dataset_retriever_tool.py 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from typing import Any
  2. from langchain.tools import BaseTool
  3. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  4. from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom
  5. from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
  6. from core.tools.entities.common_entities import I18nObject
  7. from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter
  8. from core.tools.tool.tool import Tool
  9. class DatasetRetrieverTool(Tool):
  10. langchain_tool: BaseTool
  11. @staticmethod
  12. def get_dataset_tools(tenant_id: str,
  13. dataset_ids: list[str],
  14. retrieve_config: DatasetRetrieveConfigEntity,
  15. return_resource: bool,
  16. invoke_from: InvokeFrom,
  17. hit_callback: DatasetIndexToolCallbackHandler
  18. ) -> list['DatasetRetrieverTool']:
  19. """
  20. get dataset tool
  21. """
  22. # check if retrieve_config is valid
  23. if dataset_ids is None or len(dataset_ids) == 0:
  24. return []
  25. if retrieve_config is None:
  26. return []
  27. feature = DatasetRetrievalFeature()
  28. # save original retrieve strategy, and set retrieve strategy to SINGLE
  29. # Agent only support SINGLE mode
  30. original_retriever_mode = retrieve_config.retrieve_strategy
  31. retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
  32. langchain_tools = feature.to_dataset_retriever_tool(
  33. tenant_id=tenant_id,
  34. dataset_ids=dataset_ids,
  35. retrieve_config=retrieve_config,
  36. return_resource=return_resource,
  37. invoke_from=invoke_from,
  38. hit_callback=hit_callback
  39. )
  40. # restore retrieve strategy
  41. retrieve_config.retrieve_strategy = original_retriever_mode
  42. # convert langchain tools to Tools
  43. tools = []
  44. for langchain_tool in langchain_tools:
  45. tool = DatasetRetrieverTool(
  46. langchain_tool=langchain_tool,
  47. identity=ToolIdentity(author='', name=langchain_tool.name, label=I18nObject(en_US='', zh_Hans='')),
  48. parameters=[],
  49. is_team_authorization=True,
  50. description=ToolDescription(
  51. human=I18nObject(en_US='', zh_Hans=''),
  52. llm=langchain_tool.description),
  53. runtime=DatasetRetrieverTool.Runtime()
  54. )
  55. tools.append(tool)
  56. return tools
  57. def get_runtime_parameters(self) -> list[ToolParameter]:
  58. return [
  59. ToolParameter(name='query',
  60. label=I18nObject(en_US='', zh_Hans=''),
  61. human_description=I18nObject(en_US='', zh_Hans=''),
  62. type=ToolParameter.ToolParameterType.STRING,
  63. form=ToolParameter.ToolParameterForm.LLM,
  64. llm_description='Query for the dataset to be used to retrieve the dataset.',
  65. required=True,
  66. default=''),
  67. ]
  68. def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
  69. """
  70. invoke dataset retriever tool
  71. """
  72. query = tool_parameters.get('query', None)
  73. if not query:
  74. return self.create_text_message(text='please input query')
  75. # invoke dataset retriever tool
  76. result = self.langchain_tool._run(query=query)
  77. return self.create_text_message(text=result)
  78. def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
  79. """
  80. validate the credentials for dataset retriever tool
  81. """
  82. pass