dataset_docstore.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. from typing import Any, Dict, Optional, Sequence, cast
  2. from langchain.schema import Document
  3. from sqlalchemy import func
  4. from core.model_manager import ModelManager
  5. from core.model_runtime.entities.model_entities import ModelType
  6. from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
  7. from extensions.ext_database import db
  8. from models.dataset import Dataset, DocumentSegment
  9. class DatasetDocumentStore:
  10. def __init__(
  11. self,
  12. dataset: Dataset,
  13. user_id: str,
  14. document_id: Optional[str] = None,
  15. ):
  16. self._dataset = dataset
  17. self._user_id = user_id
  18. self._document_id = document_id
  19. @classmethod
  20. def from_dict(cls, config_dict: Dict[str, Any]) -> "DatasetDocumentStore":
  21. return cls(**config_dict)
  22. def to_dict(self) -> Dict[str, Any]:
  23. """Serialize to dict."""
  24. return {
  25. "dataset_id": self._dataset.id,
  26. }
  27. @property
  28. def dateset_id(self) -> Any:
  29. return self._dataset.id
  30. @property
  31. def user_id(self) -> Any:
  32. return self._user_id
  33. @property
  34. def docs(self) -> Dict[str, Document]:
  35. document_segments = db.session.query(DocumentSegment).filter(
  36. DocumentSegment.dataset_id == self._dataset.id
  37. ).all()
  38. output = {}
  39. for document_segment in document_segments:
  40. doc_id = document_segment.index_node_id
  41. output[doc_id] = Document(
  42. page_content=document_segment.content,
  43. metadata={
  44. "doc_id": document_segment.index_node_id,
  45. "doc_hash": document_segment.index_node_hash,
  46. "document_id": document_segment.document_id,
  47. "dataset_id": document_segment.dataset_id,
  48. }
  49. )
  50. return output
  51. def add_documents(
  52. self, docs: Sequence[Document], allow_update: bool = True
  53. ) -> None:
  54. max_position = db.session.query(func.max(DocumentSegment.position)).filter(
  55. DocumentSegment.document_id == self._document_id
  56. ).scalar()
  57. if max_position is None:
  58. max_position = 0
  59. embedding_model = None
  60. if self._dataset.indexing_technique == 'high_quality':
  61. model_manager = ModelManager()
  62. embedding_model = model_manager.get_model_instance(
  63. tenant_id=self._dataset.tenant_id,
  64. provider=self._dataset.embedding_model_provider,
  65. model_type=ModelType.TEXT_EMBEDDING,
  66. model=self._dataset.embedding_model
  67. )
  68. for doc in docs:
  69. if not isinstance(doc, Document):
  70. raise ValueError("doc must be a Document")
  71. segment_document = self.get_document(doc_id=doc.metadata['doc_id'], raise_error=False)
  72. # NOTE: doc could already exist in the store, but we overwrite it
  73. if not allow_update and segment_document:
  74. raise ValueError(
  75. f"doc_id {doc.metadata['doc_id']} already exists. "
  76. "Set allow_update to True to overwrite."
  77. )
  78. # calc embedding use tokens
  79. if embedding_model:
  80. model_type_instance = embedding_model.model_type_instance
  81. model_type_instance = cast(TextEmbeddingModel, model_type_instance)
  82. tokens = model_type_instance.get_num_tokens(
  83. model=embedding_model.model,
  84. credentials=embedding_model.credentials,
  85. texts=[doc.page_content]
  86. )
  87. else:
  88. tokens = 0
  89. if not segment_document:
  90. max_position += 1
  91. segment_document = DocumentSegment(
  92. tenant_id=self._dataset.tenant_id,
  93. dataset_id=self._dataset.id,
  94. document_id=self._document_id,
  95. index_node_id=doc.metadata['doc_id'],
  96. index_node_hash=doc.metadata['doc_hash'],
  97. position=max_position,
  98. content=doc.page_content,
  99. word_count=len(doc.page_content),
  100. tokens=tokens,
  101. enabled=False,
  102. created_by=self._user_id,
  103. )
  104. if 'answer' in doc.metadata and doc.metadata['answer']:
  105. segment_document.answer = doc.metadata.pop('answer', '')
  106. db.session.add(segment_document)
  107. else:
  108. segment_document.content = doc.page_content
  109. if 'answer' in doc.metadata and doc.metadata['answer']:
  110. segment_document.answer = doc.metadata.pop('answer', '')
  111. segment_document.index_node_hash = doc.metadata['doc_hash']
  112. segment_document.word_count = len(doc.page_content)
  113. segment_document.tokens = tokens
  114. db.session.commit()
  115. def document_exists(self, doc_id: str) -> bool:
  116. """Check if document exists."""
  117. result = self.get_document_segment(doc_id)
  118. return result is not None
  119. def get_document(
  120. self, doc_id: str, raise_error: bool = True
  121. ) -> Optional[Document]:
  122. document_segment = self.get_document_segment(doc_id)
  123. if document_segment is None:
  124. if raise_error:
  125. raise ValueError(f"doc_id {doc_id} not found.")
  126. else:
  127. return None
  128. return Document(
  129. page_content=document_segment.content,
  130. metadata={
  131. "doc_id": document_segment.index_node_id,
  132. "doc_hash": document_segment.index_node_hash,
  133. "document_id": document_segment.document_id,
  134. "dataset_id": document_segment.dataset_id,
  135. }
  136. )
  137. def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
  138. document_segment = self.get_document_segment(doc_id)
  139. if document_segment is None:
  140. if raise_error:
  141. raise ValueError(f"doc_id {doc_id} not found.")
  142. else:
  143. return None
  144. db.session.delete(document_segment)
  145. db.session.commit()
  146. def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
  147. """Set the hash for a given doc_id."""
  148. document_segment = self.get_document_segment(doc_id)
  149. if document_segment is None:
  150. return None
  151. document_segment.index_node_hash = doc_hash
  152. db.session.commit()
  153. def get_document_hash(self, doc_id: str) -> Optional[str]:
  154. """Get the stored hash for a document, if it exists."""
  155. document_segment = self.get_document_segment(doc_id)
  156. if document_segment is None:
  157. return None
  158. return document_segment.index_node_hash
  159. def get_document_segment(self, doc_id: str) -> DocumentSegment:
  160. document_segment = db.session.query(DocumentSegment).filter(
  161. DocumentSegment.dataset_id == self._dataset.id,
  162. DocumentSegment.index_node_id == doc_id
  163. ).first()
  164. return document_segment