base.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. import json
  2. import logging
  3. from abc import abstractmethod
  4. from typing import Any, cast
  5. from langchain.embeddings.base import Embeddings
  6. from langchain.schema import BaseRetriever, Document
  7. from langchain.vectorstores import VectorStore
  8. from core.index.base import BaseIndex
  9. from extensions.ext_database import db
  10. from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
  11. from models.dataset import Document as DatasetDocument
  12. class BaseVectorIndex(BaseIndex):
  13. def __init__(self, dataset: Dataset, embeddings: Embeddings):
  14. super().__init__(dataset)
  15. self._embeddings = embeddings
  16. self._vector_store = None
  17. def get_type(self) -> str:
  18. raise NotImplementedError
  19. @abstractmethod
  20. def get_index_name(self, dataset: Dataset) -> str:
  21. raise NotImplementedError
  22. @abstractmethod
  23. def to_index_struct(self) -> dict:
  24. raise NotImplementedError
  25. @abstractmethod
  26. def _get_vector_store(self) -> VectorStore:
  27. raise NotImplementedError
  28. @abstractmethod
  29. def _get_vector_store_class(self) -> type:
  30. raise NotImplementedError
  31. @abstractmethod
  32. def search_by_full_text_index(
  33. self, query: str,
  34. **kwargs: Any
  35. ) -> list[Document]:
  36. raise NotImplementedError
  37. def search(
  38. self, query: str,
  39. **kwargs: Any
  40. ) -> list[Document]:
  41. vector_store = self._get_vector_store()
  42. vector_store = cast(self._get_vector_store_class(), vector_store)
  43. search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
  44. search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
  45. if search_type == 'similarity_score_threshold':
  46. score_threshold = search_kwargs.get("score_threshold")
  47. if (score_threshold is None) or (not isinstance(score_threshold, float)):
  48. search_kwargs['score_threshold'] = .0
  49. docs_with_similarity = vector_store.similarity_search_with_relevance_scores(
  50. query, **search_kwargs
  51. )
  52. docs = []
  53. for doc, similarity in docs_with_similarity:
  54. doc.metadata['score'] = similarity
  55. docs.append(doc)
  56. return docs
  57. # similarity k
  58. # mmr k, fetch_k, lambda_mult
  59. # similarity_score_threshold k
  60. return vector_store.as_retriever(
  61. search_type=search_type,
  62. search_kwargs=search_kwargs
  63. ).get_relevant_documents(query)
  64. def get_retriever(self, **kwargs: Any) -> BaseRetriever:
  65. vector_store = self._get_vector_store()
  66. vector_store = cast(self._get_vector_store_class(), vector_store)
  67. return vector_store.as_retriever(**kwargs)
  68. def add_texts(self, texts: list[Document], **kwargs):
  69. if self._is_origin():
  70. self.recreate_dataset(self.dataset)
  71. vector_store = self._get_vector_store()
  72. vector_store = cast(self._get_vector_store_class(), vector_store)
  73. if kwargs.get('duplicate_check', False):
  74. texts = self._filter_duplicate_texts(texts)
  75. uuids = self._get_uuids(texts)
  76. vector_store.add_documents(texts, uuids=uuids)
  77. def text_exists(self, id: str) -> bool:
  78. vector_store = self._get_vector_store()
  79. vector_store = cast(self._get_vector_store_class(), vector_store)
  80. return vector_store.text_exists(id)
  81. def delete_by_ids(self, ids: list[str]) -> None:
  82. if self._is_origin():
  83. self.recreate_dataset(self.dataset)
  84. return
  85. vector_store = self._get_vector_store()
  86. vector_store = cast(self._get_vector_store_class(), vector_store)
  87. for node_id in ids:
  88. vector_store.del_text(node_id)
  89. def delete_by_group_id(self, group_id: str) -> None:
  90. vector_store = self._get_vector_store()
  91. vector_store = cast(self._get_vector_store_class(), vector_store)
  92. if self.dataset.collection_binding_id:
  93. vector_store.delete_by_group_id(group_id)
  94. else:
  95. vector_store.delete()
  96. def delete(self) -> None:
  97. vector_store = self._get_vector_store()
  98. vector_store = cast(self._get_vector_store_class(), vector_store)
  99. vector_store.delete()
  100. def _is_origin(self):
  101. return False
  102. def recreate_dataset(self, dataset: Dataset):
  103. logging.info(f"Recreating dataset {dataset.id}")
  104. try:
  105. self.delete()
  106. except Exception as e:
  107. raise e
  108. dataset_documents = db.session.query(DatasetDocument).filter(
  109. DatasetDocument.dataset_id == dataset.id,
  110. DatasetDocument.indexing_status == 'completed',
  111. DatasetDocument.enabled == True,
  112. DatasetDocument.archived == False,
  113. ).all()
  114. documents = []
  115. for dataset_document in dataset_documents:
  116. segments = db.session.query(DocumentSegment).filter(
  117. DocumentSegment.document_id == dataset_document.id,
  118. DocumentSegment.status == 'completed',
  119. DocumentSegment.enabled == True
  120. ).all()
  121. for segment in segments:
  122. document = Document(
  123. page_content=segment.content,
  124. metadata={
  125. "doc_id": segment.index_node_id,
  126. "doc_hash": segment.index_node_hash,
  127. "document_id": segment.document_id,
  128. "dataset_id": segment.dataset_id,
  129. }
  130. )
  131. documents.append(document)
  132. origin_index_struct = self.dataset.index_struct[:]
  133. self.dataset.index_struct = None
  134. if documents:
  135. try:
  136. self.create(documents)
  137. except Exception as e:
  138. self.dataset.index_struct = origin_index_struct
  139. raise e
  140. dataset.index_struct = json.dumps(self.to_index_struct())
  141. db.session.commit()
  142. self.dataset = dataset
  143. logging.info(f"Dataset {dataset.id} recreate successfully.")
  144. def create_qdrant_dataset(self, dataset: Dataset):
  145. logging.info(f"create_qdrant_dataset {dataset.id}")
  146. try:
  147. self.delete()
  148. except Exception as e:
  149. raise e
  150. dataset_documents = db.session.query(DatasetDocument).filter(
  151. DatasetDocument.dataset_id == dataset.id,
  152. DatasetDocument.indexing_status == 'completed',
  153. DatasetDocument.enabled == True,
  154. DatasetDocument.archived == False,
  155. ).all()
  156. documents = []
  157. for dataset_document in dataset_documents:
  158. segments = db.session.query(DocumentSegment).filter(
  159. DocumentSegment.document_id == dataset_document.id,
  160. DocumentSegment.status == 'completed',
  161. DocumentSegment.enabled == True
  162. ).all()
  163. for segment in segments:
  164. document = Document(
  165. page_content=segment.content,
  166. metadata={
  167. "doc_id": segment.index_node_id,
  168. "doc_hash": segment.index_node_hash,
  169. "document_id": segment.document_id,
  170. "dataset_id": segment.dataset_id,
  171. }
  172. )
  173. documents.append(document)
  174. if documents:
  175. try:
  176. self.create(documents)
  177. except Exception as e:
  178. raise e
  179. logging.info(f"Dataset {dataset.id} recreate successfully.")
  180. def update_qdrant_dataset(self, dataset: Dataset):
  181. logging.info(f"update_qdrant_dataset {dataset.id}")
  182. segment = db.session.query(DocumentSegment).filter(
  183. DocumentSegment.dataset_id == dataset.id,
  184. DocumentSegment.status == 'completed',
  185. DocumentSegment.enabled == True
  186. ).first()
  187. if segment:
  188. try:
  189. exist = self.text_exists(segment.index_node_id)
  190. if exist:
  191. index_struct = {
  192. "type": 'qdrant',
  193. "vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
  194. }
  195. dataset.index_struct = json.dumps(index_struct)
  196. db.session.commit()
  197. except Exception as e:
  198. raise e
  199. logging.info(f"Dataset {dataset.id} recreate successfully.")
  200. def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
  201. logging.info(f"restore dataset in_one,_dataset {dataset.id}")
  202. dataset_documents = db.session.query(DatasetDocument).filter(
  203. DatasetDocument.dataset_id == dataset.id,
  204. DatasetDocument.indexing_status == 'completed',
  205. DatasetDocument.enabled == True,
  206. DatasetDocument.archived == False,
  207. ).all()
  208. documents = []
  209. for dataset_document in dataset_documents:
  210. segments = db.session.query(DocumentSegment).filter(
  211. DocumentSegment.document_id == dataset_document.id,
  212. DocumentSegment.status == 'completed',
  213. DocumentSegment.enabled == True
  214. ).all()
  215. for segment in segments:
  216. document = Document(
  217. page_content=segment.content,
  218. metadata={
  219. "doc_id": segment.index_node_id,
  220. "doc_hash": segment.index_node_hash,
  221. "document_id": segment.document_id,
  222. "dataset_id": segment.dataset_id,
  223. }
  224. )
  225. documents.append(document)
  226. if documents:
  227. try:
  228. self.add_texts(documents)
  229. except Exception as e:
  230. raise e
  231. logging.info(f"Dataset {dataset.id} recreate successfully.")
  232. def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
  233. logging.info(f"delete original collection: {dataset.id}")
  234. self.delete()
  235. dataset.collection_binding_id = dataset_collection_binding.id
  236. db.session.add(dataset)
  237. db.session.commit()
  238. logging.info(f"Dataset {dataset.id} recreate successfully.")