dataset_docstore.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. from typing import Any, Dict, Optional, Sequence
  2. from langchain.schema import Document
  3. from sqlalchemy import func
  4. from core.llm.token_calculator import TokenCalculator
  5. from extensions.ext_database import db
  6. from models.dataset import Dataset, DocumentSegment
  7. class DatesetDocumentStore:
  8. def __init__(
  9. self,
  10. dataset: Dataset,
  11. user_id: str,
  12. embedding_model_name: str,
  13. document_id: Optional[str] = None,
  14. ):
  15. self._dataset = dataset
  16. self._user_id = user_id
  17. self._embedding_model_name = embedding_model_name
  18. self._document_id = document_id
  19. @classmethod
  20. def from_dict(cls, config_dict: Dict[str, Any]) -> "DatesetDocumentStore":
  21. return cls(**config_dict)
  22. def to_dict(self) -> Dict[str, Any]:
  23. """Serialize to dict."""
  24. return {
  25. "dataset_id": self._dataset.id,
  26. }
  27. @property
  28. def dateset_id(self) -> Any:
  29. return self._dataset.id
  30. @property
  31. def user_id(self) -> Any:
  32. return self._user_id
  33. @property
  34. def embedding_model_name(self) -> Any:
  35. return self._embedding_model_name
  36. @property
  37. def docs(self) -> Dict[str, Document]:
  38. document_segments = db.session.query(DocumentSegment).filter(
  39. DocumentSegment.dataset_id == self._dataset.id
  40. ).all()
  41. output = {}
  42. for document_segment in document_segments:
  43. doc_id = document_segment.index_node_id
  44. output[doc_id] = Document(
  45. page_content=document_segment.content,
  46. metadata={
  47. "doc_id": document_segment.index_node_id,
  48. "doc_hash": document_segment.index_node_hash,
  49. "document_id": document_segment.document_id,
  50. "dataset_id": document_segment.dataset_id,
  51. }
  52. )
  53. return output
  54. def add_documents(
  55. self, docs: Sequence[Document], allow_update: bool = True
  56. ) -> None:
  57. max_position = db.session.query(func.max(DocumentSegment.position)).filter(
  58. DocumentSegment.document_id == self._document_id
  59. ).scalar()
  60. if max_position is None:
  61. max_position = 0
  62. for doc in docs:
  63. if not isinstance(doc, Document):
  64. raise ValueError("doc must be a Document")
  65. segment_document = self.get_document(doc_id=doc.metadata['doc_id'], raise_error=False)
  66. # NOTE: doc could already exist in the store, but we overwrite it
  67. if not allow_update and segment_document:
  68. raise ValueError(
  69. f"doc_id {doc.metadata['doc_id']} already exists. "
  70. "Set allow_update to True to overwrite."
  71. )
  72. # calc embedding use tokens
  73. tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.page_content)
  74. if not segment_document:
  75. max_position += 1
  76. segment_document = DocumentSegment(
  77. tenant_id=self._dataset.tenant_id,
  78. dataset_id=self._dataset.id,
  79. document_id=self._document_id,
  80. index_node_id=doc.metadata['doc_id'],
  81. index_node_hash=doc.metadata['doc_hash'],
  82. position=max_position,
  83. content=doc.page_content,
  84. word_count=len(doc.page_content),
  85. tokens=tokens,
  86. created_by=self._user_id,
  87. )
  88. if 'answer' in doc.metadata and doc.metadata['answer']:
  89. segment_document.answer = doc.metadata.pop('answer', '')
  90. db.session.add(segment_document)
  91. else:
  92. segment_document.content = doc.page_content
  93. if 'answer' in doc.metadata and doc.metadata['answer']:
  94. segment_document.answer = doc.metadata.pop('answer', '')
  95. segment_document.index_node_hash = doc.metadata['doc_hash']
  96. segment_document.word_count = len(doc.page_content)
  97. segment_document.tokens = tokens
  98. db.session.commit()
  99. def document_exists(self, doc_id: str) -> bool:
  100. """Check if document exists."""
  101. result = self.get_document_segment(doc_id)
  102. return result is not None
  103. def get_document(
  104. self, doc_id: str, raise_error: bool = True
  105. ) -> Optional[Document]:
  106. document_segment = self.get_document_segment(doc_id)
  107. if document_segment is None:
  108. if raise_error:
  109. raise ValueError(f"doc_id {doc_id} not found.")
  110. else:
  111. return None
  112. return Document(
  113. page_content=document_segment.content,
  114. metadata={
  115. "doc_id": document_segment.index_node_id,
  116. "doc_hash": document_segment.index_node_hash,
  117. "document_id": document_segment.document_id,
  118. "dataset_id": document_segment.dataset_id,
  119. }
  120. )
  121. def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
  122. document_segment = self.get_document_segment(doc_id)
  123. if document_segment is None:
  124. if raise_error:
  125. raise ValueError(f"doc_id {doc_id} not found.")
  126. else:
  127. return None
  128. db.session.delete(document_segment)
  129. db.session.commit()
  130. def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
  131. """Set the hash for a given doc_id."""
  132. document_segment = self.get_document_segment(doc_id)
  133. if document_segment is None:
  134. return None
  135. document_segment.index_node_hash = doc_hash
  136. db.session.commit()
  137. def get_document_hash(self, doc_id: str) -> Optional[str]:
  138. """Get the stored hash for a document, if it exists."""
  139. document_segment = self.get_document_segment(doc_id)
  140. if document_segment is None:
  141. return None
  142. return document_segment.index_node_hash
  143. def get_document_segment(self, doc_id: str) -> DocumentSegment:
  144. document_segment = db.session.query(DocumentSegment).filter(
  145. DocumentSegment.dataset_id == self._dataset.id,
  146. DocumentSegment.index_node_id == doc_id
  147. ).first()
  148. return document_segment