123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- from typing import Any
- from core.app.app_config.entities import DatasetRetrieveConfigEntity
- from core.app.entities.app_invoke_entities import InvokeFrom
- from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
- from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
- from core.tools.entities.common_entities import I18nObject
- from core.tools.entities.tool_entities import (
- ToolDescription,
- ToolIdentity,
- ToolInvokeMessage,
- ToolParameter,
- ToolProviderType,
- )
- from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
- from core.tools.tool.tool import Tool
- class DatasetRetrieverTool(Tool):
- retrieval_tool: DatasetRetrieverBaseTool
- @staticmethod
- def get_dataset_tools(
- tenant_id: str,
- dataset_ids: list[str],
- retrieve_config: DatasetRetrieveConfigEntity,
- return_resource: bool,
- invoke_from: InvokeFrom,
- hit_callback: DatasetIndexToolCallbackHandler,
- ) -> list["DatasetRetrieverTool"]:
- """
- get dataset tool
- """
- # check if retrieve_config is valid
- if dataset_ids is None or len(dataset_ids) == 0:
- return []
- if retrieve_config is None:
- return []
- feature = DatasetRetrieval()
- # save original retrieve strategy, and set retrieve strategy to SINGLE
- # Agent only support SINGLE mode
- original_retriever_mode = retrieve_config.retrieve_strategy
- retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
- retrieval_tools = feature.to_dataset_retriever_tool(
- tenant_id=tenant_id,
- dataset_ids=dataset_ids,
- retrieve_config=retrieve_config,
- return_resource=return_resource,
- invoke_from=invoke_from,
- hit_callback=hit_callback,
- )
- # restore retrieve strategy
- retrieve_config.retrieve_strategy = original_retriever_mode
- # convert retrieval tools to Tools
- tools = []
- for retrieval_tool in retrieval_tools:
- tool = DatasetRetrieverTool(
- retrieval_tool=retrieval_tool,
- identity=ToolIdentity(
- provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="")
- ),
- parameters=[],
- is_team_authorization=True,
- description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description),
- runtime=DatasetRetrieverTool.Runtime(),
- )
- tools.append(tool)
- return tools
- def get_runtime_parameters(self) -> list[ToolParameter]:
- return [
- ToolParameter(
- name="query",
- label=I18nObject(en_US="", zh_Hans=""),
- human_description=I18nObject(en_US="", zh_Hans=""),
- type=ToolParameter.ToolParameterType.STRING,
- form=ToolParameter.ToolParameterForm.LLM,
- llm_description="Query for the dataset to be used to retrieve the dataset.",
- required=True,
- default="",
- ),
- ]
- def tool_provider_type(self) -> ToolProviderType:
- return ToolProviderType.DATASET_RETRIEVAL
- def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
- """
- invoke dataset retriever tool
- """
- query = tool_parameters.get("query")
- if not query:
- return self.create_text_message(text="please input query")
- # invoke dataset retriever tool
- result = self.retrieval_tool._run(query=query)
- return self.create_text_message(text=result)
- def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
- """
- validate the credentials for dataset retriever tool
- """
- pass
|