tool_manager.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. import json
  2. import logging
  3. import mimetypes
  4. from collections.abc import Generator
  5. from os import listdir, path
  6. from threading import Lock
  7. from typing import Any, Union
  8. from flask import current_app
  9. from core.agent.entities import AgentToolEntity
  10. from core.model_runtime.utils.encoders import jsonable_encoder
  11. from core.provider_manager import ProviderManager
  12. from core.tools import *
  13. from core.tools.entities.common_entities import I18nObject
  14. from core.tools.entities.tool_entities import (
  15. ApiProviderAuthType,
  16. ToolParameter,
  17. )
  18. from core.tools.entities.user_entities import UserToolProvider
  19. from core.tools.errors import ToolProviderNotFoundError
  20. from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
  21. from core.tools.provider.builtin._positions import BuiltinToolProviderSort
  22. from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
  23. from core.tools.provider.model_tool_provider import ModelToolProviderController
  24. from core.tools.tool.api_tool import ApiTool
  25. from core.tools.tool.builtin_tool import BuiltinTool
  26. from core.tools.tool.tool import Tool
  27. from core.tools.utils.configuration import (
  28. ToolConfigurationManager,
  29. ToolParameterConfigurationManager,
  30. )
  31. from core.utils.module_import_helper import load_single_subclass_from_source
  32. from core.workflow.nodes.tool.entities import ToolEntity
  33. from extensions.ext_database import db
  34. from models.tools import ApiToolProvider, BuiltinToolProvider
  35. from services.tools_transform_service import ToolTransformService
  36. logger = logging.getLogger(__name__)
  37. class ToolManager:
  38. _builtin_provider_lock = Lock()
  39. _builtin_providers = {}
  40. _builtin_providers_loaded = False
  41. _builtin_tools_labels = {}
  42. @classmethod
  43. def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController:
  44. """
  45. get the builtin provider
  46. :param provider: the name of the provider
  47. :return: the provider
  48. """
  49. if len(cls._builtin_providers) == 0:
  50. # init the builtin providers
  51. cls.load_builtin_providers_cache()
  52. if provider not in cls._builtin_providers:
  53. raise ToolProviderNotFoundError(f'builtin provider {provider} not found')
  54. return cls._builtin_providers[provider]
  55. @classmethod
  56. def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool:
  57. """
  58. get the builtin tool
  59. :param provider: the name of the provider
  60. :param tool_name: the name of the tool
  61. :return: the provider, the tool
  62. """
  63. provider_controller = cls.get_builtin_provider(provider)
  64. tool = provider_controller.get_tool(tool_name)
  65. return tool
  66. @classmethod
  67. def get_tool(cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \
  68. -> Union[BuiltinTool, ApiTool]:
  69. """
  70. get the tool
  71. :param provider_type: the type of the provider
  72. :param provider_name: the name of the provider
  73. :param tool_name: the name of the tool
  74. :return: the tool
  75. """
  76. if provider_type == 'builtin':
  77. return cls.get_builtin_tool(provider_id, tool_name)
  78. elif provider_type == 'api':
  79. if tenant_id is None:
  80. raise ValueError('tenant id is required for api provider')
  81. api_provider, _ = cls.get_api_provider_controller(tenant_id, provider_id)
  82. return api_provider.get_tool(tool_name)
  83. elif provider_type == 'app':
  84. raise NotImplementedError('app provider not implemented')
  85. else:
  86. raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
  87. @classmethod
  88. def get_tool_runtime(cls, provider_type: str, provider_name: str, tool_name: str, tenant_id: str) \
  89. -> Union[BuiltinTool, ApiTool]:
  90. """
  91. get the tool runtime
  92. :param provider_type: the type of the provider
  93. :param provider_name: the name of the provider
  94. :param tool_name: the name of the tool
  95. :return: the tool
  96. """
  97. if provider_type == 'builtin':
  98. builtin_tool = cls.get_builtin_tool(provider_name, tool_name)
  99. # check if the builtin tool need credentials
  100. provider_controller = cls.get_builtin_provider(provider_name)
  101. if not provider_controller.need_credentials:
  102. return builtin_tool.fork_tool_runtime(meta={
  103. 'tenant_id': tenant_id,
  104. 'credentials': {},
  105. })
  106. # get credentials
  107. builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
  108. BuiltinToolProvider.tenant_id == tenant_id,
  109. BuiltinToolProvider.provider == provider_name,
  110. ).first()
  111. if builtin_provider is None:
  112. raise ToolProviderNotFoundError(f'builtin provider {provider_name} not found')
  113. # decrypt the credentials
  114. credentials = builtin_provider.credentials
  115. controller = cls.get_builtin_provider(provider_name)
  116. tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
  117. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
  118. return builtin_tool.fork_tool_runtime(meta={
  119. 'tenant_id': tenant_id,
  120. 'credentials': decrypted_credentials,
  121. 'runtime_parameters': {}
  122. })
  123. elif provider_type == 'api':
  124. if tenant_id is None:
  125. raise ValueError('tenant id is required for api provider')
  126. api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_name)
  127. # decrypt the credentials
  128. tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
  129. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
  130. return api_provider.get_tool(tool_name).fork_tool_runtime(meta={
  131. 'tenant_id': tenant_id,
  132. 'credentials': decrypted_credentials,
  133. })
  134. elif provider_type == 'model':
  135. if tenant_id is None:
  136. raise ValueError('tenant id is required for model provider')
  137. # get model provider
  138. model_provider = cls.get_model_provider(tenant_id, provider_name)
  139. # get tool
  140. model_tool = model_provider.get_tool(tool_name)
  141. return model_tool.fork_tool_runtime(meta={
  142. 'tenant_id': tenant_id,
  143. 'credentials': model_tool.model_configuration['model_instance'].credentials
  144. })
  145. elif provider_type == 'app':
  146. raise NotImplementedError('app provider not implemented')
  147. else:
  148. raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
  149. @classmethod
  150. def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]:
  151. """
  152. init runtime parameter
  153. """
  154. parameter_value = parameters.get(parameter_rule.name)
  155. if not parameter_value:
  156. # get default value
  157. parameter_value = parameter_rule.default
  158. if not parameter_value and parameter_rule.required:
  159. raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config")
  160. if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
  161. # check if tool_parameter_config in options
  162. options = list(map(lambda x: x.value, parameter_rule.options))
  163. if parameter_value not in options:
  164. raise ValueError(
  165. f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}")
  166. # convert tool parameter config to correct type
  167. try:
  168. if parameter_rule.type == ToolParameter.ToolParameterType.NUMBER:
  169. # check if tool parameter is integer
  170. if isinstance(parameter_value, int):
  171. parameter_value = parameter_value
  172. elif isinstance(parameter_value, float):
  173. parameter_value = parameter_value
  174. elif isinstance(parameter_value, str):
  175. if '.' in parameter_value:
  176. parameter_value = float(parameter_value)
  177. else:
  178. parameter_value = int(parameter_value)
  179. elif parameter_rule.type == ToolParameter.ToolParameterType.BOOLEAN:
  180. parameter_value = bool(parameter_value)
  181. elif parameter_rule.type not in [ToolParameter.ToolParameterType.SELECT,
  182. ToolParameter.ToolParameterType.STRING]:
  183. parameter_value = str(parameter_value)
  184. elif parameter_rule.type == ToolParameter.ToolParameterType:
  185. parameter_value = str(parameter_value)
  186. except Exception as e:
  187. raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} is not correct type")
  188. return parameter_value
  189. @classmethod
  190. def get_agent_tool_runtime(cls, tenant_id: str, agent_tool: AgentToolEntity) -> Tool:
  191. """
  192. get the agent tool runtime
  193. """
  194. tool_entity = cls.get_tool_runtime(
  195. provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id,
  196. tool_name=agent_tool.tool_name,
  197. tenant_id=tenant_id,
  198. )
  199. runtime_parameters = {}
  200. parameters = tool_entity.get_all_runtime_parameters()
  201. for parameter in parameters:
  202. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  203. # save tool parameter to tool entity memory
  204. value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
  205. runtime_parameters[parameter.name] = value
  206. # decrypt runtime parameters
  207. encryption_manager = ToolParameterConfigurationManager(
  208. tenant_id=tenant_id,
  209. tool_runtime=tool_entity,
  210. provider_name=agent_tool.provider_id,
  211. provider_type=agent_tool.provider_type,
  212. )
  213. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  214. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  215. return tool_entity
  216. @classmethod
  217. def get_workflow_tool_runtime(cls, tenant_id: str, workflow_tool: ToolEntity):
  218. """
  219. get the workflow tool runtime
  220. """
  221. tool_entity = cls.get_tool_runtime(
  222. provider_type=workflow_tool.provider_type,
  223. provider_name=workflow_tool.provider_id,
  224. tool_name=workflow_tool.tool_name,
  225. tenant_id=tenant_id,
  226. )
  227. runtime_parameters = {}
  228. parameters = tool_entity.get_all_runtime_parameters()
  229. for parameter in parameters:
  230. # save tool parameter to tool entity memory
  231. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  232. value = cls._init_runtime_parameter(parameter, workflow_tool.tool_configurations)
  233. runtime_parameters[parameter.name] = value
  234. # decrypt runtime parameters
  235. encryption_manager = ToolParameterConfigurationManager(
  236. tenant_id=tenant_id,
  237. tool_runtime=tool_entity,
  238. provider_name=workflow_tool.provider_id,
  239. provider_type=workflow_tool.provider_type,
  240. )
  241. if runtime_parameters:
  242. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  243. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  244. return tool_entity
  245. @classmethod
  246. def get_builtin_provider_icon(cls, provider: str) -> tuple[str, str]:
  247. """
  248. get the absolute path of the icon of the builtin provider
  249. :param provider: the name of the provider
  250. :return: the absolute path of the icon, the mime type of the icon
  251. """
  252. # get provider
  253. provider_controller = cls.get_builtin_provider(provider)
  254. absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets',
  255. provider_controller.identity.icon)
  256. # check if the icon exists
  257. if not path.exists(absolute_path):
  258. raise ToolProviderNotFoundError(f'builtin provider {provider} icon not found')
  259. # get the mime type
  260. mime_type, _ = mimetypes.guess_type(absolute_path)
  261. mime_type = mime_type or 'application/octet-stream'
  262. return absolute_path, mime_type
  263. @classmethod
  264. def list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
  265. # use cache first
  266. if cls._builtin_providers_loaded:
  267. yield from list(cls._builtin_providers.values())
  268. return
  269. with cls._builtin_provider_lock:
  270. if cls._builtin_providers_loaded:
  271. yield from list(cls._builtin_providers.values())
  272. return
  273. yield from cls._list_builtin_providers()
  274. @classmethod
  275. def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
  276. """
  277. list all the builtin providers
  278. """
  279. for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
  280. if provider.startswith('__'):
  281. continue
  282. if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider)):
  283. if provider.startswith('__'):
  284. continue
  285. # init provider
  286. try:
  287. provider_class = load_single_subclass_from_source(
  288. module_name=f'core.tools.provider.builtin.{provider}.{provider}',
  289. script_path=path.join(path.dirname(path.realpath(__file__)),
  290. 'provider', 'builtin', provider, f'{provider}.py'),
  291. parent_type=BuiltinToolProviderController)
  292. provider: BuiltinToolProviderController = provider_class()
  293. cls._builtin_providers[provider.identity.name] = provider
  294. for tool in provider.get_tools():
  295. cls._builtin_tools_labels[tool.identity.name] = tool.identity.label
  296. yield provider
  297. except Exception as e:
  298. logger.error(f'load builtin provider {provider} error: {e}')
  299. continue
  300. # set builtin providers loaded
  301. cls._builtin_providers_loaded = True
  302. @classmethod
  303. def load_builtin_providers_cache(cls):
  304. for _ in cls.list_builtin_providers():
  305. pass
  306. @classmethod
  307. def clear_builtin_providers_cache(cls):
  308. cls._builtin_providers = {}
  309. cls._builtin_providers_loaded = False
  310. # @classmethod
  311. # def list_model_providers(cls, tenant_id: str = None) -> list[ModelToolProviderController]:
  312. # """
  313. # list all the model providers
  314. # :return: the list of the model providers
  315. # """
  316. # tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff'
  317. # # get configurations
  318. # model_configurations = ModelToolConfigurationManager.get_all_configuration()
  319. # # get all providers
  320. # provider_manager = ProviderManager()
  321. # configurations = provider_manager.get_configurations(tenant_id).values()
  322. # # get model providers
  323. # model_providers: list[ModelToolProviderController] = []
  324. # for configuration in configurations:
  325. # # all the model tool should be configurated
  326. # if configuration.provider.provider not in model_configurations:
  327. # continue
  328. # if not ModelToolProviderController.is_configuration_valid(configuration):
  329. # continue
  330. # model_providers.append(ModelToolProviderController.from_db(configuration))
  331. # return model_providers
  332. @classmethod
  333. def get_model_provider(cls, tenant_id: str, provider_name: str) -> ModelToolProviderController:
  334. """
  335. get the model provider
  336. :param provider_name: the name of the provider
  337. :return: the provider
  338. """
  339. # get configurations
  340. provider_manager = ProviderManager()
  341. configurations = provider_manager.get_configurations(tenant_id)
  342. configuration = configurations.get(provider_name)
  343. if configuration is None:
  344. raise ToolProviderNotFoundError(f'model provider {provider_name} not found')
  345. return ModelToolProviderController.from_db(configuration)
  346. @classmethod
  347. def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
  348. """
  349. get the tool label
  350. :param tool_name: the name of the tool
  351. :return: the label of the tool
  352. """
  353. cls._builtin_tools_labels
  354. if len(cls._builtin_tools_labels) == 0:
  355. # init the builtin providers
  356. cls.load_builtin_providers_cache()
  357. if tool_name not in cls._builtin_tools_labels:
  358. return None
  359. return cls._builtin_tools_labels[tool_name]
  360. @classmethod
  361. def user_list_providers(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]:
  362. result_providers: dict[str, UserToolProvider] = {}
  363. # get builtin providers
  364. builtin_providers = cls.list_builtin_providers()
  365. # get db builtin providers
  366. db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
  367. filter(BuiltinToolProvider.tenant_id == tenant_id).all()
  368. find_db_builtin_provider = lambda provider: next(
  369. (x for x in db_builtin_providers if x.provider == provider),
  370. None
  371. )
  372. # append builtin providers
  373. for provider in builtin_providers:
  374. user_provider = ToolTransformService.builtin_provider_to_user_provider(
  375. provider_controller=provider,
  376. db_provider=find_db_builtin_provider(provider.identity.name),
  377. decrypt_credentials=False
  378. )
  379. result_providers[provider.identity.name] = user_provider
  380. # # get model tool providers
  381. # model_providers = cls.list_model_providers(tenant_id=tenant_id)
  382. # # append model providers
  383. # for provider in model_providers:
  384. # user_provider = ToolTransformService.model_provider_to_user_provider(
  385. # db_provider=provider,
  386. # )
  387. # result_providers[f'model_provider.{provider.identity.name}'] = user_provider
  388. # get db api providers
  389. db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
  390. filter(ApiToolProvider.tenant_id == tenant_id).all()
  391. for db_api_provider in db_api_providers:
  392. provider_controller = ToolTransformService.api_provider_to_controller(
  393. db_provider=db_api_provider,
  394. )
  395. user_provider = ToolTransformService.api_provider_to_user_provider(
  396. provider_controller=provider_controller,
  397. db_provider=db_api_provider,
  398. decrypt_credentials=False
  399. )
  400. result_providers[db_api_provider.name] = user_provider
  401. return BuiltinToolProviderSort.sort(list(result_providers.values()))
  402. @classmethod
  403. def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[
  404. ApiBasedToolProviderController, dict[str, Any]]:
  405. """
  406. get the api provider
  407. :param provider_name: the name of the provider
  408. :return: the provider controller, the credentials
  409. """
  410. provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
  411. ApiToolProvider.id == provider_id,
  412. ApiToolProvider.tenant_id == tenant_id,
  413. ).first()
  414. if provider is None:
  415. raise ToolProviderNotFoundError(f'api provider {provider_id} not found')
  416. controller = ApiBasedToolProviderController.from_db(
  417. provider,
  418. ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else
  419. ApiProviderAuthType.NONE
  420. )
  421. controller.load_bundled_tools(provider.tools)
  422. return controller, provider.credentials
  423. @classmethod
  424. def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
  425. """
  426. get api provider
  427. """
  428. """
  429. get tool provider
  430. """
  431. provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
  432. ApiToolProvider.tenant_id == tenant_id,
  433. ApiToolProvider.name == provider,
  434. ).first()
  435. if provider is None:
  436. raise ValueError(f'you have not added provider {provider}')
  437. try:
  438. credentials = json.loads(provider.credentials_str) or {}
  439. except:
  440. credentials = {}
  441. # package tool provider controller
  442. controller = ApiBasedToolProviderController.from_db(
  443. provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
  444. )
  445. # init tool configuration
  446. tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
  447. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
  448. masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
  449. try:
  450. icon = json.loads(provider.icon)
  451. except:
  452. icon = {
  453. "background": "#252525",
  454. "content": "\ud83d\ude01"
  455. }
  456. return jsonable_encoder({
  457. 'schema_type': provider.schema_type,
  458. 'schema': provider.schema,
  459. 'tools': provider.tools,
  460. 'icon': icon,
  461. 'description': provider.description,
  462. 'credentials': masked_credentials,
  463. 'privacy_policy': provider.privacy_policy
  464. })
  465. @classmethod
  466. def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]:
  467. """
  468. get the tool icon
  469. :param tenant_id: the id of the tenant
  470. :param provider_type: the type of the provider
  471. :param provider_id: the id of the provider
  472. :return:
  473. """
  474. provider_type = provider_type
  475. provider_id = provider_id
  476. if provider_type == 'builtin':
  477. return (current_app.config.get("CONSOLE_API_URL")
  478. + "/console/api/workspaces/current/tool-provider/builtin/"
  479. + provider_id
  480. + "/icon")
  481. elif provider_type == 'api':
  482. try:
  483. provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
  484. ApiToolProvider.tenant_id == tenant_id,
  485. ApiToolProvider.id == provider_id
  486. )
  487. return json.loads(provider.icon)
  488. except:
  489. return {
  490. "background": "#252525",
  491. "content": "\ud83d\ude01"
  492. }
  493. else:
  494. raise ValueError(f"provider type {provider_type} not found")
  495. ToolManager.load_builtin_providers_cache()