| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462 | import jsonimport loggingfrom typing import Optionalfrom httpx import getfrom core.model_runtime.utils.encoders import jsonable_encoderfrom core.tools.entities.api_entities import UserTool, UserToolProviderfrom core.tools.entities.common_entities import I18nObjectfrom core.tools.entities.tool_bundle import ApiToolBundlefrom core.tools.entities.tool_entities import (    ApiProviderAuthType,    ApiProviderSchemaType,    ToolCredentialsOption,    ToolProviderCredentials,)from core.tools.provider.api_tool_provider import ApiToolProviderControllerfrom core.tools.tool_label_manager import ToolLabelManagerfrom core.tools.tool_manager import ToolManagerfrom core.tools.utils.configuration import ToolConfigurationManagerfrom core.tools.utils.parser import ApiBasedToolSchemaParserfrom extensions.ext_database import dbfrom models.tools import ApiToolProviderfrom services.tools.tools_transform_service import ToolTransformServicelogger = logging.getLogger(__name__)class ApiToolManageService:    @staticmethod    def parser_api_schema(schema: str) -> list[ApiToolBundle]:        """        parse api schema to tool bundle        """        try:            warnings = {}            try:                tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings)            except Exception as e:                raise ValueError(f"invalid schema: {str(e)}")            credentials_schema = [                ToolProviderCredentials(                    name="auth_type",                    type=ToolProviderCredentials.CredentialsType.SELECT,                    required=True,                    default="none",                    options=[                        ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")),                        ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),                    ],                    placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),                ),                ToolProviderCredentials(                    name="api_key_header",                    type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,                    required=False,                    placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"),                    default="api_key",                    help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"),                ),                ToolProviderCredentials(                    name="api_key_value",                    type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,                    required=False,                    placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"),                    default="",                ),            ]            return jsonable_encoder(                {                    "schema_type": schema_type,                    "parameters_schema": tool_bundles,                    "credentials_schema": credentials_schema,                    "warning": warnings,                }            )        except Exception as e:            raise ValueError(f"invalid schema: {str(e)}")    @staticmethod    def convert_schema_to_tool_bundles(        schema: str, extra_info: Optional[dict] = None    ) -> tuple[list[ApiToolBundle], str]:        """        convert schema to tool bundles        :return: the list of tool bundles, description        """        try:            tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)            return tool_bundles        except Exception as e:            raise ValueError(f"invalid schema: {str(e)}")    @staticmethod    def create_api_tool_provider(        user_id: str,        tenant_id: str,        provider_name: str,        icon: dict,        credentials: dict,        schema_type: str,        schema: str,        privacy_policy: str,        custom_disclaimer: str,        labels: list[str],    ):        """        create api tool provider        """        if schema_type not in [member.value for member in ApiProviderSchemaType]:            raise ValueError(f"invalid schema type {schema}")        # check if the provider exists        provider: ApiToolProvider = (            db.session.query(ApiToolProvider)            .filter(                ApiToolProvider.tenant_id == tenant_id,                ApiToolProvider.name == provider_name,            )            .first()        )        if provider is not None:            raise ValueError(f"provider {provider_name} already exists")        # parse openapi to tool bundle        extra_info = {}        # extra info like description will be set here        tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)        if len(tool_bundles) > 100:            raise ValueError("the number of apis should be less than 100")        # create db provider        db_provider = ApiToolProvider(            tenant_id=tenant_id,            user_id=user_id,            name=provider_name,            icon=json.dumps(icon),            schema=schema,            description=extra_info.get("description", ""),            schema_type_str=schema_type,            tools_str=json.dumps(jsonable_encoder(tool_bundles)),            credentials_str={},            privacy_policy=privacy_policy,            custom_disclaimer=custom_disclaimer,        )        if "auth_type" not in credentials:            raise ValueError("auth_type is required")        # get auth type, none or api key        auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])        # create provider entity        provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)        # load tools into provider entity        provider_controller.load_bundled_tools(tool_bundles)        # encrypt credentials        tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)        encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials)        db_provider.credentials_str = json.dumps(encrypted_credentials)        db.session.add(db_provider)        db.session.commit()        # update labels        ToolLabelManager.update_tool_labels(provider_controller, labels)        return {"result": "success"}    @staticmethod    def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str):        """        get api tool provider remote schema        """        headers = {            "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko)"            " Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0",            "Accept": "*/*",        }        try:            response = get(url, headers=headers, timeout=10)            if response.status_code != 200:                raise ValueError(f"Got status code {response.status_code}")            schema = response.text            # try to parse schema, avoid SSRF attack            ApiToolManageService.parser_api_schema(schema)        except Exception as e:            logger.error(f"parse api schema error: {str(e)}")            raise ValueError("invalid schema, please check the url you provided")        return {"schema": schema}    @staticmethod    def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:        """        list api tool provider tools        """        provider: ApiToolProvider = (            db.session.query(ApiToolProvider)            .filter(                ApiToolProvider.tenant_id == tenant_id,                ApiToolProvider.name == provider,            )            .first()        )        if provider is None:            raise ValueError(f"you have not added provider {provider}")        controller = ToolTransformService.api_provider_to_controller(db_provider=provider)        labels = ToolLabelManager.get_tool_labels(controller)        return [            ToolTransformService.tool_to_user_tool(                tool_bundle,                labels=labels,            )            for tool_bundle in provider.tools        ]    @staticmethod    def update_api_tool_provider(        user_id: str,        tenant_id: str,        provider_name: str,        original_provider: str,        icon: dict,        credentials: dict,        schema_type: str,        schema: str,        privacy_policy: str,        custom_disclaimer: str,        labels: list[str],    ):        """        update api tool provider        """        if schema_type not in [member.value for member in ApiProviderSchemaType]:            raise ValueError(f"invalid schema type {schema}")        # check if the provider exists        provider: ApiToolProvider = (            db.session.query(ApiToolProvider)            .filter(                ApiToolProvider.tenant_id == tenant_id,                ApiToolProvider.name == original_provider,            )            .first()        )        if provider is None:            raise ValueError(f"api provider {provider_name} does not exists")        # parse openapi to tool bundle        extra_info = {}        # extra info like description will be set here        tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)        # update db provider        provider.name = provider_name        provider.icon = json.dumps(icon)        provider.schema = schema        provider.description = extra_info.get("description", "")        provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value        provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))        provider.privacy_policy = privacy_policy        provider.custom_disclaimer = custom_disclaimer        if "auth_type" not in credentials:            raise ValueError("auth_type is required")        # get auth type, none or api key        auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])        # create provider entity        provider_controller = ApiToolProviderController.from_db(provider, auth_type)        # load tools into provider entity        provider_controller.load_bundled_tools(tool_bundles)        # get original credentials if exists        tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)        original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)        masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)        # check if the credential has changed, save the original credential        for name, value in credentials.items():            if name in masked_credentials and value == masked_credentials[name]:                credentials[name] = original_credentials[name]        credentials = tool_configuration.encrypt_tool_credentials(credentials)        provider.credentials_str = json.dumps(credentials)        db.session.add(provider)        db.session.commit()        # delete cache        tool_configuration.delete_tool_credentials_cache()        # update labels        ToolLabelManager.update_tool_labels(provider_controller, labels)        return {"result": "success"}    @staticmethod    def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str):        """        delete tool provider        """        provider: ApiToolProvider = (            db.session.query(ApiToolProvider)            .filter(                ApiToolProvider.tenant_id == tenant_id,                ApiToolProvider.name == provider_name,            )            .first()        )        if provider is None:            raise ValueError(f"you have not added provider {provider_name}")        db.session.delete(provider)        db.session.commit()        return {"result": "success"}    @staticmethod    def get_api_tool_provider(user_id: str, tenant_id: str, provider: str):        """        get api tool provider        """        return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id)    @staticmethod    def test_api_tool_preview(        tenant_id: str,        provider_name: str,        tool_name: str,        credentials: dict,        parameters: dict,        schema_type: str,        schema: str,    ):        """        test api tool before adding api tool provider        """        if schema_type not in [member.value for member in ApiProviderSchemaType]:            raise ValueError(f"invalid schema type {schema_type}")        try:            tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema)        except Exception as e:            raise ValueError("invalid schema")        # get tool bundle        tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None)        if tool_bundle is None:            raise ValueError(f"invalid tool name {tool_name}")        db_provider: ApiToolProvider = (            db.session.query(ApiToolProvider)            .filter(                ApiToolProvider.tenant_id == tenant_id,                ApiToolProvider.name == provider_name,            )            .first()        )        if not db_provider:            # create a fake db provider            db_provider = ApiToolProvider(                tenant_id="",                user_id="",                name="",                icon="",                schema=schema,                description="",                schema_type_str=ApiProviderSchemaType.OPENAPI.value,                tools_str=json.dumps(jsonable_encoder(tool_bundles)),                credentials_str=json.dumps(credentials),            )        if "auth_type" not in credentials:            raise ValueError("auth_type is required")        # get auth type, none or api key        auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])        # create provider entity        provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)        # load tools into provider entity        provider_controller.load_bundled_tools(tool_bundles)        # decrypt credentials        if db_provider.id:            tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)            decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)            # check if the credential has changed, save the original credential            masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)            for name, value in credentials.items():                if name in masked_credentials and value == masked_credentials[name]:                    credentials[name] = decrypted_credentials[name]        try:            provider_controller.validate_credentials_format(credentials)            # get tool            tool = provider_controller.get_tool(tool_name)            tool = tool.fork_tool_runtime(                runtime={                    "credentials": credentials,                    "tenant_id": tenant_id,                }            )            result = tool.validate_credentials(credentials, parameters)        except Exception as e:            return {"error": str(e)}        return {"result": result or "empty response"}    @staticmethod    def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:        """        list api tools        """        # get all api providers        db_providers: list[ApiToolProvider] = (            db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or []        )        result: list[UserToolProvider] = []        for provider in db_providers:            # convert provider controller to user provider            provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider)            labels = ToolLabelManager.get_tool_labels(provider_controller)            user_provider = ToolTransformService.api_provider_to_user_provider(                provider_controller, db_provider=provider, decrypt_credentials=True            )            user_provider.labels = labels            # add icon            ToolTransformService.repack_provider(user_provider)            tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)            for tool in tools:                user_provider.tools.append(                    ToolTransformService.tool_to_user_tool(                        tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels                    )                )            result.append(user_provider)        return result
 |