api_tool_provider.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. from typing import Any, Dict, List
  2. from core.tools.entities.common_entities import I18nObject
  3. from core.tools.entities.tool_bundle import ApiBasedToolBundle
  4. from core.tools.entities.tool_entities import (ApiProviderAuthType, ToolCredentialsOption, ToolProviderCredentials,
  5. ToolProviderType)
  6. from core.tools.provider.tool_provider import ToolProviderController
  7. from core.tools.tool.api_tool import ApiTool
  8. from core.tools.tool.tool import Tool
  9. from extensions.ext_database import db
  10. from models.tools import ApiToolProvider
  11. class ApiBasedToolProviderController(ToolProviderController):
  12. @staticmethod
  13. def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiBasedToolProviderController':
  14. credentials_schema = {
  15. 'auth_type': ToolProviderCredentials(
  16. name='auth_type',
  17. required=True,
  18. type=ToolProviderCredentials.CredentialsType.SELECT,
  19. options=[
  20. ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='无')),
  21. ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))
  22. ],
  23. default='none',
  24. help=I18nObject(
  25. en_US='The auth type of the api provider',
  26. zh_Hans='api provider 的认证类型'
  27. )
  28. )
  29. }
  30. if auth_type == ApiProviderAuthType.API_KEY:
  31. credentials_schema = {
  32. **credentials_schema,
  33. 'api_key_header': ToolProviderCredentials(
  34. name='api_key_header',
  35. required=False,
  36. default='api_key',
  37. type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
  38. help=I18nObject(
  39. en_US='The header name of the api key',
  40. zh_Hans='携带 api key 的 header 名称'
  41. )
  42. ),
  43. 'api_key_value': ToolProviderCredentials(
  44. name='api_key_value',
  45. required=True,
  46. type=ToolProviderCredentials.CredentialsType.SECRET_INPUT,
  47. help=I18nObject(
  48. en_US='The api key',
  49. zh_Hans='api key的值'
  50. )
  51. )
  52. }
  53. elif auth_type == ApiProviderAuthType.NONE:
  54. pass
  55. else:
  56. raise ValueError(f'invalid auth type {auth_type}')
  57. return ApiBasedToolProviderController(**{
  58. 'identity': {
  59. 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
  60. 'name': db_provider.name,
  61. 'label': {
  62. 'en_US': db_provider.name,
  63. 'zh_Hans': db_provider.name
  64. },
  65. 'description': {
  66. 'en_US': db_provider.description,
  67. 'zh_Hans': db_provider.description
  68. },
  69. 'icon': db_provider.icon
  70. },
  71. 'credentials_schema': credentials_schema
  72. })
  73. @property
  74. def app_type(self) -> ToolProviderType:
  75. return ToolProviderType.API_BASED
  76. def _validate_credentials(self, tool_name: str, credentials: Dict[str, Any]) -> None:
  77. pass
  78. def validate_parameters(self, tool_name: str, tool_parameters: Dict[str, Any]) -> None:
  79. pass
  80. def _parse_tool_bundle(self, tool_bundle: ApiBasedToolBundle) -> ApiTool:
  81. """
  82. parse tool bundle to tool
  83. :param tool_bundle: the tool bundle
  84. :return: the tool
  85. """
  86. return ApiTool(**{
  87. 'api_bundle': tool_bundle,
  88. 'identity' : {
  89. 'author': tool_bundle.author,
  90. 'name': tool_bundle.operation_id,
  91. 'label': {
  92. 'en_US': tool_bundle.operation_id,
  93. 'zh_Hans': tool_bundle.operation_id
  94. },
  95. 'icon': tool_bundle.icon if tool_bundle.icon else ''
  96. },
  97. 'description': {
  98. 'human': {
  99. 'en_US': tool_bundle.summary or '',
  100. 'zh_Hans': tool_bundle.summary or ''
  101. },
  102. 'llm': tool_bundle.summary or ''
  103. },
  104. 'parameters' : tool_bundle.parameters if tool_bundle.parameters else [],
  105. })
  106. def load_bundled_tools(self, tools: List[ApiBasedToolBundle]) -> List[ApiTool]:
  107. """
  108. load bundled tools
  109. :param tools: the bundled tools
  110. :return: the tools
  111. """
  112. self.tools = [self._parse_tool_bundle(tool) for tool in tools]
  113. return self.tools
  114. def get_tools(self, user_id: str, tenant_id: str) -> List[ApiTool]:
  115. """
  116. fetch tools from database
  117. :param user_id: the user id
  118. :param tenant_id: the tenant id
  119. :return: the tools
  120. """
  121. if self.tools is not None:
  122. return self.tools
  123. tools: List[Tool] = []
  124. # get tenant api providers
  125. db_providers: List[ApiToolProvider] = db.session.query(ApiToolProvider).filter(
  126. ApiToolProvider.tenant_id == tenant_id,
  127. ApiToolProvider.name == self.identity.name
  128. ).all()
  129. if db_providers and len(db_providers) != 0:
  130. for db_provider in db_providers:
  131. for tool in db_provider.tools:
  132. assistant_tool = self._parse_tool_bundle(tool)
  133. assistant_tool.is_team_authorization = True
  134. tools.append(assistant_tool)
  135. self.tools = tools
  136. return tools
  137. def get_tool(self, tool_name: str) -> ApiTool:
  138. """
  139. get tool by name
  140. :param tool_name: the name of the tool
  141. :return: the tool
  142. """
  143. if self.tools is None:
  144. self.get_tools()
  145. for tool in self.tools:
  146. if tool.identity.name == tool_name:
  147. return tool
  148. raise ValueError(f'tool {tool_name} not found')