dataset_docstore.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. from typing import Any, Dict, Optional, Sequence
  2. from langchain.schema import Document
  3. from sqlalchemy import func
  4. from core.model_providers.model_factory import ModelFactory
  5. from extensions.ext_database import db
  6. from models.dataset import Dataset, DocumentSegment
  7. class DatesetDocumentStore:
  8. def __init__(
  9. self,
  10. dataset: Dataset,
  11. user_id: str,
  12. document_id: Optional[str] = None,
  13. ):
  14. self._dataset = dataset
  15. self._user_id = user_id
  16. self._document_id = document_id
  17. @classmethod
  18. def from_dict(cls, config_dict: Dict[str, Any]) -> "DatesetDocumentStore":
  19. return cls(**config_dict)
  20. def to_dict(self) -> Dict[str, Any]:
  21. """Serialize to dict."""
  22. return {
  23. "dataset_id": self._dataset.id,
  24. }
  25. @property
  26. def dateset_id(self) -> Any:
  27. return self._dataset.id
  28. @property
  29. def user_id(self) -> Any:
  30. return self._user_id
  31. @property
  32. def docs(self) -> Dict[str, Document]:
  33. document_segments = db.session.query(DocumentSegment).filter(
  34. DocumentSegment.dataset_id == self._dataset.id
  35. ).all()
  36. output = {}
  37. for document_segment in document_segments:
  38. doc_id = document_segment.index_node_id
  39. output[doc_id] = Document(
  40. page_content=document_segment.content,
  41. metadata={
  42. "doc_id": document_segment.index_node_id,
  43. "doc_hash": document_segment.index_node_hash,
  44. "document_id": document_segment.document_id,
  45. "dataset_id": document_segment.dataset_id,
  46. }
  47. )
  48. return output
  49. def add_documents(
  50. self, docs: Sequence[Document], allow_update: bool = True
  51. ) -> None:
  52. max_position = db.session.query(func.max(DocumentSegment.position)).filter(
  53. DocumentSegment.document_id == self._document_id
  54. ).scalar()
  55. if max_position is None:
  56. max_position = 0
  57. embedding_model = None
  58. if self._dataset.indexing_technique == 'high_quality':
  59. embedding_model = ModelFactory.get_embedding_model(
  60. tenant_id=self._dataset.tenant_id,
  61. model_provider_name=self._dataset.embedding_model_provider,
  62. model_name=self._dataset.embedding_model
  63. )
  64. for doc in docs:
  65. if not isinstance(doc, Document):
  66. raise ValueError("doc must be a Document")
  67. segment_document = self.get_document(doc_id=doc.metadata['doc_id'], raise_error=False)
  68. # NOTE: doc could already exist in the store, but we overwrite it
  69. if not allow_update and segment_document:
  70. raise ValueError(
  71. f"doc_id {doc.metadata['doc_id']} already exists. "
  72. "Set allow_update to True to overwrite."
  73. )
  74. # calc embedding use tokens
  75. tokens = embedding_model.get_num_tokens(doc.page_content) if embedding_model else 0
  76. if not segment_document:
  77. max_position += 1
  78. segment_document = DocumentSegment(
  79. tenant_id=self._dataset.tenant_id,
  80. dataset_id=self._dataset.id,
  81. document_id=self._document_id,
  82. index_node_id=doc.metadata['doc_id'],
  83. index_node_hash=doc.metadata['doc_hash'],
  84. position=max_position,
  85. content=doc.page_content,
  86. word_count=len(doc.page_content),
  87. tokens=tokens,
  88. enabled=False,
  89. created_by=self._user_id,
  90. )
  91. if 'answer' in doc.metadata and doc.metadata['answer']:
  92. segment_document.answer = doc.metadata.pop('answer', '')
  93. db.session.add(segment_document)
  94. else:
  95. segment_document.content = doc.page_content
  96. if 'answer' in doc.metadata and doc.metadata['answer']:
  97. segment_document.answer = doc.metadata.pop('answer', '')
  98. segment_document.index_node_hash = doc.metadata['doc_hash']
  99. segment_document.word_count = len(doc.page_content)
  100. segment_document.tokens = tokens
  101. db.session.commit()
  102. def document_exists(self, doc_id: str) -> bool:
  103. """Check if document exists."""
  104. result = self.get_document_segment(doc_id)
  105. return result is not None
  106. def get_document(
  107. self, doc_id: str, raise_error: bool = True
  108. ) -> Optional[Document]:
  109. document_segment = self.get_document_segment(doc_id)
  110. if document_segment is None:
  111. if raise_error:
  112. raise ValueError(f"doc_id {doc_id} not found.")
  113. else:
  114. return None
  115. return Document(
  116. page_content=document_segment.content,
  117. metadata={
  118. "doc_id": document_segment.index_node_id,
  119. "doc_hash": document_segment.index_node_hash,
  120. "document_id": document_segment.document_id,
  121. "dataset_id": document_segment.dataset_id,
  122. }
  123. )
  124. def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
  125. document_segment = self.get_document_segment(doc_id)
  126. if document_segment is None:
  127. if raise_error:
  128. raise ValueError(f"doc_id {doc_id} not found.")
  129. else:
  130. return None
  131. db.session.delete(document_segment)
  132. db.session.commit()
  133. def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
  134. """Set the hash for a given doc_id."""
  135. document_segment = self.get_document_segment(doc_id)
  136. if document_segment is None:
  137. return None
  138. document_segment.index_node_hash = doc_hash
  139. db.session.commit()
  140. def get_document_hash(self, doc_id: str) -> Optional[str]:
  141. """Get the stored hash for a document, if it exists."""
  142. document_segment = self.get_document_segment(doc_id)
  143. if document_segment is None:
  144. return None
  145. return document_segment.index_node_hash
  146. def get_document_segment(self, doc_id: str) -> DocumentSegment:
  147. document_segment = db.session.query(DocumentSegment).filter(
  148. DocumentSegment.dataset_id == self._dataset.id,
  149. DocumentSegment.index_node_id == doc_id
  150. ).first()
  151. return document_segment