dataset_retrieval.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. import threading
  2. from typing import Optional, cast
  3. from flask import Flask, current_app
  4. from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
  5. from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
  6. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  7. from core.entities.agent_entities import PlanningStrategy
  8. from core.memory.token_buffer_memory import TokenBufferMemory
  9. from core.model_manager import ModelInstance, ModelManager
  10. from core.model_runtime.entities.message_entities import PromptMessageTool
  11. from core.model_runtime.entities.model_entities import ModelFeature, ModelType
  12. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  13. from core.rag.datasource.retrieval_service import RetrievalService
  14. from core.rag.models.document import Document
  15. from core.rag.rerank.rerank import RerankRunner
  16. from core.rag.retrieval.retrival_methods import RetrievalMethod
  17. from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
  18. from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
  19. from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
  20. from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
  21. from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
  22. from extensions.ext_database import db
  23. from models.dataset import Dataset, DatasetQuery, DocumentSegment
  24. from models.dataset import Document as DatasetDocument
  25. default_retrieval_model = {
  26. 'search_method': RetrievalMethod.SEMANTIC_SEARCH,
  27. 'reranking_enable': False,
  28. 'reranking_model': {
  29. 'reranking_provider_name': '',
  30. 'reranking_model_name': ''
  31. },
  32. 'top_k': 2,
  33. 'score_threshold_enabled': False
  34. }
  35. class DatasetRetrieval:
  36. def retrieve(self, app_id: str, user_id: str, tenant_id: str,
  37. model_config: ModelConfigWithCredentialsEntity,
  38. config: DatasetEntity,
  39. query: str,
  40. invoke_from: InvokeFrom,
  41. show_retrieve_source: bool,
  42. hit_callback: DatasetIndexToolCallbackHandler,
  43. memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
  44. """
  45. Retrieve dataset.
  46. :param app_id: app_id
  47. :param user_id: user_id
  48. :param tenant_id: tenant id
  49. :param model_config: model config
  50. :param config: dataset config
  51. :param query: query
  52. :param invoke_from: invoke from
  53. :param show_retrieve_source: show retrieve source
  54. :param hit_callback: hit callback
  55. :param memory: memory
  56. :return:
  57. """
  58. dataset_ids = config.dataset_ids
  59. if len(dataset_ids) == 0:
  60. return None
  61. retrieve_config = config.retrieve_config
  62. # check model is support tool calling
  63. model_type_instance = model_config.provider_model_bundle.model_type_instance
  64. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  65. model_manager = ModelManager()
  66. model_instance = model_manager.get_model_instance(
  67. tenant_id=tenant_id,
  68. model_type=ModelType.LLM,
  69. provider=model_config.provider,
  70. model=model_config.model
  71. )
  72. # get model schema
  73. model_schema = model_type_instance.get_model_schema(
  74. model=model_config.model,
  75. credentials=model_config.credentials
  76. )
  77. if not model_schema:
  78. return None
  79. planning_strategy = PlanningStrategy.REACT_ROUTER
  80. features = model_schema.features
  81. if features:
  82. if ModelFeature.TOOL_CALL in features \
  83. or ModelFeature.MULTI_TOOL_CALL in features:
  84. planning_strategy = PlanningStrategy.ROUTER
  85. available_datasets = []
  86. for dataset_id in dataset_ids:
  87. # get dataset from dataset id
  88. dataset = db.session.query(Dataset).filter(
  89. Dataset.tenant_id == tenant_id,
  90. Dataset.id == dataset_id
  91. ).first()
  92. # pass if dataset is not available
  93. if not dataset:
  94. continue
  95. # pass if dataset is not available
  96. if (dataset and dataset.available_document_count == 0
  97. and dataset.available_document_count == 0):
  98. continue
  99. available_datasets.append(dataset)
  100. all_documents = []
  101. user_from = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'
  102. if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
  103. all_documents = self.single_retrieve(app_id, tenant_id, user_id, user_from, available_datasets, query,
  104. model_instance,
  105. model_config, planning_strategy)
  106. elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
  107. all_documents = self.multiple_retrieve(app_id, tenant_id, user_id, user_from,
  108. available_datasets, query, retrieve_config.top_k,
  109. retrieve_config.score_threshold,
  110. retrieve_config.reranking_model.get('reranking_provider_name'),
  111. retrieve_config.reranking_model.get('reranking_model_name'))
  112. document_score_list = {}
  113. for item in all_documents:
  114. if item.metadata.get('score'):
  115. document_score_list[item.metadata['doc_id']] = item.metadata['score']
  116. document_context_list = []
  117. index_node_ids = [document.metadata['doc_id'] for document in all_documents]
  118. segments = DocumentSegment.query.filter(
  119. DocumentSegment.dataset_id.in_(dataset_ids),
  120. DocumentSegment.completed_at.isnot(None),
  121. DocumentSegment.status == 'completed',
  122. DocumentSegment.enabled == True,
  123. DocumentSegment.index_node_id.in_(index_node_ids)
  124. ).all()
  125. if segments:
  126. index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
  127. sorted_segments = sorted(segments,
  128. key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
  129. float('inf')))
  130. for segment in sorted_segments:
  131. if segment.answer:
  132. document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}')
  133. else:
  134. document_context_list.append(segment.get_sign_content())
  135. if show_retrieve_source:
  136. context_list = []
  137. resource_number = 1
  138. for segment in sorted_segments:
  139. dataset = Dataset.query.filter_by(
  140. id=segment.dataset_id
  141. ).first()
  142. document = DatasetDocument.query.filter(DatasetDocument.id == segment.document_id,
  143. DatasetDocument.enabled == True,
  144. DatasetDocument.archived == False,
  145. ).first()
  146. if dataset and document:
  147. source = {
  148. 'position': resource_number,
  149. 'dataset_id': dataset.id,
  150. 'dataset_name': dataset.name,
  151. 'document_id': document.id,
  152. 'document_name': document.name,
  153. 'data_source_type': document.data_source_type,
  154. 'segment_id': segment.id,
  155. 'retriever_from': invoke_from.to_source(),
  156. 'score': document_score_list.get(segment.index_node_id, None)
  157. }
  158. if invoke_from.to_source() == 'dev':
  159. source['hit_count'] = segment.hit_count
  160. source['word_count'] = segment.word_count
  161. source['segment_position'] = segment.position
  162. source['index_node_hash'] = segment.index_node_hash
  163. if segment.answer:
  164. source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
  165. else:
  166. source['content'] = segment.content
  167. context_list.append(source)
  168. resource_number += 1
  169. if hit_callback:
  170. hit_callback.return_retriever_resource_info(context_list)
  171. return str("\n".join(document_context_list))
  172. return ''
  173. def single_retrieve(self, app_id: str,
  174. tenant_id: str,
  175. user_id: str,
  176. user_from: str,
  177. available_datasets: list,
  178. query: str,
  179. model_instance: ModelInstance,
  180. model_config: ModelConfigWithCredentialsEntity,
  181. planning_strategy: PlanningStrategy,
  182. ):
  183. tools = []
  184. for dataset in available_datasets:
  185. description = dataset.description
  186. if not description:
  187. description = 'useful for when you want to answer queries about the ' + dataset.name
  188. description = description.replace('\n', '').replace('\r', '')
  189. message_tool = PromptMessageTool(
  190. name=dataset.id,
  191. description=description,
  192. parameters={
  193. "type": "object",
  194. "properties": {},
  195. "required": [],
  196. }
  197. )
  198. tools.append(message_tool)
  199. dataset_id = None
  200. if planning_strategy == PlanningStrategy.REACT_ROUTER:
  201. react_multi_dataset_router = ReactMultiDatasetRouter()
  202. dataset_id = react_multi_dataset_router.invoke(query, tools, model_config, model_instance,
  203. user_id, tenant_id)
  204. elif planning_strategy == PlanningStrategy.ROUTER:
  205. function_call_router = FunctionCallMultiDatasetRouter()
  206. dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
  207. if dataset_id:
  208. # get retrieval model config
  209. dataset = db.session.query(Dataset).filter(
  210. Dataset.id == dataset_id
  211. ).first()
  212. if dataset:
  213. retrieval_model_config = dataset.retrieval_model \
  214. if dataset.retrieval_model else default_retrieval_model
  215. # get top k
  216. top_k = retrieval_model_config['top_k']
  217. # get retrieval method
  218. if dataset.indexing_technique == "economy":
  219. retrival_method = 'keyword_search'
  220. else:
  221. retrival_method = retrieval_model_config['search_method']
  222. # get reranking model
  223. reranking_model = retrieval_model_config['reranking_model'] \
  224. if retrieval_model_config['reranking_enable'] else None
  225. # get score threshold
  226. score_threshold = .0
  227. score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
  228. if score_threshold_enabled:
  229. score_threshold = retrieval_model_config.get("score_threshold")
  230. results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id,
  231. query=query,
  232. top_k=top_k, score_threshold=score_threshold,
  233. reranking_model=reranking_model)
  234. self._on_query(query, [dataset_id], app_id, user_from, user_id)
  235. if results:
  236. self._on_retrival_end(results)
  237. return results
  238. return []
  239. def multiple_retrieve(self,
  240. app_id: str,
  241. tenant_id: str,
  242. user_id: str,
  243. user_from: str,
  244. available_datasets: list,
  245. query: str,
  246. top_k: int,
  247. score_threshold: float,
  248. reranking_provider_name: str,
  249. reranking_model_name: str):
  250. threads = []
  251. all_documents = []
  252. dataset_ids = [dataset.id for dataset in available_datasets]
  253. for dataset in available_datasets:
  254. retrieval_thread = threading.Thread(target=self._retriever, kwargs={
  255. 'flask_app': current_app._get_current_object(),
  256. 'dataset_id': dataset.id,
  257. 'query': query,
  258. 'top_k': top_k,
  259. 'all_documents': all_documents,
  260. })
  261. threads.append(retrieval_thread)
  262. retrieval_thread.start()
  263. for thread in threads:
  264. thread.join()
  265. # do rerank for searched documents
  266. model_manager = ModelManager()
  267. rerank_model_instance = model_manager.get_model_instance(
  268. tenant_id=tenant_id,
  269. provider=reranking_provider_name,
  270. model_type=ModelType.RERANK,
  271. model=reranking_model_name
  272. )
  273. rerank_runner = RerankRunner(rerank_model_instance)
  274. all_documents = rerank_runner.run(query, all_documents,
  275. score_threshold,
  276. top_k)
  277. self._on_query(query, dataset_ids, app_id, user_from, user_id)
  278. if all_documents:
  279. self._on_retrival_end(all_documents)
  280. return all_documents
  281. def _on_retrival_end(self, documents: list[Document]) -> None:
  282. """Handle retrival end."""
  283. for document in documents:
  284. query = db.session.query(DocumentSegment).filter(
  285. DocumentSegment.index_node_id == document.metadata['doc_id']
  286. )
  287. # if 'dataset_id' in document.metadata:
  288. if 'dataset_id' in document.metadata:
  289. query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
  290. # add hit count to document segment
  291. query.update(
  292. {DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
  293. synchronize_session=False
  294. )
  295. db.session.commit()
  296. def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str) -> None:
  297. """
  298. Handle query.
  299. """
  300. if not query:
  301. return
  302. dataset_queries = []
  303. for dataset_id in dataset_ids:
  304. dataset_query = DatasetQuery(
  305. dataset_id=dataset_id,
  306. content=query,
  307. source='app',
  308. source_app_id=app_id,
  309. created_by_role=user_from,
  310. created_by=user_id
  311. )
  312. dataset_queries.append(dataset_query)
  313. if dataset_queries:
  314. db.session.add_all(dataset_queries)
  315. db.session.commit()
  316. def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
  317. with flask_app.app_context():
  318. dataset = db.session.query(Dataset).filter(
  319. Dataset.id == dataset_id
  320. ).first()
  321. if not dataset:
  322. return []
  323. # get retrieval model , if the model is not setting , using default
  324. retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
  325. if dataset.indexing_technique == "economy":
  326. # use keyword table query
  327. documents = RetrievalService.retrieve(retrival_method='keyword_search',
  328. dataset_id=dataset.id,
  329. query=query,
  330. top_k=top_k
  331. )
  332. if documents:
  333. all_documents.extend(documents)
  334. else:
  335. if top_k > 0:
  336. # retrieval source
  337. documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
  338. dataset_id=dataset.id,
  339. query=query,
  340. top_k=top_k,
  341. score_threshold=retrieval_model['score_threshold']
  342. if retrieval_model['score_threshold_enabled'] else None,
  343. reranking_model=retrieval_model['reranking_model']
  344. if retrieval_model['reranking_enable'] else None
  345. )
  346. all_documents.extend(documents)
  347. def to_dataset_retriever_tool(self, tenant_id: str,
  348. dataset_ids: list[str],
  349. retrieve_config: DatasetRetrieveConfigEntity,
  350. return_resource: bool,
  351. invoke_from: InvokeFrom,
  352. hit_callback: DatasetIndexToolCallbackHandler) \
  353. -> Optional[list[DatasetRetrieverBaseTool]]:
  354. """
  355. A dataset tool is a tool that can be used to retrieve information from a dataset
  356. :param tenant_id: tenant id
  357. :param dataset_ids: dataset ids
  358. :param retrieve_config: retrieve config
  359. :param return_resource: return resource
  360. :param invoke_from: invoke from
  361. :param hit_callback: hit callback
  362. """
  363. tools = []
  364. available_datasets = []
  365. for dataset_id in dataset_ids:
  366. # get dataset from dataset id
  367. dataset = db.session.query(Dataset).filter(
  368. Dataset.tenant_id == tenant_id,
  369. Dataset.id == dataset_id
  370. ).first()
  371. # pass if dataset is not available
  372. if not dataset:
  373. continue
  374. # pass if dataset is not available
  375. if (dataset and dataset.available_document_count == 0
  376. and dataset.available_document_count == 0):
  377. continue
  378. available_datasets.append(dataset)
  379. if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
  380. # get retrieval model config
  381. default_retrieval_model = {
  382. 'search_method': RetrievalMethod.SEMANTIC_SEARCH,
  383. 'reranking_enable': False,
  384. 'reranking_model': {
  385. 'reranking_provider_name': '',
  386. 'reranking_model_name': ''
  387. },
  388. 'top_k': 2,
  389. 'score_threshold_enabled': False
  390. }
  391. for dataset in available_datasets:
  392. retrieval_model_config = dataset.retrieval_model \
  393. if dataset.retrieval_model else default_retrieval_model
  394. # get top k
  395. top_k = retrieval_model_config['top_k']
  396. # get score threshold
  397. score_threshold = None
  398. score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
  399. if score_threshold_enabled:
  400. score_threshold = retrieval_model_config.get("score_threshold")
  401. tool = DatasetRetrieverTool.from_dataset(
  402. dataset=dataset,
  403. top_k=top_k,
  404. score_threshold=score_threshold,
  405. hit_callbacks=[hit_callback],
  406. return_resource=return_resource,
  407. retriever_from=invoke_from.to_source()
  408. )
  409. tools.append(tool)
  410. elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
  411. tool = DatasetMultiRetrieverTool.from_dataset(
  412. dataset_ids=[dataset.id for dataset in available_datasets],
  413. tenant_id=tenant_id,
  414. top_k=retrieve_config.top_k or 2,
  415. score_threshold=retrieve_config.score_threshold,
  416. hit_callbacks=[hit_callback],
  417. return_resource=return_resource,
  418. retriever_from=invoke_from.to_source(),
  419. reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'),
  420. reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name')
  421. )
  422. tools.append(tool)
  423. return tools