| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193 | from typing import Anyfrom core.tools.entities.common_entities import I18nObjectfrom core.tools.entities.tool_bundle import ApiBasedToolBundlefrom core.tools.entities.tool_entities import (    ApiProviderAuthType,    ToolCredentialsOption,    ToolProviderCredentials,    ToolProviderType,)from core.tools.provider.tool_provider import ToolProviderControllerfrom core.tools.tool.api_tool import ApiToolfrom core.tools.tool.tool import Toolfrom extensions.ext_database import dbfrom models.tools import ApiToolProviderclass ApiBasedToolProviderController(ToolProviderController):    provider_id: str    @staticmethod    def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiBasedToolProviderController':        credentials_schema = {            'auth_type': ToolProviderCredentials(                name='auth_type',                required=True,                type=ToolProviderCredentials.CredentialsType.SELECT,                options=[                    ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='无')),                    ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))                ],                default='none',                help=I18nObject(                    en_US='The auth type of the api provider',                    zh_Hans='api provider 的认证类型'                )            )        }        if auth_type == ApiProviderAuthType.API_KEY:            credentials_schema = {                **credentials_schema,                'api_key_header': ToolProviderCredentials(                    name='api_key_header',                    required=False,                    default='api_key',                    type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,                    help=I18nObject(                        en_US='The header name of the api key',                        zh_Hans='携带 api key 的 header 名称'                    )                ),                'api_key_value': ToolProviderCredentials(                    name='api_key_value',                    required=True,                    type=ToolProviderCredentials.CredentialsType.SECRET_INPUT,                    help=I18nObject(                        en_US='The api key',                        zh_Hans='api key的值'                    )                ),                'api_key_header_prefix': ToolProviderCredentials(                    name='api_key_header_prefix',                    required=False,                    default='basic',                    type=ToolProviderCredentials.CredentialsType.SELECT,                    help=I18nObject(                        en_US='The prefix of the api key header',                        zh_Hans='api key header 的前缀'                    ),                    options=[                        ToolCredentialsOption(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')),                        ToolCredentialsOption(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')),                        ToolCredentialsOption(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom'))                    ]                )            }        elif auth_type == ApiProviderAuthType.NONE:            pass        else:            raise ValueError(f'invalid auth type {auth_type}')        return ApiBasedToolProviderController(**{            'identity': {                'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',                'name': db_provider.name,                'label': {                    'en_US': db_provider.name,                    'zh_Hans': db_provider.name                },                'description': {                    'en_US': db_provider.description,                    'zh_Hans': db_provider.description                },                'icon': db_provider.icon,            },            'credentials_schema': credentials_schema,            'provider_id': db_provider.id or '',        })    @property    def app_type(self) -> ToolProviderType:        return ToolProviderType.API_BASED        def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:        pass    def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None:        pass    def _parse_tool_bundle(self, tool_bundle: ApiBasedToolBundle) -> ApiTool:        """            parse tool bundle to tool            :param tool_bundle: the tool bundle            :return: the tool        """        return ApiTool(**{            'api_bundle': tool_bundle,            'identity' : {                'author': tool_bundle.author,                'name': tool_bundle.operation_id,                'label': {                    'en_US': tool_bundle.operation_id,                    'zh_Hans': tool_bundle.operation_id                },                'icon': self.identity.icon,                'provider': self.provider_id,            },            'description': {                'human': {                    'en_US': tool_bundle.summary or '',                    'zh_Hans': tool_bundle.summary or ''                },                'llm': tool_bundle.summary or ''            },            'parameters' : tool_bundle.parameters if tool_bundle.parameters else [],        })    def load_bundled_tools(self, tools: list[ApiBasedToolBundle]) -> list[ApiTool]:        """            load bundled tools            :param tools: the bundled tools            :return: the tools        """        self.tools = [self._parse_tool_bundle(tool) for tool in tools]        return self.tools    def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]:        """            fetch tools from database            :param user_id: the user id            :param tenant_id: the tenant id            :return: the tools        """        if self.tools is not None:            return self.tools                tools: list[Tool] = []        # get tenant api providers        db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter(            ApiToolProvider.tenant_id == tenant_id,            ApiToolProvider.name == self.identity.name        ).all()        if db_providers and len(db_providers) != 0:            for db_provider in db_providers:                for tool in db_provider.tools:                    assistant_tool = self._parse_tool_bundle(tool)                    assistant_tool.is_team_authorization = True                    tools.append(assistant_tool)                self.tools = tools        return tools        def get_tool(self, tool_name: str) -> ApiTool:        """            get tool by name            :param tool_name: the name of the tool            :return: the tool        """        if self.tools is None:            self.get_tools()        for tool in self.tools:            if tool.identity.name == tool_name:                return tool        raise ValueError(f'tool {tool_name} not found')
 |