builtin_tool_provider.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. import importlib
  2. from abc import abstractmethod
  3. from os import listdir, path
  4. from typing import Any, Dict, List
  5. from yaml import FullLoader, load
  6. from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
  7. from core.tools.entities.user_entities import UserToolProviderCredentials
  8. from core.tools.errors import (
  9. ToolNotFoundError,
  10. ToolParameterValidationError,
  11. ToolProviderCredentialValidationError,
  12. ToolProviderNotFoundError,
  13. )
  14. from core.tools.provider.tool_provider import ToolProviderController
  15. from core.tools.tool.builtin_tool import BuiltinTool
  16. from core.tools.tool.tool import Tool
  17. class BuiltinToolProviderController(ToolProviderController):
  18. def __init__(self, **data: Any) -> None:
  19. if self.app_type == ToolProviderType.API_BASED or self.app_type == ToolProviderType.APP_BASED:
  20. super().__init__(**data)
  21. return
  22. # load provider yaml
  23. provider = self.__class__.__module__.split('.')[-1]
  24. yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml')
  25. try:
  26. with open(yaml_path, 'r') as f:
  27. provider_yaml = load(f.read(), FullLoader)
  28. except:
  29. raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}')
  30. if 'credentials_for_provider' in provider_yaml and provider_yaml['credentials_for_provider'] is not None:
  31. # set credentials name
  32. for credential_name in provider_yaml['credentials_for_provider']:
  33. provider_yaml['credentials_for_provider'][credential_name]['name'] = credential_name
  34. super().__init__(**{
  35. 'identity': provider_yaml['identity'],
  36. 'credentials_schema': provider_yaml['credentials_for_provider'] if 'credentials_for_provider' in provider_yaml else None,
  37. })
  38. def _get_builtin_tools(self) -> List[Tool]:
  39. """
  40. returns a list of tools that the provider can provide
  41. :return: list of tools
  42. """
  43. if self.tools:
  44. return self.tools
  45. provider = self.identity.name
  46. tool_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, "tools")
  47. # get all the yaml files in the tool path
  48. tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
  49. tools = []
  50. for tool_file in tool_files:
  51. with open(path.join(tool_path, tool_file), "r") as f:
  52. # get tool name
  53. tool_name = tool_file.split(".")[0]
  54. tool = load(f.read(), FullLoader)
  55. # get tool class, import the module
  56. py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, 'tools', f'{tool_name}.py')
  57. spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.tools.{tool_name}', py_path)
  58. mod = importlib.util.module_from_spec(spec)
  59. spec.loader.exec_module(mod)
  60. # get all the classes in the module
  61. classes = [x for _, x in vars(mod).items()
  62. if isinstance(x, type) and x not in [BuiltinTool, Tool] and issubclass(x, BuiltinTool)
  63. ]
  64. assistant_tool_class = classes[0]
  65. tools.append(assistant_tool_class(**tool))
  66. self.tools = tools
  67. return tools
  68. def get_credentials_schema(self) -> Dict[str, ToolProviderCredentials]:
  69. """
  70. returns the credentials schema of the provider
  71. :return: the credentials schema
  72. """
  73. if not self.credentials_schema:
  74. return {}
  75. return self.credentials_schema.copy()
  76. def user_get_credentials_schema(self) -> UserToolProviderCredentials:
  77. """
  78. returns the credentials schema of the provider, this method is used for user
  79. :return: the credentials schema
  80. """
  81. credentials = self.credentials_schema.copy()
  82. return UserToolProviderCredentials(credentials=credentials)
  83. def get_tools(self) -> List[Tool]:
  84. """
  85. returns a list of tools that the provider can provide
  86. :return: list of tools
  87. """
  88. return self._get_builtin_tools()
  89. def get_tool(self, tool_name: str) -> Tool:
  90. """
  91. returns the tool that the provider can provide
  92. """
  93. return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
  94. def get_parameters(self, tool_name: str) -> List[ToolParameter]:
  95. """
  96. returns the parameters of the tool
  97. :param tool_name: the name of the tool, defined in `get_tools`
  98. :return: list of parameters
  99. """
  100. tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
  101. if tool is None:
  102. raise ToolNotFoundError(f'tool {tool_name} not found')
  103. return tool.parameters
  104. @property
  105. def need_credentials(self) -> bool:
  106. """
  107. returns whether the provider needs credentials
  108. :return: whether the provider needs credentials
  109. """
  110. return self.credentials_schema is not None and len(self.credentials_schema) != 0
  111. @property
  112. def app_type(self) -> ToolProviderType:
  113. """
  114. returns the type of the provider
  115. :return: type of the provider
  116. """
  117. return ToolProviderType.BUILT_IN
  118. def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: Dict[str, Any]) -> None:
  119. """
  120. validate the parameters of the tool and set the default value if needed
  121. :param tool_name: the name of the tool, defined in `get_tools`
  122. :param tool_parameters: the parameters of the tool
  123. """
  124. tool_parameters_schema = self.get_parameters(tool_name)
  125. tool_parameters_need_to_validate: Dict[str, ToolParameter] = {}
  126. for parameter in tool_parameters_schema:
  127. tool_parameters_need_to_validate[parameter.name] = parameter
  128. for parameter in tool_parameters:
  129. if parameter not in tool_parameters_need_to_validate:
  130. raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}')
  131. # check type
  132. parameter_schema = tool_parameters_need_to_validate[parameter]
  133. if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
  134. if not isinstance(tool_parameters[parameter], str):
  135. raise ToolParameterValidationError(f'parameter {parameter} should be string')
  136. elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
  137. if not isinstance(tool_parameters[parameter], (int, float)):
  138. raise ToolParameterValidationError(f'parameter {parameter} should be number')
  139. if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
  140. raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}')
  141. if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
  142. raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}')
  143. elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
  144. if not isinstance(tool_parameters[parameter], bool):
  145. raise ToolParameterValidationError(f'parameter {parameter} should be boolean')
  146. elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
  147. if not isinstance(tool_parameters[parameter], str):
  148. raise ToolParameterValidationError(f'parameter {parameter} should be string')
  149. options = parameter_schema.options
  150. if not isinstance(options, list):
  151. raise ToolParameterValidationError(f'parameter {parameter} options should be list')
  152. if tool_parameters[parameter] not in [x.value for x in options]:
  153. raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}')
  154. tool_parameters_need_to_validate.pop(parameter)
  155. for parameter in tool_parameters_need_to_validate:
  156. parameter_schema = tool_parameters_need_to_validate[parameter]
  157. if parameter_schema.required:
  158. raise ToolParameterValidationError(f'parameter {parameter} is required')
  159. # the parameter is not set currently, set the default value if needed
  160. if parameter_schema.default is not None:
  161. default_value = parameter_schema.default
  162. # parse default value into the correct type
  163. if parameter_schema.type == ToolParameter.ToolParameterType.STRING or \
  164. parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
  165. default_value = str(default_value)
  166. elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
  167. default_value = float(default_value)
  168. elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
  169. default_value = bool(default_value)
  170. tool_parameters[parameter] = default_value
  171. def validate_credentials_format(self, credentials: Dict[str, Any]) -> None:
  172. """
  173. validate the format of the credentials of the provider and set the default value if needed
  174. :param credentials: the credentials of the tool
  175. """
  176. credentials_schema = self.credentials_schema
  177. if credentials_schema is None:
  178. return
  179. credentials_need_to_validate: Dict[str, ToolProviderCredentials] = {}
  180. for credential_name in credentials_schema:
  181. credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
  182. for credential_name in credentials:
  183. if credential_name not in credentials_need_to_validate:
  184. raise ToolProviderCredentialValidationError(f'credential {credential_name} not found in provider {self.identity.name}')
  185. # check type
  186. credential_schema = credentials_need_to_validate[credential_name]
  187. if credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
  188. credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT:
  189. if not isinstance(credentials[credential_name], str):
  190. raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be string')
  191. elif credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
  192. if not isinstance(credentials[credential_name], str):
  193. raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be string')
  194. options = credential_schema.options
  195. if not isinstance(options, list):
  196. raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} options should be list')
  197. if credentials[credential_name] not in [x.value for x in options]:
  198. raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be one of {options}')
  199. if credentials[credential_name]:
  200. credentials_need_to_validate.pop(credential_name)
  201. for credential_name in credentials_need_to_validate:
  202. credential_schema = credentials_need_to_validate[credential_name]
  203. if credential_schema.required:
  204. raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} is required')
  205. # the credential is not set currently, set the default value if needed
  206. if credential_schema.default is not None:
  207. default_value = credential_schema.default
  208. # parse default value into the correct type
  209. if credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
  210. credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT or \
  211. credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
  212. default_value = str(default_value)
  213. credentials[credential_name] = default_value
  214. def validate_credentials(self, credentials: Dict[str, Any]) -> None:
  215. """
  216. validate the credentials of the provider
  217. :param tool_name: the name of the tool, defined in `get_tools`
  218. :param credentials: the credentials of the tool
  219. """
  220. # validate credentials format
  221. self.validate_credentials_format(credentials)
  222. # validate credentials
  223. self._validate_credentials(credentials)
  224. @abstractmethod
  225. def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
  226. """
  227. validate the credentials of the provider
  228. :param tool_name: the name of the tool, defined in `get_tools`
  229. :param credentials: the credentials of the tool
  230. """
  231. pass