| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 | from typing import Anyfrom core.app.app_config.entities import DatasetRetrieveConfigEntityfrom core.app.entities.app_invoke_entities import InvokeFromfrom core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandlerfrom core.rag.retrieval.dataset_retrieval import DatasetRetrievalfrom core.tools.entities.common_entities import I18nObjectfrom core.tools.entities.tool_entities import (    ToolDescription,    ToolIdentity,    ToolInvokeMessage,    ToolParameter,    ToolProviderType,)from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseToolfrom core.tools.tool.tool import Toolclass DatasetRetrieverTool(Tool):    retrival_tool: DatasetRetrieverBaseTool    @staticmethod    def get_dataset_tools(tenant_id: str,                          dataset_ids: list[str],                          retrieve_config: DatasetRetrieveConfigEntity,                          return_resource: bool,                          invoke_from: InvokeFrom,                          hit_callback: DatasetIndexToolCallbackHandler                          ) -> list['DatasetRetrieverTool']:        """        get dataset tool        """        # check if retrieve_config is valid        if dataset_ids is None or len(dataset_ids) == 0:            return []        if retrieve_config is None:            return []        feature = DatasetRetrieval()        # save original retrieve strategy, and set retrieve strategy to SINGLE        # Agent only support SINGLE mode        original_retriever_mode = retrieve_config.retrieve_strategy        retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE        retrival_tools = feature.to_dataset_retriever_tool(            tenant_id=tenant_id,            dataset_ids=dataset_ids,            retrieve_config=retrieve_config,            return_resource=return_resource,            invoke_from=invoke_from,            hit_callback=hit_callback        )        # restore retrieve strategy        retrieve_config.retrieve_strategy = original_retriever_mode        # convert retrival tools to Tools        tools = []        for retrival_tool in retrival_tools:            tool = DatasetRetrieverTool(                retrival_tool=retrival_tool,                identity=ToolIdentity(provider='', author='', name=retrival_tool.name, label=I18nObject(en_US='', zh_Hans='')),                parameters=[],                is_team_authorization=True,                description=ToolDescription(                    human=I18nObject(en_US='', zh_Hans=''),                    llm=retrival_tool.description),                runtime=DatasetRetrieverTool.Runtime()            )            tools.append(tool)        return tools    def get_runtime_parameters(self) -> list[ToolParameter]:        return [            ToolParameter(name='query',                          label=I18nObject(en_US='', zh_Hans=''),                          human_description=I18nObject(en_US='', zh_Hans=''),                          type=ToolParameter.ToolParameterType.STRING,                          form=ToolParameter.ToolParameterForm.LLM,                          llm_description='Query for the dataset to be used to retrieve the dataset.',                          required=True,                          default=''),        ]        def tool_provider_type(self) -> ToolProviderType:        return ToolProviderType.DATASET_RETRIEVAL    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:        """        invoke dataset retriever tool        """        query = tool_parameters.get('query', None)        if not query:            return self.create_text_message(text='please input query')        # invoke dataset retriever tool        result = self.retrival_tool._run(query=query)        return self.create_text_message(text=result)    def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:        """        validate the credentials for dataset retriever tool        """        pass
 |