tool_manager.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. import importlib
  2. import json
  3. import logging
  4. import mimetypes
  5. from os import listdir, path
  6. from typing import Any, Dict, List, Tuple, Union
  7. from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
  8. from core.model_runtime.entities.message_entities import PromptMessage
  9. from core.tools.entities.common_entities import I18nObject
  10. from core.tools.entities.constant import DEFAULT_PROVIDERS
  11. from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeMessage, ToolProviderCredentials
  12. from core.tools.entities.user_entities import UserToolProvider
  13. from core.tools.errors import ToolProviderNotFoundError
  14. from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
  15. from core.tools.provider.app_tool_provider import AppBasedToolProviderEntity
  16. from core.tools.provider.builtin._positions import BuiltinToolProviderSort
  17. from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
  18. from core.tools.provider.tool_provider import ToolProviderController
  19. from core.tools.tool.api_tool import ApiTool
  20. from core.tools.tool.builtin_tool import BuiltinTool
  21. from core.tools.utils.configuration import ToolConfiguration
  22. from core.tools.utils.encoder import serialize_base_model_dict
  23. from extensions.ext_database import db
  24. from models.tools import ApiToolProvider, BuiltinToolProvider
  25. logger = logging.getLogger(__name__)
  26. _builtin_providers = {}
  27. _builtin_tools_labels = {}
  28. class ToolManager:
  29. @staticmethod
  30. def invoke(
  31. provider: str,
  32. tool_id: str,
  33. tool_name: str,
  34. tool_parameters: Dict[str, Any],
  35. credentials: Dict[str, Any],
  36. prompt_messages: List[PromptMessage],
  37. ) -> List[ToolInvokeMessage]:
  38. """
  39. invoke the assistant
  40. :param provider: the name of the provider
  41. :param tool_id: the id of the tool
  42. :param tool_name: the name of the tool, defined in `get_tools`
  43. :param tool_parameters: the parameters of the tool
  44. :param credentials: the credentials of the tool
  45. :param prompt_messages: the prompt messages that the tool can use
  46. :return: the messages that the tool wants to send to the user
  47. """
  48. provider_entity: ToolProviderController = None
  49. if provider == DEFAULT_PROVIDERS.API_BASED:
  50. provider_entity = ApiBasedToolProviderController()
  51. elif provider == DEFAULT_PROVIDERS.APP_BASED:
  52. provider_entity = AppBasedToolProviderEntity()
  53. if provider_entity is None:
  54. # fetch the provider from .provider.builtin
  55. py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py')
  56. spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path)
  57. mod = importlib.util.module_from_spec(spec)
  58. spec.loader.exec_module(mod)
  59. # get all the classes in the module
  60. classes = [ x for _, x in vars(mod).items()
  61. if isinstance(x, type) and x != ToolProviderController and issubclass(x, ToolProviderController)
  62. ]
  63. if len(classes) == 0:
  64. raise ToolProviderNotFoundError(f'provider {provider} not found')
  65. if len(classes) > 1:
  66. raise ToolProviderNotFoundError(f'multiple providers found for {provider}')
  67. provider_entity = classes[0]()
  68. return provider_entity.invoke(tool_id, tool_name, tool_parameters, credentials, prompt_messages)
  69. @staticmethod
  70. def get_builtin_provider(provider: str) -> BuiltinToolProviderController:
  71. global _builtin_providers
  72. """
  73. get the builtin provider
  74. :param provider: the name of the provider
  75. :return: the provider
  76. """
  77. if len(_builtin_providers) == 0:
  78. # init the builtin providers
  79. ToolManager.list_builtin_providers()
  80. if provider not in _builtin_providers:
  81. raise ToolProviderNotFoundError(f'builtin provider {provider} not found')
  82. return _builtin_providers[provider]
  83. @staticmethod
  84. def get_builtin_tool(provider: str, tool_name: str) -> BuiltinTool:
  85. """
  86. get the builtin tool
  87. :param provider: the name of the provider
  88. :param tool_name: the name of the tool
  89. :return: the provider, the tool
  90. """
  91. provider_controller = ToolManager.get_builtin_provider(provider)
  92. tool = provider_controller.get_tool(tool_name)
  93. return tool
  94. @staticmethod
  95. def get_tool(provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \
  96. -> Union[BuiltinTool, ApiTool]:
  97. """
  98. get the tool
  99. :param provider_type: the type of the provider
  100. :param provider_name: the name of the provider
  101. :param tool_name: the name of the tool
  102. :return: the tool
  103. """
  104. if provider_type == 'builtin':
  105. return ToolManager.get_builtin_tool(provider_id, tool_name)
  106. elif provider_type == 'api':
  107. if tenant_id is None:
  108. raise ValueError('tenant id is required for api provider')
  109. api_provider, _ = ToolManager.get_api_provider_controller(tenant_id, provider_id)
  110. return api_provider.get_tool(tool_name)
  111. elif provider_type == 'app':
  112. raise NotImplementedError('app provider not implemented')
  113. else:
  114. raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
  115. @staticmethod
  116. def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id,
  117. agent_callback: DifyAgentCallbackHandler = None) \
  118. -> Union[BuiltinTool, ApiTool]:
  119. """
  120. get the tool runtime
  121. :param provider_type: the type of the provider
  122. :param provider_name: the name of the provider
  123. :param tool_name: the name of the tool
  124. :return: the tool
  125. """
  126. if provider_type == 'builtin':
  127. builtin_tool = ToolManager.get_builtin_tool(provider_name, tool_name)
  128. # check if the builtin tool need credentials
  129. provider_controller = ToolManager.get_builtin_provider(provider_name)
  130. if not provider_controller.need_credentials:
  131. return builtin_tool.fork_tool_runtime(meta={
  132. 'tenant_id': tenant_id,
  133. 'credentials': {},
  134. }, agent_callback=agent_callback)
  135. # get credentials
  136. builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
  137. BuiltinToolProvider.tenant_id == tenant_id,
  138. BuiltinToolProvider.provider == provider_name,
  139. ).first()
  140. if builtin_provider is None:
  141. raise ToolProviderNotFoundError(f'builtin provider {provider_name} not found')
  142. # decrypt the credentials
  143. credentials = builtin_provider.credentials
  144. controller = ToolManager.get_builtin_provider(provider_name)
  145. tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
  146. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
  147. return builtin_tool.fork_tool_runtime(meta={
  148. 'tenant_id': tenant_id,
  149. 'credentials': decrypted_credentials,
  150. 'runtime_parameters': {}
  151. }, agent_callback=agent_callback)
  152. elif provider_type == 'api':
  153. if tenant_id is None:
  154. raise ValueError('tenant id is required for api provider')
  155. api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name)
  156. # decrypt the credentials
  157. tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=api_provider)
  158. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
  159. return api_provider.get_tool(tool_name).fork_tool_runtime(meta={
  160. 'tenant_id': tenant_id,
  161. 'credentials': decrypted_credentials,
  162. })
  163. elif provider_type == 'app':
  164. raise NotImplementedError('app provider not implemented')
  165. else:
  166. raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
  167. @staticmethod
  168. def get_builtin_provider_icon(provider: str) -> Tuple[str, str]:
  169. """
  170. get the absolute path of the icon of the builtin provider
  171. :param provider: the name of the provider
  172. :return: the absolute path of the icon, the mime type of the icon
  173. """
  174. # get provider
  175. provider_controller = ToolManager.get_builtin_provider(provider)
  176. absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets', provider_controller.identity.icon)
  177. # check if the icon exists
  178. if not path.exists(absolute_path):
  179. raise ToolProviderNotFoundError(f'builtin provider {provider} icon not found')
  180. # get the mime type
  181. mime_type, _ = mimetypes.guess_type(absolute_path)
  182. mime_type = mime_type or 'application/octet-stream'
  183. return absolute_path, mime_type
  184. @staticmethod
  185. def list_builtin_providers() -> List[BuiltinToolProviderController]:
  186. global _builtin_providers
  187. # use cache first
  188. if len(_builtin_providers) > 0:
  189. return list(_builtin_providers.values())
  190. builtin_providers: List[BuiltinToolProviderController] = []
  191. for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
  192. if provider.startswith('__'):
  193. continue
  194. if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider)):
  195. if provider.startswith('__'):
  196. continue
  197. py_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, f'{provider}.py')
  198. spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path)
  199. mod = importlib.util.module_from_spec(spec)
  200. spec.loader.exec_module(mod)
  201. # load all classes
  202. classes = [
  203. obj for name, obj in vars(mod).items()
  204. if isinstance(obj, type) and obj != BuiltinToolProviderController and issubclass(obj, BuiltinToolProviderController)
  205. ]
  206. if len(classes) == 0:
  207. raise ToolProviderNotFoundError(f'provider {provider} not found')
  208. if len(classes) > 1:
  209. raise ToolProviderNotFoundError(f'multiple providers found for {provider}')
  210. # init provider
  211. provider_class = classes[0]
  212. builtin_providers.append(provider_class())
  213. # cache the builtin providers
  214. for provider in builtin_providers:
  215. _builtin_providers[provider.identity.name] = provider
  216. for tool in provider.get_tools():
  217. _builtin_tools_labels[tool.identity.name] = tool.identity.label
  218. return builtin_providers
  219. @staticmethod
  220. def get_tool_label(tool_name: str) -> Union[I18nObject, None]:
  221. """
  222. get the tool label
  223. :param tool_name: the name of the tool
  224. :return: the label of the tool
  225. """
  226. global _builtin_tools_labels
  227. if len(_builtin_tools_labels) == 0:
  228. # init the builtin providers
  229. ToolManager.list_builtin_providers()
  230. if tool_name not in _builtin_tools_labels:
  231. return None
  232. return _builtin_tools_labels[tool_name]
  233. @staticmethod
  234. def user_list_providers(
  235. user_id: str,
  236. tenant_id: str,
  237. ) -> List[UserToolProvider]:
  238. result_providers: Dict[str, UserToolProvider] = {}
  239. # get builtin providers
  240. builtin_providers = ToolManager.list_builtin_providers()
  241. # append builtin providers
  242. for provider in builtin_providers:
  243. result_providers[provider.identity.name] = UserToolProvider(
  244. id=provider.identity.name,
  245. author=provider.identity.author,
  246. name=provider.identity.name,
  247. description=I18nObject(
  248. en_US=provider.identity.description.en_US,
  249. zh_Hans=provider.identity.description.zh_Hans,
  250. ),
  251. icon=provider.identity.icon,
  252. label=I18nObject(
  253. en_US=provider.identity.label.en_US,
  254. zh_Hans=provider.identity.label.zh_Hans,
  255. ),
  256. type=UserToolProvider.ProviderType.BUILTIN,
  257. team_credentials={},
  258. is_team_authorization=False,
  259. )
  260. # get credentials schema
  261. schema = provider.get_credentials_schema()
  262. for name, value in schema.items():
  263. result_providers[provider.identity.name].team_credentials[name] = \
  264. ToolProviderCredentials.CredentialsType.default(value.type)
  265. # check if the provider need credentials
  266. if not provider.need_credentials:
  267. result_providers[provider.identity.name].is_team_authorization = True
  268. result_providers[provider.identity.name].allow_delete = False
  269. # get db builtin providers
  270. db_builtin_providers: List[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
  271. filter(BuiltinToolProvider.tenant_id == tenant_id).all()
  272. for db_builtin_provider in db_builtin_providers:
  273. # add provider into providers
  274. credentials = db_builtin_provider.credentials
  275. provider_name = db_builtin_provider.provider
  276. result_providers[provider_name].is_team_authorization = True
  277. # package builtin tool provider controller
  278. controller = ToolManager.get_builtin_provider(provider_name)
  279. # init tool configuration
  280. tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
  281. # decrypt the credentials and mask the credentials
  282. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
  283. masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
  284. result_providers[provider_name].team_credentials = masked_credentials
  285. # get db api providers
  286. db_api_providers: List[ApiToolProvider] = db.session.query(ApiToolProvider). \
  287. filter(ApiToolProvider.tenant_id == tenant_id).all()
  288. for db_api_provider in db_api_providers:
  289. username = 'Anonymous'
  290. try:
  291. username = db_api_provider.user.name
  292. except Exception as e:
  293. logger.error(f'failed to get user name for api provider {db_api_provider.id}: {str(e)}')
  294. # add provider into providers
  295. credentials = db_api_provider.credentials
  296. provider_name = db_api_provider.name
  297. result_providers[provider_name] = UserToolProvider(
  298. id=db_api_provider.id,
  299. author=username,
  300. name=db_api_provider.name,
  301. description=I18nObject(
  302. en_US=db_api_provider.description,
  303. zh_Hans=db_api_provider.description,
  304. ),
  305. icon=db_api_provider.icon,
  306. label=I18nObject(
  307. en_US=db_api_provider.name,
  308. zh_Hans=db_api_provider.name,
  309. ),
  310. type=UserToolProvider.ProviderType.API,
  311. team_credentials={},
  312. is_team_authorization=True,
  313. )
  314. # package tool provider controller
  315. controller = ApiBasedToolProviderController.from_db(
  316. db_provider=db_api_provider,
  317. auth_type=ApiProviderAuthType.API_KEY if db_api_provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
  318. )
  319. # init tool configuration
  320. tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
  321. # decrypt the credentials and mask the credentials
  322. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
  323. masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
  324. result_providers[provider_name].team_credentials = masked_credentials
  325. return BuiltinToolProviderSort.sort(list(result_providers.values()))
  326. @staticmethod
  327. def get_api_provider_controller(tenant_id: str, provider_id: str) -> Tuple[ApiBasedToolProviderController, Dict[str, Any]]:
  328. """
  329. get the api provider
  330. :param provider_name: the name of the provider
  331. :return: the provider controller, the credentials
  332. """
  333. provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
  334. ApiToolProvider.id == provider_id,
  335. ApiToolProvider.tenant_id == tenant_id,
  336. ).first()
  337. if provider is None:
  338. raise ToolProviderNotFoundError(f'api provider {provider_id} not found')
  339. controller = ApiBasedToolProviderController.from_db(
  340. provider, ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
  341. )
  342. controller.load_bundled_tools(provider.tools)
  343. return controller, provider.credentials
  344. @staticmethod
  345. def user_get_api_provider(provider: str, tenant_id: str) -> dict:
  346. """
  347. get api provider
  348. """
  349. """
  350. get tool provider
  351. """
  352. provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
  353. ApiToolProvider.tenant_id == tenant_id,
  354. ApiToolProvider.name == provider,
  355. ).first()
  356. if provider is None:
  357. raise ValueError(f'you have not added provider {provider}')
  358. try:
  359. credentials = json.loads(provider.credentials_str) or {}
  360. except:
  361. credentials = {}
  362. # package tool provider controller
  363. controller = ApiBasedToolProviderController.from_db(
  364. provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
  365. )
  366. # init tool configuration
  367. tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
  368. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
  369. masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
  370. try:
  371. icon = json.loads(provider.icon)
  372. except:
  373. icon = {
  374. "background": "#252525",
  375. "content": "\ud83d\ude01"
  376. }
  377. return json.loads(serialize_base_model_dict({
  378. 'schema_type': provider.schema_type,
  379. 'schema': provider.schema,
  380. 'tools': provider.tools,
  381. 'icon': icon,
  382. 'description': provider.description,
  383. 'credentials': masked_credentials,
  384. 'privacy_policy': provider.privacy_policy
  385. }))