api_tool_provider.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. from typing import Any, Dict, List
  2. from core.tools.entities.tool_entities import ToolProviderType, ApiProviderAuthType, ToolProviderCredentials, ToolCredentialsOption
  3. from core.tools.entities.common_entities import I18nObject
  4. from core.tools.entities.tool_bundle import ApiBasedToolBundle
  5. from core.tools.tool.tool import Tool
  6. from core.tools.tool.api_tool import ApiTool
  7. from core.tools.provider.tool_provider import ToolProviderController
  8. from extensions.ext_database import db
  9. from models.tools import ApiToolProvider
  10. class ApiBasedToolProviderController(ToolProviderController):
  11. @staticmethod
  12. def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiBasedToolProviderController':
  13. credentials_schema = {
  14. 'auth_type': ToolProviderCredentials(
  15. name='auth_type',
  16. required=True,
  17. type=ToolProviderCredentials.CredentialsType.SELECT,
  18. options=[
  19. ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='无')),
  20. ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))
  21. ],
  22. default='none',
  23. help=I18nObject(
  24. en_US='The auth type of the api provider',
  25. zh_Hans='api provider 的认证类型'
  26. )
  27. )
  28. }
  29. if auth_type == ApiProviderAuthType.API_KEY:
  30. credentials_schema = {
  31. **credentials_schema,
  32. 'api_key_header': ToolProviderCredentials(
  33. name='api_key_header',
  34. required=False,
  35. default='api_key',
  36. type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
  37. help=I18nObject(
  38. en_US='The header name of the api key',
  39. zh_Hans='携带 api key 的 header 名称'
  40. )
  41. ),
  42. 'api_key_value': ToolProviderCredentials(
  43. name='api_key_value',
  44. required=True,
  45. type=ToolProviderCredentials.CredentialsType.SECRET_INPUT,
  46. help=I18nObject(
  47. en_US='The api key',
  48. zh_Hans='api key的值'
  49. )
  50. )
  51. }
  52. elif auth_type == ApiProviderAuthType.NONE:
  53. pass
  54. else:
  55. raise ValueError(f'invalid auth type {auth_type}')
  56. return ApiBasedToolProviderController(**{
  57. 'identity': {
  58. 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
  59. 'name': db_provider.name,
  60. 'label': {
  61. 'en_US': db_provider.name,
  62. 'zh_Hans': db_provider.name
  63. },
  64. 'description': {
  65. 'en_US': db_provider.description,
  66. 'zh_Hans': db_provider.description
  67. },
  68. 'icon': db_provider.icon
  69. },
  70. 'credentials_schema': credentials_schema
  71. })
  72. @property
  73. def app_type(self) -> ToolProviderType:
  74. return ToolProviderType.API_BASED
  75. def _validate_credentials(self, tool_name: str, credentials: Dict[str, Any]) -> None:
  76. pass
  77. def validate_parameters(self, tool_name: str, tool_parameters: Dict[str, Any]) -> None:
  78. pass
  79. def _parse_tool_bundle(self, tool_bundle: ApiBasedToolBundle) -> ApiTool:
  80. """
  81. parse tool bundle to tool
  82. :param tool_bundle: the tool bundle
  83. :return: the tool
  84. """
  85. return ApiTool(**{
  86. 'api_bundle': tool_bundle,
  87. 'identity' : {
  88. 'author': tool_bundle.author,
  89. 'name': tool_bundle.operation_id,
  90. 'label': {
  91. 'en_US': tool_bundle.operation_id,
  92. 'zh_Hans': tool_bundle.operation_id
  93. },
  94. 'icon': tool_bundle.icon if tool_bundle.icon else ''
  95. },
  96. 'description': {
  97. 'human': {
  98. 'en_US': tool_bundle.summary or '',
  99. 'zh_Hans': tool_bundle.summary or ''
  100. },
  101. 'llm': tool_bundle.summary or ''
  102. },
  103. 'parameters' : tool_bundle.parameters if tool_bundle.parameters else [],
  104. })
  105. def load_bundled_tools(self, tools: List[ApiBasedToolBundle]) -> List[ApiTool]:
  106. """
  107. load bundled tools
  108. :param tools: the bundled tools
  109. :return: the tools
  110. """
  111. self.tools = [self._parse_tool_bundle(tool) for tool in tools]
  112. return self.tools
  113. def get_tools(self, user_id: str, tanent_id: str) -> List[ApiTool]:
  114. """
  115. fetch tools from database
  116. :param user_id: the user id
  117. :param tanent_id: the tanent id
  118. :return: the tools
  119. """
  120. if self.tools is not None:
  121. return self.tools
  122. tools: List[Tool] = []
  123. # get tanent api providers
  124. db_providers: List[ApiToolProvider] = db.session.query(ApiToolProvider).filter(
  125. ApiToolProvider.tenant_id == tanent_id,
  126. ApiToolProvider.name == self.identity.name
  127. ).all()
  128. if db_providers and len(db_providers) != 0:
  129. for db_provider in db_providers:
  130. for tool in db_provider.tools:
  131. assistant_tool = self._parse_tool_bundle(tool)
  132. assistant_tool.is_team_authorization = True
  133. tools.append(assistant_tool)
  134. self.tools = tools
  135. return tools
  136. def get_tool(self, tool_name: str) -> ApiTool:
  137. """
  138. get tool by name
  139. :param tool_name: the name of the tool
  140. :return: the tool
  141. """
  142. if self.tools is None:
  143. self.get_tools()
  144. for tool in self.tools:
  145. if tool.identity.name == tool_name:
  146. return tool
  147. raise ValueError(f'tool {tool_name} not found')