tool.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. from abc import ABC, abstractmethod
  2. from enum import Enum
  3. from typing import Any, Optional, Union
  4. from pydantic import BaseModel
  5. from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
  6. from core.tools.entities.tool_entities import (
  7. ToolDescription,
  8. ToolIdentity,
  9. ToolInvokeMessage,
  10. ToolParameter,
  11. ToolRuntimeImageVariable,
  12. ToolRuntimeVariable,
  13. ToolRuntimeVariablePool,
  14. )
  15. from core.tools.tool_file_manager import ToolFileManager
  16. class Tool(BaseModel, ABC):
  17. identity: ToolIdentity = None
  18. parameters: Optional[list[ToolParameter]] = None
  19. description: ToolDescription = None
  20. is_team_authorization: bool = False
  21. agent_callback: Optional[DifyAgentCallbackHandler] = None
  22. use_callback: bool = False
  23. class Runtime(BaseModel):
  24. """
  25. Meta data of a tool call processing
  26. """
  27. def __init__(self, **data: Any):
  28. super().__init__(**data)
  29. if not self.runtime_parameters:
  30. self.runtime_parameters = {}
  31. tenant_id: str = None
  32. tool_id: str = None
  33. credentials: dict[str, Any] = None
  34. runtime_parameters: dict[str, Any] = None
  35. runtime: Runtime = None
  36. variables: ToolRuntimeVariablePool = None
  37. def __init__(self, **data: Any):
  38. super().__init__(**data)
  39. if not self.agent_callback:
  40. self.use_callback = False
  41. else:
  42. self.use_callback = True
  43. class VARIABLE_KEY(Enum):
  44. IMAGE = 'image'
  45. def fork_tool_runtime(self, meta: dict[str, Any], agent_callback: DifyAgentCallbackHandler = None) -> 'Tool':
  46. """
  47. fork a new tool with meta data
  48. :param meta: the meta data of a tool call processing, tenant_id is required
  49. :return: the new tool
  50. """
  51. return self.__class__(
  52. identity=self.identity.copy() if self.identity else None,
  53. parameters=self.parameters.copy() if self.parameters else None,
  54. description=self.description.copy() if self.description else None,
  55. runtime=Tool.Runtime(**meta),
  56. agent_callback=agent_callback
  57. )
  58. def load_variables(self, variables: ToolRuntimeVariablePool):
  59. """
  60. load variables from database
  61. :param conversation_id: the conversation id
  62. """
  63. self.variables = variables
  64. def set_image_variable(self, variable_name: str, image_key: str) -> None:
  65. """
  66. set an image variable
  67. """
  68. if not self.variables:
  69. return
  70. self.variables.set_file(self.identity.name, variable_name, image_key)
  71. def set_text_variable(self, variable_name: str, text: str) -> None:
  72. """
  73. set a text variable
  74. """
  75. if not self.variables:
  76. return
  77. self.variables.set_text(self.identity.name, variable_name, text)
  78. def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]:
  79. """
  80. get a variable
  81. :param name: the name of the variable
  82. :return: the variable
  83. """
  84. if not self.variables:
  85. return None
  86. if isinstance(name, Enum):
  87. name = name.value
  88. for variable in self.variables.pool:
  89. if variable.name == name:
  90. return variable
  91. return None
  92. def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]:
  93. """
  94. get the default image variable
  95. :return: the image variable
  96. """
  97. if not self.variables:
  98. return None
  99. return self.get_variable(self.VARIABLE_KEY.IMAGE)
  100. def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]:
  101. """
  102. get a variable file
  103. :param name: the name of the variable
  104. :return: the variable file
  105. """
  106. variable = self.get_variable(name)
  107. if not variable:
  108. return None
  109. if not isinstance(variable, ToolRuntimeImageVariable):
  110. return None
  111. message_file_id = variable.value
  112. # get file binary
  113. file_binary = ToolFileManager.get_file_binary_by_message_file_id(message_file_id)
  114. if not file_binary:
  115. return None
  116. return file_binary[0]
  117. def list_variables(self) -> list[ToolRuntimeVariable]:
  118. """
  119. list all variables
  120. :return: the variables
  121. """
  122. if not self.variables:
  123. return []
  124. return self.variables.pool
  125. def list_default_image_variables(self) -> list[ToolRuntimeVariable]:
  126. """
  127. list all image variables
  128. :return: the image variables
  129. """
  130. if not self.variables:
  131. return []
  132. result = []
  133. for variable in self.variables.pool:
  134. if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value):
  135. result.append(variable)
  136. return result
  137. def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
  138. # update tool_parameters
  139. if self.runtime.runtime_parameters:
  140. tool_parameters.update(self.runtime.runtime_parameters)
  141. # hit callback
  142. if self.use_callback:
  143. self.agent_callback.on_tool_start(
  144. tool_name=self.identity.name,
  145. tool_inputs=tool_parameters
  146. )
  147. try:
  148. result = self._invoke(
  149. user_id=user_id,
  150. tool_parameters=tool_parameters,
  151. )
  152. except Exception as e:
  153. if self.use_callback:
  154. self.agent_callback.on_tool_error(e)
  155. raise e
  156. if not isinstance(result, list):
  157. result = [result]
  158. # hit callback
  159. if self.use_callback:
  160. self.agent_callback.on_tool_end(
  161. tool_name=self.identity.name,
  162. tool_inputs=tool_parameters,
  163. tool_outputs=self._convert_tool_response_to_str(result)
  164. )
  165. return result
  166. def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
  167. """
  168. Handle tool response
  169. """
  170. result = ''
  171. for response in tool_response:
  172. if response.type == ToolInvokeMessage.MessageType.TEXT:
  173. result += response.message
  174. elif response.type == ToolInvokeMessage.MessageType.LINK:
  175. result += f"result link: {response.message}. please tell user to check it."
  176. elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
  177. response.type == ToolInvokeMessage.MessageType.IMAGE:
  178. result += "image has been created and sent to user already, you should tell user to check it now."
  179. elif response.type == ToolInvokeMessage.MessageType.BLOB:
  180. if len(response.message) > 114:
  181. result += str(response.message[:114]) + '...'
  182. else:
  183. result += str(response.message)
  184. else:
  185. result += f"tool response: {response.message}."
  186. return result
  187. @abstractmethod
  188. def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
  189. pass
  190. def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
  191. """
  192. validate the credentials
  193. :param credentials: the credentials
  194. :param parameters: the parameters
  195. """
  196. pass
  197. def get_runtime_parameters(self) -> list[ToolParameter]:
  198. """
  199. get the runtime parameters
  200. interface for developer to dynamic change the parameters of a tool depends on the variables pool
  201. :return: the runtime parameters
  202. """
  203. return self.parameters
  204. def is_tool_available(self) -> bool:
  205. """
  206. check if the tool is available
  207. :return: if the tool is available
  208. """
  209. return True
  210. def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage:
  211. """
  212. create an image message
  213. :param image: the url of the image
  214. :return: the image message
  215. """
  216. return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE,
  217. message=image,
  218. save_as=save_as)
  219. def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
  220. """
  221. create a link message
  222. :param link: the url of the link
  223. :return: the link message
  224. """
  225. return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK,
  226. message=link,
  227. save_as=save_as)
  228. def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
  229. """
  230. create a text message
  231. :param text: the text
  232. :return: the text message
  233. """
  234. return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.TEXT,
  235. message=text,
  236. save_as=save_as
  237. )
  238. def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
  239. """
  240. create a blob message
  241. :param blob: the blob
  242. :return: the blob message
  243. """
  244. return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.BLOB,
  245. message=blob, meta=meta,
  246. save_as=save_as
  247. )