tool.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. from abc import ABC, abstractmethod
  2. from copy import deepcopy
  3. from enum import Enum
  4. from typing import Any, Optional, Union
  5. from pydantic import BaseModel, ConfigDict, field_validator
  6. from pydantic_core.core_schema import ValidationInfo
  7. from core.app.entities.app_invoke_entities import InvokeFrom
  8. from core.file.file_obj import FileVar
  9. from core.tools.entities.tool_entities import (
  10. ToolDescription,
  11. ToolIdentity,
  12. ToolInvokeFrom,
  13. ToolInvokeMessage,
  14. ToolParameter,
  15. ToolProviderType,
  16. ToolRuntimeImageVariable,
  17. ToolRuntimeVariable,
  18. ToolRuntimeVariablePool,
  19. )
  20. from core.tools.tool_file_manager import ToolFileManager
  21. from core.tools.utils.tool_parameter_converter import ToolParameterConverter
  22. class Tool(BaseModel, ABC):
  23. identity: Optional[ToolIdentity] = None
  24. parameters: Optional[list[ToolParameter]] = None
  25. description: Optional[ToolDescription] = None
  26. is_team_authorization: bool = False
  27. # pydantic configs
  28. model_config = ConfigDict(protected_namespaces=())
  29. @field_validator('parameters', mode='before')
  30. @classmethod
  31. def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
  32. return v or []
  33. class Runtime(BaseModel):
  34. """
  35. Meta data of a tool call processing
  36. """
  37. def __init__(self, **data: Any):
  38. super().__init__(**data)
  39. if not self.runtime_parameters:
  40. self.runtime_parameters = {}
  41. tenant_id: Optional[str] = None
  42. tool_id: Optional[str] = None
  43. invoke_from: Optional[InvokeFrom] = None
  44. tool_invoke_from: Optional[ToolInvokeFrom] = None
  45. credentials: Optional[dict[str, Any]] = None
  46. runtime_parameters: Optional[dict[str, Any]] = None
  47. runtime: Optional[Runtime] = None
  48. variables: Optional[ToolRuntimeVariablePool] = None
  49. def __init__(self, **data: Any):
  50. super().__init__(**data)
  51. class VARIABLE_KEY(Enum):
  52. IMAGE = 'image'
  53. def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool':
  54. """
  55. fork a new tool with meta data
  56. :param meta: the meta data of a tool call processing, tenant_id is required
  57. :return: the new tool
  58. """
  59. return self.__class__(
  60. identity=self.identity.model_copy() if self.identity else None,
  61. parameters=self.parameters.copy() if self.parameters else None,
  62. description=self.description.model_copy() if self.description else None,
  63. runtime=Tool.Runtime(**runtime),
  64. )
  65. @abstractmethod
  66. def tool_provider_type(self) -> ToolProviderType:
  67. """
  68. get the tool provider type
  69. :return: the tool provider type
  70. """
  71. def load_variables(self, variables: ToolRuntimeVariablePool):
  72. """
  73. load variables from database
  74. :param conversation_id: the conversation id
  75. """
  76. self.variables = variables
  77. def set_image_variable(self, variable_name: str, image_key: str) -> None:
  78. """
  79. set an image variable
  80. """
  81. if not self.variables:
  82. return
  83. self.variables.set_file(self.identity.name, variable_name, image_key)
  84. def set_text_variable(self, variable_name: str, text: str) -> None:
  85. """
  86. set a text variable
  87. """
  88. if not self.variables:
  89. return
  90. self.variables.set_text(self.identity.name, variable_name, text)
  91. def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]:
  92. """
  93. get a variable
  94. :param name: the name of the variable
  95. :return: the variable
  96. """
  97. if not self.variables:
  98. return None
  99. if isinstance(name, Enum):
  100. name = name.value
  101. for variable in self.variables.pool:
  102. if variable.name == name:
  103. return variable
  104. return None
  105. def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]:
  106. """
  107. get the default image variable
  108. :return: the image variable
  109. """
  110. if not self.variables:
  111. return None
  112. return self.get_variable(self.VARIABLE_KEY.IMAGE)
  113. def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]:
  114. """
  115. get a variable file
  116. :param name: the name of the variable
  117. :return: the variable file
  118. """
  119. variable = self.get_variable(name)
  120. if not variable:
  121. return None
  122. if not isinstance(variable, ToolRuntimeImageVariable):
  123. return None
  124. message_file_id = variable.value
  125. # get file binary
  126. file_binary = ToolFileManager.get_file_binary_by_message_file_id(message_file_id)
  127. if not file_binary:
  128. return None
  129. return file_binary[0]
  130. def list_variables(self) -> list[ToolRuntimeVariable]:
  131. """
  132. list all variables
  133. :return: the variables
  134. """
  135. if not self.variables:
  136. return []
  137. return self.variables.pool
  138. def list_default_image_variables(self) -> list[ToolRuntimeVariable]:
  139. """
  140. list all image variables
  141. :return: the image variables
  142. """
  143. if not self.variables:
  144. return []
  145. result = []
  146. for variable in self.variables.pool:
  147. if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value):
  148. result.append(variable)
  149. return result
  150. def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
  151. # update tool_parameters
  152. if self.runtime.runtime_parameters:
  153. tool_parameters.update(self.runtime.runtime_parameters)
  154. # try parse tool parameters into the correct type
  155. tool_parameters = self._transform_tool_parameters_type(tool_parameters)
  156. result = self._invoke(
  157. user_id=user_id,
  158. tool_parameters=tool_parameters,
  159. )
  160. if not isinstance(result, list):
  161. result = [result]
  162. return result
  163. def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
  164. """
  165. Transform tool parameters type
  166. """
  167. # Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials
  168. result = deepcopy(tool_parameters)
  169. for parameter in self.parameters or []:
  170. if parameter.name in tool_parameters:
  171. result[parameter.name] = ToolParameterConverter.cast_parameter_by_type(tool_parameters[parameter.name], parameter.type)
  172. return result
  173. @abstractmethod
  174. def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
  175. pass
  176. def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
  177. """
  178. validate the credentials
  179. :param credentials: the credentials
  180. :param parameters: the parameters
  181. """
  182. pass
  183. def get_runtime_parameters(self) -> list[ToolParameter]:
  184. """
  185. get the runtime parameters
  186. interface for developer to dynamic change the parameters of a tool depends on the variables pool
  187. :return: the runtime parameters
  188. """
  189. return self.parameters
  190. def get_all_runtime_parameters(self) -> list[ToolParameter]:
  191. """
  192. get all runtime parameters
  193. :return: all runtime parameters
  194. """
  195. parameters = self.parameters or []
  196. parameters = parameters.copy()
  197. user_parameters = self.get_runtime_parameters() or []
  198. user_parameters = user_parameters.copy()
  199. # override parameters
  200. for parameter in user_parameters:
  201. # check if parameter in tool parameters
  202. found = False
  203. for tool_parameter in parameters:
  204. if tool_parameter.name == parameter.name:
  205. found = True
  206. break
  207. if found:
  208. # override parameter
  209. tool_parameter.type = parameter.type
  210. tool_parameter.form = parameter.form
  211. tool_parameter.required = parameter.required
  212. tool_parameter.default = parameter.default
  213. tool_parameter.options = parameter.options
  214. tool_parameter.llm_description = parameter.llm_description
  215. else:
  216. # add new parameter
  217. parameters.append(parameter)
  218. return parameters
  219. def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage:
  220. """
  221. create an image message
  222. :param image: the url of the image
  223. :return: the image message
  224. """
  225. return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE,
  226. message=image,
  227. save_as=save_as)
  228. def create_file_var_message(self, file_var: FileVar) -> ToolInvokeMessage:
  229. return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR,
  230. message='',
  231. meta={
  232. 'file_var': file_var
  233. },
  234. save_as='')
  235. def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
  236. """
  237. create a link message
  238. :param link: the url of the link
  239. :return: the link message
  240. """
  241. return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK,
  242. message=link,
  243. save_as=save_as)
  244. def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
  245. """
  246. create a text message
  247. :param text: the text
  248. :return: the text message
  249. """
  250. return ToolInvokeMessage(
  251. type=ToolInvokeMessage.MessageType.TEXT,
  252. message=text,
  253. save_as=save_as
  254. )
  255. def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
  256. """
  257. create a blob message
  258. :param blob: the blob
  259. :return: the blob message
  260. """
  261. return ToolInvokeMessage(
  262. type=ToolInvokeMessage.MessageType.BLOB,
  263. message=blob, meta=meta,
  264. save_as=save_as
  265. )
  266. def create_json_message(self, object: dict) -> ToolInvokeMessage:
  267. """
  268. create a json message
  269. """
  270. return ToolInvokeMessage(
  271. type=ToolInvokeMessage.MessageType.JSON,
  272. message=object
  273. )