builtin_tools_manage_service.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. import json
  2. import logging
  3. from core.model_runtime.utils.encoders import jsonable_encoder
  4. from core.tools.entities.api_entities import UserTool, UserToolProvider
  5. from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
  6. from core.tools.provider.builtin._positions import BuiltinToolProviderSort
  7. from core.tools.provider.tool_provider import ToolProviderController
  8. from core.tools.tool_label_manager import ToolLabelManager
  9. from core.tools.tool_manager import ToolManager
  10. from core.tools.utils.configuration import ToolConfigurationManager
  11. from extensions.ext_database import db
  12. from models.tools import BuiltinToolProvider
  13. from services.tools.tools_transform_service import ToolTransformService
  14. logger = logging.getLogger(__name__)
  15. class BuiltinToolManageService:
  16. @staticmethod
  17. def list_builtin_tool_provider_tools(
  18. user_id: str, tenant_id: str, provider: str
  19. ) -> list[UserTool]:
  20. """
  21. list builtin tool provider tools
  22. """
  23. provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
  24. tools = provider_controller.get_tools()
  25. tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
  26. # check if user has added the provider
  27. builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
  28. BuiltinToolProvider.tenant_id == tenant_id,
  29. BuiltinToolProvider.provider == provider,
  30. ).first()
  31. credentials = {}
  32. if builtin_provider is not None:
  33. # get credentials
  34. credentials = builtin_provider.credentials
  35. credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)
  36. result = []
  37. for tool in tools:
  38. result.append(ToolTransformService.tool_to_user_tool(
  39. tool=tool,
  40. credentials=credentials,
  41. tenant_id=tenant_id,
  42. labels=ToolLabelManager.get_tool_labels(provider_controller)
  43. ))
  44. return result
  45. @staticmethod
  46. def list_builtin_provider_credentials_schema(
  47. provider_name
  48. ):
  49. """
  50. list builtin provider credentials schema
  51. :return: the list of tool providers
  52. """
  53. provider = ToolManager.get_builtin_provider(provider_name)
  54. return jsonable_encoder([
  55. v for _, v in (provider.credentials_schema or {}).items()
  56. ])
  57. @staticmethod
  58. def update_builtin_tool_provider(
  59. user_id: str, tenant_id: str, provider_name: str, credentials: dict
  60. ):
  61. """
  62. update builtin tool provider
  63. """
  64. # get if the provider exists
  65. provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
  66. BuiltinToolProvider.tenant_id == tenant_id,
  67. BuiltinToolProvider.provider == provider_name,
  68. ).first()
  69. try:
  70. # get provider
  71. provider_controller = ToolManager.get_builtin_provider(provider_name)
  72. if not provider_controller.need_credentials:
  73. raise ValueError(f'provider {provider_name} does not need credentials')
  74. tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
  75. # get original credentials if exists
  76. if provider is not None:
  77. original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
  78. masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
  79. # check if the credential has changed, save the original credential
  80. for name, value in credentials.items():
  81. if name in masked_credentials and value == masked_credentials[name]:
  82. credentials[name] = original_credentials[name]
  83. # validate credentials
  84. provider_controller.validate_credentials(credentials)
  85. # encrypt credentials
  86. credentials = tool_configuration.encrypt_tool_credentials(credentials)
  87. except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e:
  88. raise ValueError(str(e))
  89. if provider is None:
  90. # create provider
  91. provider = BuiltinToolProvider(
  92. tenant_id=tenant_id,
  93. user_id=user_id,
  94. provider=provider_name,
  95. encrypted_credentials=json.dumps(credentials),
  96. )
  97. db.session.add(provider)
  98. db.session.commit()
  99. else:
  100. provider.encrypted_credentials = json.dumps(credentials)
  101. db.session.add(provider)
  102. db.session.commit()
  103. # delete cache
  104. tool_configuration.delete_tool_credentials_cache()
  105. return { 'result': 'success' }
  106. @staticmethod
  107. def get_builtin_tool_provider_credentials(
  108. user_id: str, tenant_id: str, provider: str
  109. ):
  110. """
  111. get builtin tool provider credentials
  112. """
  113. provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
  114. BuiltinToolProvider.tenant_id == tenant_id,
  115. BuiltinToolProvider.provider == provider,
  116. ).first()
  117. if provider is None:
  118. return {}
  119. provider_controller = ToolManager.get_builtin_provider(provider.provider)
  120. tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
  121. credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
  122. credentials = tool_configuration.mask_tool_credentials(credentials)
  123. return credentials
  124. @staticmethod
  125. def delete_builtin_tool_provider(
  126. user_id: str, tenant_id: str, provider_name: str
  127. ):
  128. """
  129. delete tool provider
  130. """
  131. provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
  132. BuiltinToolProvider.tenant_id == tenant_id,
  133. BuiltinToolProvider.provider == provider_name,
  134. ).first()
  135. if provider is None:
  136. raise ValueError(f'you have not added provider {provider_name}')
  137. db.session.delete(provider)
  138. db.session.commit()
  139. # delete cache
  140. provider_controller = ToolManager.get_builtin_provider(provider_name)
  141. tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
  142. tool_configuration.delete_tool_credentials_cache()
  143. return { 'result': 'success' }
  144. @staticmethod
  145. def get_builtin_tool_provider_icon(
  146. provider: str
  147. ):
  148. """
  149. get tool provider icon and it's mimetype
  150. """
  151. icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider)
  152. with open(icon_path, 'rb') as f:
  153. icon_bytes = f.read()
  154. return icon_bytes, mime_type
  155. @staticmethod
  156. def list_builtin_tools(
  157. user_id: str, tenant_id: str
  158. ) -> list[UserToolProvider]:
  159. """
  160. list builtin tools
  161. """
  162. # get all builtin providers
  163. provider_controllers = ToolManager.list_builtin_providers()
  164. # get all user added providers
  165. db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter(
  166. BuiltinToolProvider.tenant_id == tenant_id
  167. ).all() or []
  168. # find provider
  169. find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
  170. result: list[UserToolProvider] = []
  171. for provider_controller in provider_controllers:
  172. try:
  173. # convert provider controller to user provider
  174. user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
  175. provider_controller=provider_controller,
  176. db_provider=find_provider(provider_controller.identity.name),
  177. decrypt_credentials=True
  178. )
  179. # add icon
  180. ToolTransformService.repack_provider(user_builtin_provider)
  181. tools = provider_controller.get_tools()
  182. for tool in tools:
  183. user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool(
  184. tenant_id=tenant_id,
  185. tool=tool,
  186. credentials=user_builtin_provider.original_credentials,
  187. labels=ToolLabelManager.get_tool_labels(provider_controller)
  188. ))
  189. result.append(user_builtin_provider)
  190. except Exception as e:
  191. raise e
  192. return BuiltinToolProviderSort.sort(result)