tool_manager.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. import importlib
  2. import json
  3. import logging
  4. import mimetypes
  5. from os import listdir, path
  6. from typing import Any, Union
  7. from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
  8. from core.model_runtime.entities.message_entities import PromptMessage
  9. from core.tools.entities.common_entities import I18nObject
  10. from core.tools.entities.constant import DEFAULT_PROVIDERS
  11. from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeMessage, ToolProviderCredentials
  12. from core.tools.entities.user_entities import UserToolProvider
  13. from core.tools.errors import ToolProviderNotFoundError
  14. from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
  15. from core.tools.provider.app_tool_provider import AppBasedToolProviderEntity
  16. from core.tools.provider.builtin._positions import BuiltinToolProviderSort
  17. from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
  18. from core.tools.provider.tool_provider import ToolProviderController
  19. from core.tools.tool.api_tool import ApiTool
  20. from core.tools.tool.builtin_tool import BuiltinTool
  21. from core.tools.utils.configuration import ToolConfiguration
  22. from core.tools.utils.encoder import serialize_base_model_dict
  23. from extensions.ext_database import db
  24. from models.tools import ApiToolProvider, BuiltinToolProvider
  25. logger = logging.getLogger(__name__)
  26. _builtin_providers = {}
  27. _builtin_tools_labels = {}
  28. class ToolManager:
  29. @staticmethod
  30. def invoke(
  31. provider: str,
  32. tool_id: str,
  33. tool_name: str,
  34. tool_parameters: dict[str, Any],
  35. credentials: dict[str, Any],
  36. prompt_messages: list[PromptMessage],
  37. ) -> list[ToolInvokeMessage]:
  38. """
  39. invoke the assistant
  40. :param provider: the name of the provider
  41. :param tool_id: the id of the tool
  42. :param tool_name: the name of the tool, defined in `get_tools`
  43. :param tool_parameters: the parameters of the tool
  44. :param credentials: the credentials of the tool
  45. :param prompt_messages: the prompt messages that the tool can use
  46. :return: the messages that the tool wants to send to the user
  47. """
  48. provider_entity: ToolProviderController = None
  49. if provider == DEFAULT_PROVIDERS.API_BASED:
  50. provider_entity = ApiBasedToolProviderController()
  51. elif provider == DEFAULT_PROVIDERS.APP_BASED:
  52. provider_entity = AppBasedToolProviderEntity()
  53. if provider_entity is None:
  54. # fetch the provider from .provider.builtin
  55. py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py')
  56. spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path)
  57. mod = importlib.util.module_from_spec(spec)
  58. spec.loader.exec_module(mod)
  59. # get all the classes in the module
  60. classes = [ x for _, x in vars(mod).items()
  61. if isinstance(x, type) and x != ToolProviderController and issubclass(x, ToolProviderController)
  62. ]
  63. if len(classes) == 0:
  64. raise ToolProviderNotFoundError(f'provider {provider} not found')
  65. if len(classes) > 1:
  66. raise ToolProviderNotFoundError(f'multiple providers found for {provider}')
  67. provider_entity = classes[0]()
  68. return provider_entity.invoke(tool_id, tool_name, tool_parameters, credentials, prompt_messages)
  69. @staticmethod
  70. def get_builtin_provider(provider: str) -> BuiltinToolProviderController:
  71. global _builtin_providers
  72. """
  73. get the builtin provider
  74. :param provider: the name of the provider
  75. :return: the provider
  76. """
  77. if len(_builtin_providers) == 0:
  78. # init the builtin providers
  79. ToolManager.list_builtin_providers()
  80. if provider not in _builtin_providers:
  81. raise ToolProviderNotFoundError(f'builtin provider {provider} not found')
  82. return _builtin_providers[provider]
  83. @staticmethod
  84. def get_builtin_tool(provider: str, tool_name: str) -> BuiltinTool:
  85. """
  86. get the builtin tool
  87. :param provider: the name of the provider
  88. :param tool_name: the name of the tool
  89. :return: the provider, the tool
  90. """
  91. provider_controller = ToolManager.get_builtin_provider(provider)
  92. tool = provider_controller.get_tool(tool_name)
  93. return tool
  94. @staticmethod
  95. def get_tool(provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \
  96. -> Union[BuiltinTool, ApiTool]:
  97. """
  98. get the tool
  99. :param provider_type: the type of the provider
  100. :param provider_name: the name of the provider
  101. :param tool_name: the name of the tool
  102. :return: the tool
  103. """
  104. if provider_type == 'builtin':
  105. return ToolManager.get_builtin_tool(provider_id, tool_name)
  106. elif provider_type == 'api':
  107. if tenant_id is None:
  108. raise ValueError('tenant id is required for api provider')
  109. api_provider, _ = ToolManager.get_api_provider_controller(tenant_id, provider_id)
  110. return api_provider.get_tool(tool_name)
  111. elif provider_type == 'app':
  112. raise NotImplementedError('app provider not implemented')
  113. else:
  114. raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
  115. @staticmethod
  116. def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id,
  117. agent_callback: DifyAgentCallbackHandler = None) \
  118. -> Union[BuiltinTool, ApiTool]:
  119. """
  120. get the tool runtime
  121. :param provider_type: the type of the provider
  122. :param provider_name: the name of the provider
  123. :param tool_name: the name of the tool
  124. :return: the tool
  125. """
  126. if provider_type == 'builtin':
  127. builtin_tool = ToolManager.get_builtin_tool(provider_name, tool_name)
  128. # check if the builtin tool need credentials
  129. provider_controller = ToolManager.get_builtin_provider(provider_name)
  130. if not provider_controller.need_credentials:
  131. return builtin_tool.fork_tool_runtime(meta={
  132. 'tenant_id': tenant_id,
  133. 'credentials': {},
  134. }, agent_callback=agent_callback)
  135. # get credentials
  136. builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
  137. BuiltinToolProvider.tenant_id == tenant_id,
  138. BuiltinToolProvider.provider == provider_name,
  139. ).first()
  140. if builtin_provider is None:
  141. raise ToolProviderNotFoundError(f'builtin provider {provider_name} not found')
  142. # decrypt the credentials
  143. credentials = builtin_provider.credentials
  144. controller = ToolManager.get_builtin_provider(provider_name)
  145. tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
  146. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
  147. return builtin_tool.fork_tool_runtime(meta={
  148. 'tenant_id': tenant_id,
  149. 'credentials': decrypted_credentials,
  150. 'runtime_parameters': {}
  151. }, agent_callback=agent_callback)
  152. elif provider_type == 'api':
  153. if tenant_id is None:
  154. raise ValueError('tenant id is required for api provider')
  155. api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name)
  156. # decrypt the credentials
  157. tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=api_provider)
  158. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
  159. return api_provider.get_tool(tool_name).fork_tool_runtime(meta={
  160. 'tenant_id': tenant_id,
  161. 'credentials': decrypted_credentials,
  162. })
  163. elif provider_type == 'app':
  164. raise NotImplementedError('app provider not implemented')
  165. else:
  166. raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
  167. @staticmethod
  168. def get_builtin_provider_icon(provider: str) -> tuple[str, str]:
  169. """
  170. get the absolute path of the icon of the builtin provider
  171. :param provider: the name of the provider
  172. :return: the absolute path of the icon, the mime type of the icon
  173. """
  174. # get provider
  175. provider_controller = ToolManager.get_builtin_provider(provider)
  176. absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets', provider_controller.identity.icon)
  177. # check if the icon exists
  178. if not path.exists(absolute_path):
  179. raise ToolProviderNotFoundError(f'builtin provider {provider} icon not found')
  180. # get the mime type
  181. mime_type, _ = mimetypes.guess_type(absolute_path)
  182. mime_type = mime_type or 'application/octet-stream'
  183. return absolute_path, mime_type
  184. @staticmethod
  185. def list_builtin_providers() -> list[BuiltinToolProviderController]:
  186. global _builtin_providers
  187. # use cache first
  188. if len(_builtin_providers) > 0:
  189. return list(_builtin_providers.values())
  190. builtin_providers: list[BuiltinToolProviderController] = []
  191. for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
  192. if provider.startswith('__'):
  193. continue
  194. if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider)):
  195. if provider.startswith('__'):
  196. continue
  197. py_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, f'{provider}.py')
  198. spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path)
  199. mod = importlib.util.module_from_spec(spec)
  200. spec.loader.exec_module(mod)
  201. # load all classes
  202. classes = [
  203. obj for name, obj in vars(mod).items()
  204. if isinstance(obj, type) and obj != BuiltinToolProviderController and issubclass(obj, BuiltinToolProviderController)
  205. ]
  206. if len(classes) == 0:
  207. raise ToolProviderNotFoundError(f'provider {provider} not found')
  208. if len(classes) > 1:
  209. raise ToolProviderNotFoundError(f'multiple providers found for {provider}')
  210. # init provider
  211. provider_class = classes[0]
  212. builtin_providers.append(provider_class())
  213. # cache the builtin providers
  214. for provider in builtin_providers:
  215. _builtin_providers[provider.identity.name] = provider
  216. for tool in provider.get_tools():
  217. _builtin_tools_labels[tool.identity.name] = tool.identity.label
  218. return builtin_providers
  219. @staticmethod
  220. def get_tool_label(tool_name: str) -> Union[I18nObject, None]:
  221. """
  222. get the tool label
  223. :param tool_name: the name of the tool
  224. :return: the label of the tool
  225. """
  226. global _builtin_tools_labels
  227. if len(_builtin_tools_labels) == 0:
  228. # init the builtin providers
  229. ToolManager.list_builtin_providers()
  230. if tool_name not in _builtin_tools_labels:
  231. return None
  232. return _builtin_tools_labels[tool_name]
  233. @staticmethod
  234. def user_list_providers(
  235. user_id: str,
  236. tenant_id: str,
  237. ) -> list[UserToolProvider]:
  238. result_providers: dict[str, UserToolProvider] = {}
  239. # get builtin providers
  240. builtin_providers = ToolManager.list_builtin_providers()
  241. # append builtin providers
  242. for provider in builtin_providers:
  243. result_providers[provider.identity.name] = UserToolProvider(
  244. id=provider.identity.name,
  245. author=provider.identity.author,
  246. name=provider.identity.name,
  247. description=I18nObject(
  248. en_US=provider.identity.description.en_US,
  249. zh_Hans=provider.identity.description.zh_Hans,
  250. ),
  251. icon=provider.identity.icon,
  252. label=I18nObject(
  253. en_US=provider.identity.label.en_US,
  254. zh_Hans=provider.identity.label.zh_Hans,
  255. ),
  256. type=UserToolProvider.ProviderType.BUILTIN,
  257. team_credentials={},
  258. is_team_authorization=False,
  259. )
  260. # get credentials schema
  261. schema = provider.get_credentials_schema()
  262. for name, value in schema.items():
  263. result_providers[provider.identity.name].team_credentials[name] = \
  264. ToolProviderCredentials.CredentialsType.default(value.type)
  265. # check if the provider need credentials
  266. if not provider.need_credentials:
  267. result_providers[provider.identity.name].is_team_authorization = True
  268. result_providers[provider.identity.name].allow_delete = False
  269. # get db builtin providers
  270. db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
  271. filter(BuiltinToolProvider.tenant_id == tenant_id).all()
  272. for db_builtin_provider in db_builtin_providers:
  273. # add provider into providers
  274. credentials = db_builtin_provider.credentials
  275. provider_name = db_builtin_provider.provider
  276. result_providers[provider_name].is_team_authorization = True
  277. # package builtin tool provider controller
  278. controller = ToolManager.get_builtin_provider(provider_name)
  279. # init tool configuration
  280. tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
  281. # decrypt the credentials and mask the credentials
  282. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
  283. masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
  284. result_providers[provider_name].team_credentials = masked_credentials
  285. # get db api providers
  286. db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
  287. filter(ApiToolProvider.tenant_id == tenant_id).all()
  288. for db_api_provider in db_api_providers:
  289. username = 'Anonymous'
  290. try:
  291. username = db_api_provider.user.name
  292. except Exception as e:
  293. logger.error(f'failed to get user name for api provider {db_api_provider.id}: {str(e)}')
  294. # add provider into providers
  295. credentials = db_api_provider.credentials
  296. provider_name = db_api_provider.name
  297. result_providers[provider_name] = UserToolProvider(
  298. id=db_api_provider.id,
  299. author=username,
  300. name=db_api_provider.name,
  301. description=I18nObject(
  302. en_US=db_api_provider.description,
  303. zh_Hans=db_api_provider.description,
  304. ),
  305. icon=db_api_provider.icon,
  306. label=I18nObject(
  307. en_US=db_api_provider.name,
  308. zh_Hans=db_api_provider.name,
  309. ),
  310. type=UserToolProvider.ProviderType.API,
  311. team_credentials={},
  312. is_team_authorization=True,
  313. )
  314. # package tool provider controller
  315. controller = ApiBasedToolProviderController.from_db(
  316. db_provider=db_api_provider,
  317. auth_type=ApiProviderAuthType.API_KEY if db_api_provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
  318. )
  319. # init tool configuration
  320. tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
  321. # decrypt the credentials and mask the credentials
  322. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
  323. masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
  324. result_providers[provider_name].team_credentials = masked_credentials
  325. return BuiltinToolProviderSort.sort(list(result_providers.values()))
  326. @staticmethod
  327. def get_api_provider_controller(tenant_id: str, provider_id: str) -> tuple[ApiBasedToolProviderController, dict[str, Any]]:
  328. """
  329. get the api provider
  330. :param provider_name: the name of the provider
  331. :return: the provider controller, the credentials
  332. """
  333. provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
  334. ApiToolProvider.id == provider_id,
  335. ApiToolProvider.tenant_id == tenant_id,
  336. ).first()
  337. if provider is None:
  338. raise ToolProviderNotFoundError(f'api provider {provider_id} not found')
  339. controller = ApiBasedToolProviderController.from_db(
  340. provider, ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
  341. )
  342. controller.load_bundled_tools(provider.tools)
  343. return controller, provider.credentials
  344. @staticmethod
  345. def user_get_api_provider(provider: str, tenant_id: str) -> dict:
  346. """
  347. get api provider
  348. """
  349. """
  350. get tool provider
  351. """
  352. provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
  353. ApiToolProvider.tenant_id == tenant_id,
  354. ApiToolProvider.name == provider,
  355. ).first()
  356. if provider is None:
  357. raise ValueError(f'you have not added provider {provider}')
  358. try:
  359. credentials = json.loads(provider.credentials_str) or {}
  360. except:
  361. credentials = {}
  362. # package tool provider controller
  363. controller = ApiBasedToolProviderController.from_db(
  364. provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
  365. )
  366. # init tool configuration
  367. tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
  368. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
  369. masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
  370. try:
  371. icon = json.loads(provider.icon)
  372. except:
  373. icon = {
  374. "background": "#252525",
  375. "content": "\ud83d\ude01"
  376. }
  377. return json.loads(serialize_base_model_dict({
  378. 'schema_type': provider.schema_type,
  379. 'schema': provider.schema,
  380. 'tools': provider.tools,
  381. 'icon': icon,
  382. 'description': provider.description,
  383. 'credentials': masked_credentials,
  384. 'privacy_policy': provider.privacy_policy
  385. }))