dataset_docstore.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. from collections.abc import Sequence
  2. from typing import Any, Optional, cast
  3. from sqlalchemy import func
  4. from core.model_manager import ModelManager
  5. from core.model_runtime.entities.model_entities import ModelType
  6. from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
  7. from core.rag.models.document import Document
  8. from extensions.ext_database import db
  9. from models.dataset import Dataset, DocumentSegment
  10. class DatasetDocumentStore:
  11. def __init__(
  12. self,
  13. dataset: Dataset,
  14. user_id: str,
  15. document_id: Optional[str] = None,
  16. ):
  17. self._dataset = dataset
  18. self._user_id = user_id
  19. self._document_id = document_id
  20. @classmethod
  21. def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore":
  22. return cls(**config_dict)
  23. def to_dict(self) -> dict[str, Any]:
  24. """Serialize to dict."""
  25. return {
  26. "dataset_id": self._dataset.id,
  27. }
  28. @property
  29. def dateset_id(self) -> Any:
  30. return self._dataset.id
  31. @property
  32. def user_id(self) -> Any:
  33. return self._user_id
  34. @property
  35. def docs(self) -> dict[str, Document]:
  36. document_segments = db.session.query(DocumentSegment).filter(
  37. DocumentSegment.dataset_id == self._dataset.id
  38. ).all()
  39. output = {}
  40. for document_segment in document_segments:
  41. doc_id = document_segment.index_node_id
  42. output[doc_id] = Document(
  43. page_content=document_segment.content,
  44. metadata={
  45. "doc_id": document_segment.index_node_id,
  46. "doc_hash": document_segment.index_node_hash,
  47. "document_id": document_segment.document_id,
  48. "dataset_id": document_segment.dataset_id,
  49. }
  50. )
  51. return output
  52. def add_documents(
  53. self, docs: Sequence[Document], allow_update: bool = True
  54. ) -> None:
  55. max_position = db.session.query(func.max(DocumentSegment.position)).filter(
  56. DocumentSegment.document_id == self._document_id
  57. ).scalar()
  58. if max_position is None:
  59. max_position = 0
  60. embedding_model = None
  61. if self._dataset.indexing_technique == 'high_quality':
  62. model_manager = ModelManager()
  63. embedding_model = model_manager.get_model_instance(
  64. tenant_id=self._dataset.tenant_id,
  65. provider=self._dataset.embedding_model_provider,
  66. model_type=ModelType.TEXT_EMBEDDING,
  67. model=self._dataset.embedding_model
  68. )
  69. for doc in docs:
  70. if not isinstance(doc, Document):
  71. raise ValueError("doc must be a Document")
  72. segment_document = self.get_document(doc_id=doc.metadata['doc_id'], raise_error=False)
  73. # NOTE: doc could already exist in the store, but we overwrite it
  74. if not allow_update and segment_document:
  75. raise ValueError(
  76. f"doc_id {doc.metadata['doc_id']} already exists. "
  77. "Set allow_update to True to overwrite."
  78. )
  79. # calc embedding use tokens
  80. if embedding_model:
  81. model_type_instance = embedding_model.model_type_instance
  82. model_type_instance = cast(TextEmbeddingModel, model_type_instance)
  83. tokens = model_type_instance.get_num_tokens(
  84. model=embedding_model.model,
  85. credentials=embedding_model.credentials,
  86. texts=[doc.page_content]
  87. )
  88. else:
  89. tokens = 0
  90. if not segment_document:
  91. max_position += 1
  92. segment_document = DocumentSegment(
  93. tenant_id=self._dataset.tenant_id,
  94. dataset_id=self._dataset.id,
  95. document_id=self._document_id,
  96. index_node_id=doc.metadata['doc_id'],
  97. index_node_hash=doc.metadata['doc_hash'],
  98. position=max_position,
  99. content=doc.page_content,
  100. word_count=len(doc.page_content),
  101. tokens=tokens,
  102. enabled=False,
  103. created_by=self._user_id,
  104. )
  105. if 'answer' in doc.metadata and doc.metadata['answer']:
  106. segment_document.answer = doc.metadata.pop('answer', '')
  107. db.session.add(segment_document)
  108. else:
  109. segment_document.content = doc.page_content
  110. if 'answer' in doc.metadata and doc.metadata['answer']:
  111. segment_document.answer = doc.metadata.pop('answer', '')
  112. segment_document.index_node_hash = doc.metadata['doc_hash']
  113. segment_document.word_count = len(doc.page_content)
  114. segment_document.tokens = tokens
  115. db.session.commit()
  116. def document_exists(self, doc_id: str) -> bool:
  117. """Check if document exists."""
  118. result = self.get_document_segment(doc_id)
  119. return result is not None
  120. def get_document(
  121. self, doc_id: str, raise_error: bool = True
  122. ) -> Optional[Document]:
  123. document_segment = self.get_document_segment(doc_id)
  124. if document_segment is None:
  125. if raise_error:
  126. raise ValueError(f"doc_id {doc_id} not found.")
  127. else:
  128. return None
  129. return Document(
  130. page_content=document_segment.content,
  131. metadata={
  132. "doc_id": document_segment.index_node_id,
  133. "doc_hash": document_segment.index_node_hash,
  134. "document_id": document_segment.document_id,
  135. "dataset_id": document_segment.dataset_id,
  136. }
  137. )
  138. def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
  139. document_segment = self.get_document_segment(doc_id)
  140. if document_segment is None:
  141. if raise_error:
  142. raise ValueError(f"doc_id {doc_id} not found.")
  143. else:
  144. return None
  145. db.session.delete(document_segment)
  146. db.session.commit()
  147. def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
  148. """Set the hash for a given doc_id."""
  149. document_segment = self.get_document_segment(doc_id)
  150. if document_segment is None:
  151. return None
  152. document_segment.index_node_hash = doc_hash
  153. db.session.commit()
  154. def get_document_hash(self, doc_id: str) -> Optional[str]:
  155. """Get the stored hash for a document, if it exists."""
  156. document_segment = self.get_document_segment(doc_id)
  157. if document_segment is None:
  158. return None
  159. return document_segment.index_node_hash
  160. def get_document_segment(self, doc_id: str) -> DocumentSegment:
  161. document_segment = db.session.query(DocumentSegment).filter(
  162. DocumentSegment.dataset_id == self._dataset.id,
  163. DocumentSegment.index_node_id == doc_id
  164. ).first()
  165. return document_segment