dataset_docstore.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. from typing import Any, Dict, Optional, Sequence
  2. import tiktoken
  3. from llama_index.data_structs import Node
  4. from llama_index.docstore.types import BaseDocumentStore
  5. from llama_index.docstore.utils import json_to_doc
  6. from llama_index.schema import BaseDocument
  7. from sqlalchemy import func
  8. from core.llm.token_calculator import TokenCalculator
  9. from extensions.ext_database import db
  10. from models.dataset import Dataset, DocumentSegment
  11. class DatesetDocumentStore(BaseDocumentStore):
  12. def __init__(
  13. self,
  14. dataset: Dataset,
  15. user_id: str,
  16. embedding_model_name: str,
  17. document_id: Optional[str] = None,
  18. ):
  19. self._dataset = dataset
  20. self._user_id = user_id
  21. self._embedding_model_name = embedding_model_name
  22. self._document_id = document_id
  23. @classmethod
  24. def from_dict(cls, config_dict: Dict[str, Any]) -> "DatesetDocumentStore":
  25. return cls(**config_dict)
  26. def to_dict(self) -> Dict[str, Any]:
  27. """Serialize to dict."""
  28. return {
  29. "dataset_id": self._dataset.id,
  30. }
  31. @property
  32. def dateset_id(self) -> Any:
  33. return self._dataset.id
  34. @property
  35. def user_id(self) -> Any:
  36. return self._user_id
  37. @property
  38. def embedding_model_name(self) -> Any:
  39. return self._embedding_model_name
  40. @property
  41. def docs(self) -> Dict[str, BaseDocument]:
  42. document_segments = db.session.query(DocumentSegment).filter(
  43. DocumentSegment.dataset_id == self._dataset.id
  44. ).all()
  45. output = {}
  46. for document_segment in document_segments:
  47. doc_id = document_segment.index_node_id
  48. result = self.segment_to_dict(document_segment)
  49. output[doc_id] = json_to_doc(result)
  50. return output
  51. def add_documents(
  52. self, docs: Sequence[BaseDocument], allow_update: bool = True
  53. ) -> None:
  54. max_position = db.session.query(func.max(DocumentSegment.position)).filter(
  55. DocumentSegment.document == self._document_id
  56. ).scalar()
  57. if max_position is None:
  58. max_position = 0
  59. for doc in docs:
  60. if doc.is_doc_id_none:
  61. raise ValueError("doc_id not set")
  62. if not isinstance(doc, Node):
  63. raise ValueError("doc must be a Node")
  64. segment_document = self.get_document(doc_id=doc.get_doc_id(), raise_error=False)
  65. # NOTE: doc could already exist in the store, but we overwrite it
  66. if not allow_update and segment_document:
  67. raise ValueError(
  68. f"doc_id {doc.get_doc_id()} already exists. "
  69. "Set allow_update to True to overwrite."
  70. )
  71. # calc embedding use tokens
  72. tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.get_text())
  73. if not segment_document:
  74. max_position += 1
  75. segment_document = DocumentSegment(
  76. tenant_id=self._dataset.tenant_id,
  77. dataset_id=self._dataset.id,
  78. document_id=self._document_id,
  79. index_node_id=doc.get_doc_id(),
  80. index_node_hash=doc.get_doc_hash(),
  81. position=max_position,
  82. content=doc.get_text(),
  83. word_count=len(doc.get_text()),
  84. tokens=tokens,
  85. created_by=self._user_id,
  86. )
  87. db.session.add(segment_document)
  88. else:
  89. segment_document.content = doc.get_text()
  90. segment_document.index_node_hash = doc.get_doc_hash()
  91. segment_document.word_count = len(doc.get_text())
  92. segment_document.tokens = tokens
  93. db.session.commit()
  94. def document_exists(self, doc_id: str) -> bool:
  95. """Check if document exists."""
  96. result = self.get_document_segment(doc_id)
  97. return result is not None
  98. def get_document(
  99. self, doc_id: str, raise_error: bool = True
  100. ) -> Optional[BaseDocument]:
  101. document_segment = self.get_document_segment(doc_id)
  102. if document_segment is None:
  103. if raise_error:
  104. raise ValueError(f"doc_id {doc_id} not found.")
  105. else:
  106. return None
  107. result = self.segment_to_dict(document_segment)
  108. return json_to_doc(result)
  109. def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
  110. document_segment = self.get_document_segment(doc_id)
  111. if document_segment is None:
  112. if raise_error:
  113. raise ValueError(f"doc_id {doc_id} not found.")
  114. else:
  115. return None
  116. db.session.delete(document_segment)
  117. db.session.commit()
  118. def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
  119. """Set the hash for a given doc_id."""
  120. document_segment = self.get_document_segment(doc_id)
  121. if document_segment is None:
  122. return None
  123. document_segment.index_node_hash = doc_hash
  124. db.session.commit()
  125. def get_document_hash(self, doc_id: str) -> Optional[str]:
  126. """Get the stored hash for a document, if it exists."""
  127. document_segment = self.get_document_segment(doc_id)
  128. if document_segment is None:
  129. return None
  130. return document_segment.index_node_hash
  131. def update_docstore(self, other: "BaseDocumentStore") -> None:
  132. """Update docstore.
  133. Args:
  134. other (BaseDocumentStore): docstore to update from
  135. """
  136. self.add_documents(list(other.docs.values()))
  137. def get_document_segment(self, doc_id: str) -> DocumentSegment:
  138. document_segment = db.session.query(DocumentSegment).filter(
  139. DocumentSegment.dataset_id == self._dataset.id,
  140. DocumentSegment.index_node_id == doc_id
  141. ).first()
  142. return document_segment
  143. def segment_to_dict(self, segment: DocumentSegment) -> Dict[str, Any]:
  144. return {
  145. "doc_id": segment.index_node_id,
  146. "doc_hash": segment.index_node_hash,
  147. "text": segment.content,
  148. "__type__": Node.get_type()
  149. }