dataset_retriever_tool.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from typing import Any
  2. from core.app.app_config.entities import DatasetRetrieveConfigEntity
  3. from core.app.entities.app_invoke_entities import InvokeFrom
  4. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  5. from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
  6. from core.tools.entities.common_entities import I18nObject
  7. from core.tools.entities.tool_entities import (
  8. ToolDescription,
  9. ToolIdentity,
  10. ToolInvokeMessage,
  11. ToolParameter,
  12. ToolProviderType,
  13. )
  14. from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
  15. from core.tools.tool.tool import Tool
  16. class DatasetRetrieverTool(Tool):
  17. retrieval_tool: DatasetRetrieverBaseTool
  18. @staticmethod
  19. def get_dataset_tools(
  20. tenant_id: str,
  21. dataset_ids: list[str],
  22. retrieve_config: DatasetRetrieveConfigEntity,
  23. return_resource: bool,
  24. invoke_from: InvokeFrom,
  25. hit_callback: DatasetIndexToolCallbackHandler,
  26. ) -> list["DatasetRetrieverTool"]:
  27. """
  28. get dataset tool
  29. """
  30. # check if retrieve_config is valid
  31. if dataset_ids is None or len(dataset_ids) == 0:
  32. return []
  33. if retrieve_config is None:
  34. return []
  35. feature = DatasetRetrieval()
  36. # save original retrieve strategy, and set retrieve strategy to SINGLE
  37. # Agent only support SINGLE mode
  38. original_retriever_mode = retrieve_config.retrieve_strategy
  39. retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
  40. retrieval_tools = feature.to_dataset_retriever_tool(
  41. tenant_id=tenant_id,
  42. dataset_ids=dataset_ids,
  43. retrieve_config=retrieve_config,
  44. return_resource=return_resource,
  45. invoke_from=invoke_from,
  46. hit_callback=hit_callback,
  47. )
  48. # restore retrieve strategy
  49. retrieve_config.retrieve_strategy = original_retriever_mode
  50. # convert retrieval tools to Tools
  51. tools = []
  52. for retrieval_tool in retrieval_tools:
  53. tool = DatasetRetrieverTool(
  54. retrieval_tool=retrieval_tool,
  55. identity=ToolIdentity(
  56. provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="")
  57. ),
  58. parameters=[],
  59. is_team_authorization=True,
  60. description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description),
  61. runtime=DatasetRetrieverTool.Runtime(),
  62. )
  63. tools.append(tool)
  64. return tools
  65. def get_runtime_parameters(self) -> list[ToolParameter]:
  66. return [
  67. ToolParameter(
  68. name="query",
  69. label=I18nObject(en_US="", zh_Hans=""),
  70. human_description=I18nObject(en_US="", zh_Hans=""),
  71. type=ToolParameter.ToolParameterType.STRING,
  72. form=ToolParameter.ToolParameterForm.LLM,
  73. llm_description="Query for the dataset to be used to retrieve the dataset.",
  74. required=True,
  75. default="",
  76. ),
  77. ]
  78. def tool_provider_type(self) -> ToolProviderType:
  79. return ToolProviderType.DATASET_RETRIEVAL
  80. def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
  81. """
  82. invoke dataset retriever tool
  83. """
  84. query = tool_parameters.get("query")
  85. if not query:
  86. return self.create_text_message(text="please input query")
  87. # invoke dataset retriever tool
  88. result = self.retrieval_tool._run(query=query)
  89. return self.create_text_message(text=result)
  90. def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
  91. """
  92. validate the credentials for dataset retriever tool
  93. """
  94. pass