dataset_retriever_tool.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. from typing import Any
  2. from langchain.tools import BaseTool
  3. from core.app.app_config.entities import DatasetRetrieveConfigEntity
  4. from core.app.entities.app_invoke_entities import InvokeFrom
  5. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  6. from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
  7. from core.tools.entities.common_entities import I18nObject
  8. from core.tools.entities.tool_entities import (
  9. ToolDescription,
  10. ToolIdentity,
  11. ToolInvokeMessage,
  12. ToolParameter,
  13. ToolProviderType,
  14. )
  15. from core.tools.tool.tool import Tool
  16. class DatasetRetrieverTool(Tool):
  17. langchain_tool: BaseTool
  18. @staticmethod
  19. def get_dataset_tools(tenant_id: str,
  20. dataset_ids: list[str],
  21. retrieve_config: DatasetRetrieveConfigEntity,
  22. return_resource: bool,
  23. invoke_from: InvokeFrom,
  24. hit_callback: DatasetIndexToolCallbackHandler
  25. ) -> list['DatasetRetrieverTool']:
  26. """
  27. get dataset tool
  28. """
  29. # check if retrieve_config is valid
  30. if dataset_ids is None or len(dataset_ids) == 0:
  31. return []
  32. if retrieve_config is None:
  33. return []
  34. feature = DatasetRetrieval()
  35. # save original retrieve strategy, and set retrieve strategy to SINGLE
  36. # Agent only support SINGLE mode
  37. original_retriever_mode = retrieve_config.retrieve_strategy
  38. retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
  39. langchain_tools = feature.to_dataset_retriever_tool(
  40. tenant_id=tenant_id,
  41. dataset_ids=dataset_ids,
  42. retrieve_config=retrieve_config,
  43. return_resource=return_resource,
  44. invoke_from=invoke_from,
  45. hit_callback=hit_callback
  46. )
  47. # restore retrieve strategy
  48. retrieve_config.retrieve_strategy = original_retriever_mode
  49. # convert langchain tools to Tools
  50. tools = []
  51. for langchain_tool in langchain_tools:
  52. tool = DatasetRetrieverTool(
  53. langchain_tool=langchain_tool,
  54. identity=ToolIdentity(provider='', author='', name=langchain_tool.name, label=I18nObject(en_US='', zh_Hans='')),
  55. parameters=[],
  56. is_team_authorization=True,
  57. description=ToolDescription(
  58. human=I18nObject(en_US='', zh_Hans=''),
  59. llm=langchain_tool.description),
  60. runtime=DatasetRetrieverTool.Runtime()
  61. )
  62. tools.append(tool)
  63. return tools
  64. def get_runtime_parameters(self) -> list[ToolParameter]:
  65. return [
  66. ToolParameter(name='query',
  67. label=I18nObject(en_US='', zh_Hans=''),
  68. human_description=I18nObject(en_US='', zh_Hans=''),
  69. type=ToolParameter.ToolParameterType.STRING,
  70. form=ToolParameter.ToolParameterForm.LLM,
  71. llm_description='Query for the dataset to be used to retrieve the dataset.',
  72. required=True,
  73. default=''),
  74. ]
  75. def tool_provider_type(self) -> ToolProviderType:
  76. return ToolProviderType.DATASET_RETRIEVAL
  77. def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
  78. """
  79. invoke dataset retriever tool
  80. """
  81. query = tool_parameters.get('query', None)
  82. if not query:
  83. return self.create_text_message(text='please input query')
  84. # invoke dataset retriever tool
  85. result = self.langchain_tool._run(query=query)
  86. return self.create_text_message(text=result)
  87. def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
  88. """
  89. validate the credentials for dataset retriever tool
  90. """
  91. pass