| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451 | import jsonimport loggingfrom httpx import getfrom core.model_runtime.utils.encoders import jsonable_encoderfrom core.tools.entities.api_entities import UserTool, UserToolProviderfrom core.tools.entities.common_entities import I18nObjectfrom core.tools.entities.tool_bundle import ApiToolBundlefrom core.tools.entities.tool_entities import (    ApiProviderAuthType,    ApiProviderSchemaType,    ToolCredentialsOption,    ToolProviderCredentials,)from core.tools.provider.api_tool_provider import ApiToolProviderControllerfrom core.tools.tool_label_manager import ToolLabelManagerfrom core.tools.tool_manager import ToolManagerfrom core.tools.utils.configuration import ToolConfigurationManagerfrom core.tools.utils.parser import ApiBasedToolSchemaParserfrom extensions.ext_database import dbfrom models.tools import ApiToolProviderfrom services.tools.tools_transform_service import ToolTransformServicelogger = logging.getLogger(__name__)class ApiToolManageService:    @staticmethod    def parser_api_schema(schema: str) -> list[ApiToolBundle]:        """            parse api schema to tool bundle        """        try:            warnings = {}            try:                tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings)            except Exception as e:                raise ValueError(f'invalid schema: {str(e)}')                        credentials_schema = [                ToolProviderCredentials(                    name='auth_type',                    type=ToolProviderCredentials.CredentialsType.SELECT,                    required=True,                    default='none',                    options=[                        ToolCredentialsOption(value='none', label=I18nObject(                            en_US='None',                            zh_Hans='无'                        )),                        ToolCredentialsOption(value='api_key', label=I18nObject(                            en_US='Api Key',                            zh_Hans='Api Key'                        )),                    ],                    placeholder=I18nObject(                        en_US='Select auth type',                        zh_Hans='选择认证方式'                    )                ),                ToolProviderCredentials(                    name='api_key_header',                    type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,                    required=False,                    placeholder=I18nObject(                        en_US='Enter api key header',                        zh_Hans='输入 api key header,如:X-API-KEY'                    ),                    default='api_key',                    help=I18nObject(                        en_US='HTTP header name for api key',                        zh_Hans='HTTP 头部字段名,用于传递 api key'                    )                ),                ToolProviderCredentials(                    name='api_key_value',                    type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,                    required=False,                    placeholder=I18nObject(                        en_US='Enter api key',                        zh_Hans='输入 api key'                    ),                    default=''                ),            ]            return jsonable_encoder({                'schema_type': schema_type,                'parameters_schema': tool_bundles,                'credentials_schema': credentials_schema,                'warning': warnings            })        except Exception as e:            raise ValueError(f'invalid schema: {str(e)}')    @staticmethod    def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]:        """            convert schema to tool bundles            :return: the list of tool bundles, description        """        try:            tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)            return tool_bundles        except Exception as e:            raise ValueError(f'invalid schema: {str(e)}')    @staticmethod    def create_api_tool_provider(        user_id: str, tenant_id: str, provider_name: str, icon: dict, credentials: dict,        schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str]    ):        """            create api tool provider        """        if schema_type not in [member.value for member in ApiProviderSchemaType]:            raise ValueError(f'invalid schema type {schema}')                # check if the provider exists        provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(            ApiToolProvider.tenant_id == tenant_id,            ApiToolProvider.name == provider_name,        ).first()        if provider is not None:            raise ValueError(f'provider {provider_name} already exists')        # parse openapi to tool bundle        extra_info = {}        # extra info like description will be set here        tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)                if len(tool_bundles) > 100:            raise ValueError('the number of apis should be less than 100')        # create db provider        db_provider = ApiToolProvider(            tenant_id=tenant_id,            user_id=user_id,            name=provider_name,            icon=json.dumps(icon),            schema=schema,            description=extra_info.get('description', ''),            schema_type_str=schema_type,            tools_str=json.dumps(jsonable_encoder(tool_bundles)),            credentials_str={},            privacy_policy=privacy_policy,            custom_disclaimer=custom_disclaimer        )        if 'auth_type' not in credentials:            raise ValueError('auth_type is required')        # get auth type, none or api key        auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])        # create provider entity        provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)        # load tools into provider entity        provider_controller.load_bundled_tools(tool_bundles)        # encrypt credentials        tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)        encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials)        db_provider.credentials_str = json.dumps(encrypted_credentials)        db.session.add(db_provider)        db.session.commit()        # update labels        ToolLabelManager.update_tool_labels(provider_controller, labels)        return { 'result': 'success' }        @staticmethod    def get_api_tool_provider_remote_schema(        user_id: str, tenant_id: str, url: str    ):        """            get api tool provider remote schema        """        headers = {            "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0",            "Accept": "*/*",        }        try:            response = get(url, headers=headers, timeout=10)            if response.status_code != 200:                raise ValueError(f'Got status code {response.status_code}')            schema = response.text            # try to parse schema, avoid SSRF attack            ApiToolManageService.parser_api_schema(schema)        except Exception as e:            logger.error(f"parse api schema error: {str(e)}")            raise ValueError('invalid schema, please check the url you provided')                return {            'schema': schema        }    @staticmethod    def list_api_tool_provider_tools(        user_id: str, tenant_id: str, provider: str    ) -> list[UserTool]:        """            list api tool provider tools        """        provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(            ApiToolProvider.tenant_id == tenant_id,            ApiToolProvider.name == provider,        ).first()        if provider is None:            raise ValueError(f'you have not added provider {provider}')                controller = ToolTransformService.api_provider_to_controller(db_provider=provider)        labels = ToolLabelManager.get_tool_labels(controller)                return [            ToolTransformService.tool_to_user_tool(                tool_bundle,                labels=labels,            ) for tool_bundle in provider.tools        ]    @staticmethod    def update_api_tool_provider(        user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict,         schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str]    ):        """            update api tool provider        """        if schema_type not in [member.value for member in ApiProviderSchemaType]:            raise ValueError(f'invalid schema type {schema}')                # check if the provider exists        provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(            ApiToolProvider.tenant_id == tenant_id,            ApiToolProvider.name == original_provider,        ).first()        if provider is None:            raise ValueError(f'api provider {provider_name} does not exists')        # parse openapi to tool bundle        extra_info = {}        # extra info like description will be set here        tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)                # update db provider        provider.name = provider_name        provider.icon = json.dumps(icon)        provider.schema = schema        provider.description = extra_info.get('description', '')        provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value        provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))        provider.privacy_policy = privacy_policy        provider.custom_disclaimer = custom_disclaimer        if 'auth_type' not in credentials:            raise ValueError('auth_type is required')        # get auth type, none or api key        auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])        # create provider entity        provider_controller = ApiToolProviderController.from_db(provider, auth_type)        # load tools into provider entity        provider_controller.load_bundled_tools(tool_bundles)        # get original credentials if exists        tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)        original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)        masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)        # check if the credential has changed, save the original credential        for name, value in credentials.items():            if name in masked_credentials and value == masked_credentials[name]:                credentials[name] = original_credentials[name]        credentials = tool_configuration.encrypt_tool_credentials(credentials)        provider.credentials_str = json.dumps(credentials)        db.session.add(provider)        db.session.commit()        # delete cache        tool_configuration.delete_tool_credentials_cache()        # update labels        ToolLabelManager.update_tool_labels(provider_controller, labels)        return { 'result': 'success' }        @staticmethod    def delete_api_tool_provider(        user_id: str, tenant_id: str, provider_name: str    ):        """            delete tool provider        """        provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(            ApiToolProvider.tenant_id == tenant_id,            ApiToolProvider.name == provider_name,        ).first()        if provider is None:            raise ValueError(f'you have not added provider {provider_name}')                db.session.delete(provider)        db.session.commit()        return { 'result': 'success' }        @staticmethod    def get_api_tool_provider(        user_id: str, tenant_id: str, provider: str    ):        """            get api tool provider        """        return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id)        @staticmethod    def test_api_tool_preview(        tenant_id: str,         provider_name: str,        tool_name: str,         credentials: dict,         parameters: dict,         schema_type: str,         schema: str    ):        """            test api tool before adding api tool provider        """        if schema_type not in [member.value for member in ApiProviderSchemaType]:            raise ValueError(f'invalid schema type {schema_type}')                try:            tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema)        except Exception as e:            raise ValueError('invalid schema')                # get tool bundle        tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None)        if tool_bundle is None:            raise ValueError(f'invalid tool name {tool_name}')                db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(            ApiToolProvider.tenant_id == tenant_id,            ApiToolProvider.name == provider_name,        ).first()        if not db_provider:            # create a fake db provider            db_provider = ApiToolProvider(                tenant_id='', user_id='', name='', icon='',                schema=schema,                description='',                schema_type_str=ApiProviderSchemaType.OPENAPI.value,                tools_str=json.dumps(jsonable_encoder(tool_bundles)),                credentials_str=json.dumps(credentials),            )        if 'auth_type' not in credentials:            raise ValueError('auth_type is required')        # get auth type, none or api key        auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])        # create provider entity        provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)        # load tools into provider entity        provider_controller.load_bundled_tools(tool_bundles)        # decrypt credentials        if db_provider.id:            tool_configuration = ToolConfigurationManager(                tenant_id=tenant_id,                 provider_controller=provider_controller            )            decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)            # check if the credential has changed, save the original credential            masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)            for name, value in credentials.items():                if name in masked_credentials and value == masked_credentials[name]:                    credentials[name] = decrypted_credentials[name]        try:            provider_controller.validate_credentials_format(credentials)            # get tool            tool = provider_controller.get_tool(tool_name)            tool = tool.fork_tool_runtime(runtime={                'credentials': credentials,                'tenant_id': tenant_id,            })            result = tool.validate_credentials(credentials, parameters)        except Exception as e:            return { 'error': str(e) }                return { 'result': result or 'empty response' }        @staticmethod    def list_api_tools(        user_id: str, tenant_id: str    ) -> list[UserToolProvider]:        """            list api tools        """        # get all api providers        db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter(            ApiToolProvider.tenant_id == tenant_id        ).all() or []        result: list[UserToolProvider] = []        for provider in db_providers:            # convert provider controller to user provider            provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider)            labels = ToolLabelManager.get_tool_labels(provider_controller)            user_provider = ToolTransformService.api_provider_to_user_provider(                provider_controller,                db_provider=provider,                decrypt_credentials=True            )            user_provider.labels = labels            # add icon            ToolTransformService.repack_provider(user_provider)            tools = provider_controller.get_tools(                user_id=user_id, tenant_id=tenant_id            )            for tool in tools:                user_provider.tools.append(ToolTransformService.tool_to_user_tool(                    tenant_id=tenant_id,                    tool=tool,                     credentials=user_provider.original_credentials,                     labels=labels                ))            result.append(user_provider)        return result
 |