dataset_retrieval.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626
  1. import math
  2. import threading
  3. from collections import Counter
  4. from typing import Optional, cast
  5. from flask import Flask, current_app
  6. from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
  7. from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
  8. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  9. from core.entities.agent_entities import PlanningStrategy
  10. from core.memory.token_buffer_memory import TokenBufferMemory
  11. from core.model_manager import ModelInstance, ModelManager
  12. from core.model_runtime.entities.message_entities import PromptMessageTool
  13. from core.model_runtime.entities.model_entities import ModelFeature, ModelType
  14. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  15. from core.ops.entities.trace_entity import TraceTaskName
  16. from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
  17. from core.ops.utils import measure_time
  18. from core.rag.data_post_processor.data_post_processor import DataPostProcessor
  19. from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
  20. from core.rag.datasource.retrieval_service import RetrievalService
  21. from core.rag.models.document import Document
  22. from core.rag.retrieval.retrival_methods import RetrievalMethod
  23. from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
  24. from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
  25. from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
  26. from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
  27. from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
  28. from extensions.ext_database import db
  29. from models.dataset import Dataset, DatasetQuery, DocumentSegment
  30. from models.dataset import Document as DatasetDocument
  31. default_retrieval_model = {
  32. 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
  33. 'reranking_enable': False,
  34. 'reranking_model': {
  35. 'reranking_provider_name': '',
  36. 'reranking_model_name': ''
  37. },
  38. 'top_k': 2,
  39. 'score_threshold_enabled': False
  40. }
  41. class DatasetRetrieval:
  42. def __init__(self, application_generate_entity=None):
  43. self.application_generate_entity = application_generate_entity
  44. def retrieve(
  45. self, app_id: str, user_id: str, tenant_id: str,
  46. model_config: ModelConfigWithCredentialsEntity,
  47. config: DatasetEntity,
  48. query: str,
  49. invoke_from: InvokeFrom,
  50. show_retrieve_source: bool,
  51. hit_callback: DatasetIndexToolCallbackHandler,
  52. message_id: str,
  53. memory: Optional[TokenBufferMemory] = None,
  54. ) -> Optional[str]:
  55. """
  56. Retrieve dataset.
  57. :param app_id: app_id
  58. :param user_id: user_id
  59. :param tenant_id: tenant id
  60. :param model_config: model config
  61. :param config: dataset config
  62. :param query: query
  63. :param invoke_from: invoke from
  64. :param show_retrieve_source: show retrieve source
  65. :param hit_callback: hit callback
  66. :param message_id: message id
  67. :param memory: memory
  68. :return:
  69. """
  70. dataset_ids = config.dataset_ids
  71. if len(dataset_ids) == 0:
  72. return None
  73. retrieve_config = config.retrieve_config
  74. # check model is support tool calling
  75. model_type_instance = model_config.provider_model_bundle.model_type_instance
  76. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  77. model_manager = ModelManager()
  78. model_instance = model_manager.get_model_instance(
  79. tenant_id=tenant_id,
  80. model_type=ModelType.LLM,
  81. provider=model_config.provider,
  82. model=model_config.model
  83. )
  84. # get model schema
  85. model_schema = model_type_instance.get_model_schema(
  86. model=model_config.model,
  87. credentials=model_config.credentials
  88. )
  89. if not model_schema:
  90. return None
  91. planning_strategy = PlanningStrategy.REACT_ROUTER
  92. features = model_schema.features
  93. if features:
  94. if ModelFeature.TOOL_CALL in features \
  95. or ModelFeature.MULTI_TOOL_CALL in features:
  96. planning_strategy = PlanningStrategy.ROUTER
  97. available_datasets = []
  98. for dataset_id in dataset_ids:
  99. # get dataset from dataset id
  100. dataset = db.session.query(Dataset).filter(
  101. Dataset.tenant_id == tenant_id,
  102. Dataset.id == dataset_id
  103. ).first()
  104. # pass if dataset is not available
  105. if not dataset:
  106. continue
  107. # pass if dataset is not available
  108. if (dataset and dataset.available_document_count == 0
  109. and dataset.available_document_count == 0):
  110. continue
  111. available_datasets.append(dataset)
  112. all_documents = []
  113. user_from = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'
  114. if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
  115. all_documents = self.single_retrieve(
  116. app_id, tenant_id, user_id, user_from, available_datasets, query,
  117. model_instance,
  118. model_config, planning_strategy, message_id
  119. )
  120. elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
  121. all_documents = self.multiple_retrieve(
  122. app_id, tenant_id, user_id, user_from,
  123. available_datasets, query, retrieve_config.top_k,
  124. retrieve_config.score_threshold,
  125. retrieve_config.rerank_mode,
  126. retrieve_config.reranking_model,
  127. retrieve_config.weights,
  128. retrieve_config.reranking_enabled,
  129. message_id,
  130. )
  131. document_score_list = {}
  132. for item in all_documents:
  133. if item.metadata.get('score'):
  134. document_score_list[item.metadata['doc_id']] = item.metadata['score']
  135. document_context_list = []
  136. index_node_ids = [document.metadata['doc_id'] for document in all_documents]
  137. segments = DocumentSegment.query.filter(
  138. DocumentSegment.dataset_id.in_(dataset_ids),
  139. DocumentSegment.completed_at.isnot(None),
  140. DocumentSegment.status == 'completed',
  141. DocumentSegment.enabled == True,
  142. DocumentSegment.index_node_id.in_(index_node_ids)
  143. ).all()
  144. if segments:
  145. index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
  146. sorted_segments = sorted(segments,
  147. key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
  148. float('inf')))
  149. for segment in sorted_segments:
  150. if segment.answer:
  151. document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}')
  152. else:
  153. document_context_list.append(segment.get_sign_content())
  154. if show_retrieve_source:
  155. context_list = []
  156. resource_number = 1
  157. for segment in sorted_segments:
  158. dataset = Dataset.query.filter_by(
  159. id=segment.dataset_id
  160. ).first()
  161. document = DatasetDocument.query.filter(DatasetDocument.id == segment.document_id,
  162. DatasetDocument.enabled == True,
  163. DatasetDocument.archived == False,
  164. ).first()
  165. if dataset and document:
  166. source = {
  167. 'position': resource_number,
  168. 'dataset_id': dataset.id,
  169. 'dataset_name': dataset.name,
  170. 'document_id': document.id,
  171. 'document_name': document.name,
  172. 'data_source_type': document.data_source_type,
  173. 'segment_id': segment.id,
  174. 'retriever_from': invoke_from.to_source(),
  175. 'score': document_score_list.get(segment.index_node_id, None)
  176. }
  177. if invoke_from.to_source() == 'dev':
  178. source['hit_count'] = segment.hit_count
  179. source['word_count'] = segment.word_count
  180. source['segment_position'] = segment.position
  181. source['index_node_hash'] = segment.index_node_hash
  182. if segment.answer:
  183. source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
  184. else:
  185. source['content'] = segment.content
  186. context_list.append(source)
  187. resource_number += 1
  188. if hit_callback:
  189. hit_callback.return_retriever_resource_info(context_list)
  190. return str("\n".join(document_context_list))
  191. return ''
  192. def single_retrieve(
  193. self, app_id: str,
  194. tenant_id: str,
  195. user_id: str,
  196. user_from: str,
  197. available_datasets: list,
  198. query: str,
  199. model_instance: ModelInstance,
  200. model_config: ModelConfigWithCredentialsEntity,
  201. planning_strategy: PlanningStrategy,
  202. message_id: Optional[str] = None,
  203. ):
  204. tools = []
  205. for dataset in available_datasets:
  206. description = dataset.description
  207. if not description:
  208. description = 'useful for when you want to answer queries about the ' + dataset.name
  209. description = description.replace('\n', '').replace('\r', '')
  210. message_tool = PromptMessageTool(
  211. name=dataset.id,
  212. description=description,
  213. parameters={
  214. "type": "object",
  215. "properties": {},
  216. "required": [],
  217. }
  218. )
  219. tools.append(message_tool)
  220. dataset_id = None
  221. if planning_strategy == PlanningStrategy.REACT_ROUTER:
  222. react_multi_dataset_router = ReactMultiDatasetRouter()
  223. dataset_id = react_multi_dataset_router.invoke(query, tools, model_config, model_instance,
  224. user_id, tenant_id)
  225. elif planning_strategy == PlanningStrategy.ROUTER:
  226. function_call_router = FunctionCallMultiDatasetRouter()
  227. dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
  228. if dataset_id:
  229. # get retrieval model config
  230. dataset = db.session.query(Dataset).filter(
  231. Dataset.id == dataset_id
  232. ).first()
  233. if dataset:
  234. retrieval_model_config = dataset.retrieval_model \
  235. if dataset.retrieval_model else default_retrieval_model
  236. # get top k
  237. top_k = retrieval_model_config['top_k']
  238. # get retrieval method
  239. if dataset.indexing_technique == "economy":
  240. retrival_method = 'keyword_search'
  241. else:
  242. retrival_method = retrieval_model_config['search_method']
  243. # get reranking model
  244. reranking_model = retrieval_model_config['reranking_model'] \
  245. if retrieval_model_config['reranking_enable'] else None
  246. # get score threshold
  247. score_threshold = .0
  248. score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
  249. if score_threshold_enabled:
  250. score_threshold = retrieval_model_config.get("score_threshold")
  251. with measure_time() as timer:
  252. results = RetrievalService.retrieve(
  253. retrival_method=retrival_method, dataset_id=dataset.id,
  254. query=query,
  255. top_k=top_k, score_threshold=score_threshold,
  256. reranking_model=reranking_model,
  257. reranking_mode=retrieval_model_config.get('reranking_mode', 'reranking_model'),
  258. weights=retrieval_model_config.get('weights', None),
  259. )
  260. self._on_query(query, [dataset_id], app_id, user_from, user_id)
  261. if results:
  262. self._on_retrival_end(results, message_id, timer)
  263. return results
  264. return []
  265. def multiple_retrieve(
  266. self,
  267. app_id: str,
  268. tenant_id: str,
  269. user_id: str,
  270. user_from: str,
  271. available_datasets: list,
  272. query: str,
  273. top_k: int,
  274. score_threshold: float,
  275. reranking_mode: str,
  276. reranking_model: Optional[dict] = None,
  277. weights: Optional[dict] = None,
  278. reranking_enable: bool = True,
  279. message_id: Optional[str] = None,
  280. ):
  281. threads = []
  282. all_documents = []
  283. dataset_ids = [dataset.id for dataset in available_datasets]
  284. index_type = None
  285. for dataset in available_datasets:
  286. index_type = dataset.indexing_technique
  287. retrieval_thread = threading.Thread(target=self._retriever, kwargs={
  288. 'flask_app': current_app._get_current_object(),
  289. 'dataset_id': dataset.id,
  290. 'query': query,
  291. 'top_k': top_k,
  292. 'all_documents': all_documents,
  293. })
  294. threads.append(retrieval_thread)
  295. retrieval_thread.start()
  296. for thread in threads:
  297. thread.join()
  298. with measure_time() as timer:
  299. if reranking_enable:
  300. # do rerank for searched documents
  301. data_post_processor = DataPostProcessor(
  302. tenant_id, reranking_mode,
  303. reranking_model, weights, False
  304. )
  305. all_documents = data_post_processor.invoke(
  306. query=query,
  307. documents=all_documents,
  308. score_threshold=score_threshold,
  309. top_n=top_k
  310. )
  311. else:
  312. if index_type == "economy":
  313. all_documents = self.calculate_keyword_score(query, all_documents, top_k)
  314. elif index_type == "high_quality":
  315. all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
  316. self._on_query(query, dataset_ids, app_id, user_from, user_id)
  317. if all_documents:
  318. self._on_retrival_end(all_documents, message_id, timer)
  319. return all_documents
  320. def _on_retrival_end(
  321. self, documents: list[Document], message_id: Optional[str] = None, timer: Optional[dict] = None
  322. ) -> None:
  323. """Handle retrival end."""
  324. for document in documents:
  325. query = db.session.query(DocumentSegment).filter(
  326. DocumentSegment.index_node_id == document.metadata['doc_id']
  327. )
  328. # if 'dataset_id' in document.metadata:
  329. if 'dataset_id' in document.metadata:
  330. query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
  331. # add hit count to document segment
  332. query.update(
  333. {DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
  334. synchronize_session=False
  335. )
  336. db.session.commit()
  337. # get tracing instance
  338. trace_manager: TraceQueueManager = self.application_generate_entity.trace_manager if self.application_generate_entity else None
  339. if trace_manager:
  340. trace_manager.add_trace_task(
  341. TraceTask(
  342. TraceTaskName.DATASET_RETRIEVAL_TRACE,
  343. message_id=message_id,
  344. documents=documents,
  345. timer=timer
  346. )
  347. )
  348. def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str) -> None:
  349. """
  350. Handle query.
  351. """
  352. if not query:
  353. return
  354. dataset_queries = []
  355. for dataset_id in dataset_ids:
  356. dataset_query = DatasetQuery(
  357. dataset_id=dataset_id,
  358. content=query,
  359. source='app',
  360. source_app_id=app_id,
  361. created_by_role=user_from,
  362. created_by=user_id
  363. )
  364. dataset_queries.append(dataset_query)
  365. if dataset_queries:
  366. db.session.add_all(dataset_queries)
  367. db.session.commit()
  368. def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
  369. with flask_app.app_context():
  370. dataset = db.session.query(Dataset).filter(
  371. Dataset.id == dataset_id
  372. ).first()
  373. if not dataset:
  374. return []
  375. # get retrieval model , if the model is not setting , using default
  376. retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
  377. if dataset.indexing_technique == "economy":
  378. # use keyword table query
  379. documents = RetrievalService.retrieve(retrival_method='keyword_search',
  380. dataset_id=dataset.id,
  381. query=query,
  382. top_k=top_k
  383. )
  384. if documents:
  385. all_documents.extend(documents)
  386. else:
  387. if top_k > 0:
  388. # retrieval source
  389. documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
  390. dataset_id=dataset.id,
  391. query=query,
  392. top_k=top_k,
  393. score_threshold=retrieval_model.get('score_threshold', .0)
  394. if retrieval_model['score_threshold_enabled'] else None,
  395. reranking_model=retrieval_model.get('reranking_model', None)
  396. if retrieval_model['reranking_enable'] else None,
  397. reranking_mode=retrieval_model.get('reranking_mode')
  398. if retrieval_model.get('reranking_mode') else 'reranking_model',
  399. weights=retrieval_model.get('weights', None),
  400. )
  401. all_documents.extend(documents)
  402. def to_dataset_retriever_tool(self, tenant_id: str,
  403. dataset_ids: list[str],
  404. retrieve_config: DatasetRetrieveConfigEntity,
  405. return_resource: bool,
  406. invoke_from: InvokeFrom,
  407. hit_callback: DatasetIndexToolCallbackHandler) \
  408. -> Optional[list[DatasetRetrieverBaseTool]]:
  409. """
  410. A dataset tool is a tool that can be used to retrieve information from a dataset
  411. :param tenant_id: tenant id
  412. :param dataset_ids: dataset ids
  413. :param retrieve_config: retrieve config
  414. :param return_resource: return resource
  415. :param invoke_from: invoke from
  416. :param hit_callback: hit callback
  417. """
  418. tools = []
  419. available_datasets = []
  420. for dataset_id in dataset_ids:
  421. # get dataset from dataset id
  422. dataset = db.session.query(Dataset).filter(
  423. Dataset.tenant_id == tenant_id,
  424. Dataset.id == dataset_id
  425. ).first()
  426. # pass if dataset is not available
  427. if not dataset:
  428. continue
  429. # pass if dataset is not available
  430. if (dataset and dataset.available_document_count == 0
  431. and dataset.available_document_count == 0):
  432. continue
  433. available_datasets.append(dataset)
  434. if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
  435. # get retrieval model config
  436. default_retrieval_model = {
  437. 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
  438. 'reranking_enable': False,
  439. 'reranking_model': {
  440. 'reranking_provider_name': '',
  441. 'reranking_model_name': ''
  442. },
  443. 'top_k': 2,
  444. 'score_threshold_enabled': False
  445. }
  446. for dataset in available_datasets:
  447. retrieval_model_config = dataset.retrieval_model \
  448. if dataset.retrieval_model else default_retrieval_model
  449. # get top k
  450. top_k = retrieval_model_config['top_k']
  451. # get score threshold
  452. score_threshold = None
  453. score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
  454. if score_threshold_enabled:
  455. score_threshold = retrieval_model_config.get("score_threshold")
  456. tool = DatasetRetrieverTool.from_dataset(
  457. dataset=dataset,
  458. top_k=top_k,
  459. score_threshold=score_threshold,
  460. hit_callbacks=[hit_callback],
  461. return_resource=return_resource,
  462. retriever_from=invoke_from.to_source()
  463. )
  464. tools.append(tool)
  465. elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
  466. tool = DatasetMultiRetrieverTool.from_dataset(
  467. dataset_ids=[dataset.id for dataset in available_datasets],
  468. tenant_id=tenant_id,
  469. top_k=retrieve_config.top_k or 2,
  470. score_threshold=retrieve_config.score_threshold,
  471. hit_callbacks=[hit_callback],
  472. return_resource=return_resource,
  473. retriever_from=invoke_from.to_source(),
  474. reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'),
  475. reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name')
  476. )
  477. tools.append(tool)
  478. return tools
  479. def calculate_keyword_score(self, query: str, documents: list[Document], top_k: int) -> list[Document]:
  480. """
  481. Calculate keywords scores
  482. :param query: search query
  483. :param documents: documents for reranking
  484. :return:
  485. """
  486. keyword_table_handler = JiebaKeywordTableHandler()
  487. query_keywords = keyword_table_handler.extract_keywords(query, None)
  488. documents_keywords = []
  489. for document in documents:
  490. # get the document keywords
  491. document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
  492. document.metadata['keywords'] = document_keywords
  493. documents_keywords.append(document_keywords)
  494. # Counter query keywords(TF)
  495. query_keyword_counts = Counter(query_keywords)
  496. # total documents
  497. total_documents = len(documents)
  498. # calculate all documents' keywords IDF
  499. all_keywords = set()
  500. for document_keywords in documents_keywords:
  501. all_keywords.update(document_keywords)
  502. keyword_idf = {}
  503. for keyword in all_keywords:
  504. # calculate include query keywords' documents
  505. doc_count_containing_keyword = sum(1 for doc_keywords in documents_keywords if keyword in doc_keywords)
  506. # IDF
  507. keyword_idf[keyword] = math.log((1 + total_documents) / (1 + doc_count_containing_keyword)) + 1
  508. query_tfidf = {}
  509. for keyword, count in query_keyword_counts.items():
  510. tf = count
  511. idf = keyword_idf.get(keyword, 0)
  512. query_tfidf[keyword] = tf * idf
  513. # calculate all documents' TF-IDF
  514. documents_tfidf = []
  515. for document_keywords in documents_keywords:
  516. document_keyword_counts = Counter(document_keywords)
  517. document_tfidf = {}
  518. for keyword, count in document_keyword_counts.items():
  519. tf = count
  520. idf = keyword_idf.get(keyword, 0)
  521. document_tfidf[keyword] = tf * idf
  522. documents_tfidf.append(document_tfidf)
  523. def cosine_similarity(vec1, vec2):
  524. intersection = set(vec1.keys()) & set(vec2.keys())
  525. numerator = sum(vec1[x] * vec2[x] for x in intersection)
  526. sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
  527. sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
  528. denominator = math.sqrt(sum1) * math.sqrt(sum2)
  529. if not denominator:
  530. return 0.0
  531. else:
  532. return float(numerator) / denominator
  533. similarities = []
  534. for document_tfidf in documents_tfidf:
  535. similarity = cosine_similarity(query_tfidf, document_tfidf)
  536. similarities.append(similarity)
  537. for document, score in zip(documents, similarities):
  538. # format document
  539. document.metadata['score'] = score
  540. documents = sorted(documents, key=lambda x: x.metadata['score'], reverse=True)
  541. return documents[:top_k] if top_k else documents
  542. def calculate_vector_score(self, all_documents: list[Document],
  543. top_k: int, score_threshold: float) -> list[Document]:
  544. filter_documents = []
  545. for document in all_documents:
  546. if score_threshold is None or document.metadata['score'] >= score_threshold:
  547. filter_documents.append(document)
  548. if not filter_documents:
  549. return []
  550. filter_documents = sorted(filter_documents, key=lambda x: x.metadata['score'], reverse=True)
  551. return filter_documents[:top_k] if top_k else filter_documents