| 
					
				 | 
			
			
				@@ -1,23 +1,40 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import threading 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from typing import Optional, cast 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from langchain.tools import BaseTool 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from flask import Flask, current_app 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.entities.agent_entities import PlanningStrategy 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.memory.token_buffer_memory import TokenBufferMemory 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from core.model_runtime.entities.model_entities import ModelFeature 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.model_manager import ModelInstance, ModelManager 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.model_runtime.entities.message_entities import PromptMessageTool 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.model_runtime.entities.model_entities import ModelFeature, ModelType 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from core.rag.retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.rag.datasource.retrieval_service import RetrievalService 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.rag.models.document import Document 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.rerank.rerank import RerankRunner 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from extensions.ext_database import db 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from models.dataset import Dataset 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from models.dataset import Dataset, DatasetQuery, DocumentSegment 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from models.dataset import Document as DatasetDocument 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+default_retrieval_model = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    'search_method': 'semantic_search', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    'reranking_enable': False, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    'reranking_model': { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'reranking_provider_name': '', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'reranking_model_name': '' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    }, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    'top_k': 2, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    'score_threshold_enabled': False 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 class DatasetRetrieval: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def retrieve(self, tenant_id: str, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def retrieve(self, app_id: str, user_id: str, tenant_id: str, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                  model_config: ModelConfigWithCredentialsEntity, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                  config: DatasetEntity, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                  query: str, 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -27,6 +44,8 @@ class DatasetRetrieval: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                  memory: Optional[TokenBufferMemory] = None) -> Optional[str]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         Retrieve dataset. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param app_id: app_id 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param user_id: user_id 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :param tenant_id: tenant id 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :param model_config: model config 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :param config: dataset config 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -38,12 +57,22 @@ class DatasetRetrieval: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :return: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         dataset_ids = config.dataset_ids 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if len(dataset_ids) == 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         retrieve_config = config.retrieve_config 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # check model is support tool calling 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         model_type_instance = model_config.provider_model_bundle.model_type_instance 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         model_type_instance = cast(LargeLanguageModel, model_type_instance) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        model_manager = ModelManager() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        model_instance = model_manager.get_model_instance( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            tenant_id=tenant_id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            model_type=ModelType.LLM, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            provider=model_config.provider, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            model=model_config.model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # get model schema 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         model_schema = model_type_instance.get_model_schema( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             model=model_config.model, 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -59,56 +88,6 @@ class DatasetRetrieval: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if ModelFeature.TOOL_CALL in features \ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     or ModelFeature.MULTI_TOOL_CALL in features: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 planning_strategy = PlanningStrategy.ROUTER 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        dataset_retriever_tools = self.to_dataset_retriever_tool( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            tenant_id=tenant_id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            dataset_ids=dataset_ids, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            retrieve_config=retrieve_config, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            return_resource=show_retrieve_source, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            invoke_from=invoke_from, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            hit_callback=hit_callback 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if len(dataset_retriever_tools) == 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            return None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        agent_configuration = AgentConfiguration( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            strategy=planning_strategy, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            model_config=model_config, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            tools=dataset_retriever_tools, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            memory=memory, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            max_iterations=10, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            max_execution_time=400.0, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            early_stopping_method="generate" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        agent_executor = AgentExecutor(agent_configuration) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        should_use_agent = agent_executor.should_use_agent(query) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if not should_use_agent: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            return None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        result = agent_executor.run(query) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return result.output 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def to_dataset_retriever_tool(self, tenant_id: str, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                  dataset_ids: list[str], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                  retrieve_config: DatasetRetrieveConfigEntity, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                  return_resource: bool, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                  invoke_from: InvokeFrom, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                  hit_callback: DatasetIndexToolCallbackHandler) \ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            -> Optional[list[BaseTool]]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        A dataset tool is a tool that can be used to retrieve information from a dataset 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        :param tenant_id: tenant id 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        :param dataset_ids: dataset ids 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        :param retrieve_config: retrieve config 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        :param return_resource: return resource 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        :param invoke_from: invoke from 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        :param hit_callback: hit callback 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        tools = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         available_datasets = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         for dataset_id in dataset_ids: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             # get dataset from dataset id 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -127,56 +106,270 @@ class DatasetRetrieval: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             available_datasets.append(dataset) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        all_documents = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        user_from = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            all_documents = self.single_retrieve(app_id, tenant_id, user_id, user_from, available_datasets, query, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                 model_instance, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                 model_config, planning_strategy) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            all_documents = self.multiple_retrieve(app_id, tenant_id, user_id, user_from, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                   available_datasets, query, retrieve_config.top_k, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                   retrieve_config.score_threshold, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                   retrieve_config.reranking_model.get('reranking_provider_name'), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                   retrieve_config.reranking_model.get('reranking_model_name')) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        document_score_list = {} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for item in all_documents: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if 'score' in item.metadata and item.metadata['score']: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                document_score_list[item.metadata['doc_id']] = item.metadata['score'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        document_context_list = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        index_node_ids = [document.metadata['doc_id'] for document in all_documents] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        segments = DocumentSegment.query.filter( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            DocumentSegment.dataset_id.in_(dataset_ids), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            DocumentSegment.completed_at.isnot(None), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            DocumentSegment.status == 'completed', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            DocumentSegment.enabled == True, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            DocumentSegment.index_node_id.in_(index_node_ids) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ).all() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if segments: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            sorted_segments = sorted(segments, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                     key=lambda segment: index_node_id_to_position.get(segment.index_node_id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                                                       float('inf'))) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for segment in sorted_segments: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if segment.answer: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    document_context_list.append(f'question:{segment.content} answer:{segment.answer}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    document_context_list.append(segment.content) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if show_retrieve_source: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                context_list = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                resource_number = 1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                for segment in sorted_segments: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    dataset = Dataset.query.filter_by( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        id=segment.dataset_id 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    ).first() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    document = DatasetDocument.query.filter(DatasetDocument.id == segment.document_id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                            DatasetDocument.enabled == True, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                            DatasetDocument.archived == False, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                            ).first() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    if dataset and document: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        source = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            'position': resource_number, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            'dataset_id': dataset.id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            'dataset_name': dataset.name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            'document_id': document.id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            'document_name': document.name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            'data_source_type': document.data_source_type, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            'segment_id': segment.id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            'retriever_from': invoke_from.to_source(), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            'score': document_score_list.get(segment.index_node_id, None) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        if invoke_from.to_source() == 'dev': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            source['hit_count'] = segment.hit_count 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            source['word_count'] = segment.word_count 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            source['segment_position'] = segment.position 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            source['index_node_hash'] = segment.index_node_hash 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        if segment.answer: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            source['content'] = segment.content 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        context_list.append(source) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    resource_number += 1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if hit_callback: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    hit_callback.return_retriever_resource_info(context_list) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return str("\n".join(document_context_list)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return '' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def single_retrieve(self, app_id: str, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        tenant_id: str, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        user_id: str, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        user_from: str, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        available_datasets: list, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        query: str, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        model_instance: ModelInstance, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        model_config: ModelConfigWithCredentialsEntity, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        planning_strategy: PlanningStrategy, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        ): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        tools = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for dataset in available_datasets: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            description = dataset.description 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if not description: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                description = 'useful for when you want to answer queries about the ' + dataset.name 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            description = description.replace('\n', '').replace('\r', '') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            message_tool = PromptMessageTool( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                name=dataset.id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                description=description, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                parameters={ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "type": "object", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "properties": {}, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "required": [], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            tools.append(message_tool) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        dataset_id = None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if planning_strategy == PlanningStrategy.REACT_ROUTER: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            react_multi_dataset_router = ReactMultiDatasetRouter() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            dataset_id = react_multi_dataset_router.invoke(query, tools, model_config, model_instance, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                           user_id, tenant_id) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif planning_strategy == PlanningStrategy.ROUTER: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            function_call_router = FunctionCallMultiDatasetRouter() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            dataset_id = function_call_router.invoke(query, tools, model_config, model_instance) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if dataset_id: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             # get retrieval model config 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            default_retrieval_model = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                'search_method': 'semantic_search', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                'reranking_enable': False, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                'reranking_model': { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    'reranking_provider_name': '', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    'reranking_model_name': '' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                }, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                'top_k': 2, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                'score_threshold_enabled': False 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            for dataset in available_datasets: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            dataset = db.session.query(Dataset).filter( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                Dataset.id == dataset_id 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ).first() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if dataset: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 retrieval_model_config = dataset.retrieval_model \ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     if dataset.retrieval_model else default_retrieval_model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 # get top k 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 top_k = retrieval_model_config['top_k'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                # get retrieval method 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if dataset.indexing_technique == "economy": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    retrival_method = 'keyword_search' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    retrival_method = retrieval_model_config['search_method'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                # get reranking model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                reranking_model = retrieval_model_config['reranking_model'] \ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    if retrieval_model_config['reranking_enable'] else None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 # get score threshold 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                score_threshold = None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                score_threshold = .0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 if score_threshold_enabled: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     score_threshold = retrieval_model_config.get("score_threshold") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                tool = DatasetRetrieverTool.from_dataset( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    dataset=dataset, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    top_k=top_k, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    score_threshold=score_threshold, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    hit_callbacks=[hit_callback], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    return_resource=return_resource, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    retriever_from=invoke_from.to_source() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                    query=query, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                    top_k=top_k, score_threshold=score_threshold, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                    reranking_model=reranking_model) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                self._on_query(query, [dataset_id], app_id, user_from, user_id) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if results: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    self._on_retrival_end(results) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                return results 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                tools.append(tool) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            tool = DatasetMultiRetrieverTool.from_dataset( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                dataset_ids=[dataset.id for dataset in available_datasets], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                tenant_id=tenant_id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                top_k=retrieve_config.top_k or 2, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                score_threshold=retrieve_config.score_threshold, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                hit_callbacks=[hit_callback], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                return_resource=return_resource, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                retriever_from=invoke_from.to_source(), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def multiple_retrieve(self, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                          app_id: str, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                          tenant_id: str, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                          user_id: str, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                          user_from: str, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                          available_datasets: list, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                          query: str, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                          top_k: int, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                          score_threshold: float, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                          reranking_provider_name: str, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                          reranking_model_name: str): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        threads = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        all_documents = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        dataset_ids = [dataset.id for dataset in available_datasets] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for dataset in available_datasets: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            retrieval_thread = threading.Thread(target=self._retriever, kwargs={ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'flask_app': current_app._get_current_object(), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'dataset_id': dataset.id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'query': query, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'top_k': top_k, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'all_documents': all_documents, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            }) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            threads.append(retrieval_thread) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            retrieval_thread.start() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for thread in threads: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            thread.join() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # do rerank for searched documents 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        model_manager = ModelManager() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        rerank_model_instance = model_manager.get_model_instance( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            tenant_id=tenant_id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            provider=reranking_provider_name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            model_type=ModelType.RERANK, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            model=reranking_model_name 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        rerank_runner = RerankRunner(rerank_model_instance) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        all_documents = rerank_runner.run(query, all_documents, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                          score_threshold, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                          top_k) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self._on_query(query, dataset_ids, app_id, user_from, user_id) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if all_documents: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self._on_retrival_end(all_documents) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return all_documents 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _on_retrival_end(self, documents: list[Document]) -> None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """Handle retrival end.""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for document in documents: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            query = db.session.query(DocumentSegment).filter( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                DocumentSegment.index_node_id == document.metadata['doc_id'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            tools.append(tool) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # if 'dataset_id' in document.metadata: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if 'dataset_id' in document.metadata: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # add hit count to document segment 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            query.update( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                synchronize_session=False 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            db.session.commit() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str) -> None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Handle query. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if not query: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for dataset_id in dataset_ids: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            dataset_query = DatasetQuery( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                dataset_id=dataset_id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                content=query, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                source='app', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                source_app_id=app_id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                created_by_role=user_from, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                created_by=user_id 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            db.session.add(dataset_query) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        db.session.commit() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        with flask_app.app_context(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            dataset = db.session.query(Dataset).filter( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                Dataset.id == dataset_id 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ).first() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if not dataset: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                return [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # get retrieval model , if the model is not setting , using default 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if dataset.indexing_technique == "economy": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                # use keyword table query 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                documents = RetrievalService.retrieve(retrival_method='keyword_search', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                      dataset_id=dataset.id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                      query=query, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                      top_k=top_k 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                      ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if documents: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    all_documents.extend(documents) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if top_k > 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    # retrieval source 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                          dataset_id=dataset.id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                          query=query, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                          top_k=top_k, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                          score_threshold=retrieval_model['score_threshold'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                          if retrieval_model['score_threshold_enabled'] else None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                          reranking_model=retrieval_model['reranking_model'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                          if retrieval_model['reranking_enable'] else None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                          ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return tools 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    all_documents.extend(documents) 
			 |