| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699 | import jsonimport loggingfrom httpx import getfrom core.model_runtime.utils.encoders import jsonable_encoderfrom core.tools.entities.common_entities import I18nObjectfrom core.tools.entities.tool_bundle import ApiBasedToolBundlefrom core.tools.entities.tool_entities import (    ApiProviderAuthType,    ApiProviderSchemaType,    ToolCredentialsOption,    ToolProviderCredentials,)from core.tools.entities.user_entities import UserTool, UserToolProviderfrom core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundErrorfrom core.tools.provider.api_tool_provider import ApiBasedToolProviderControllerfrom core.tools.provider.builtin._positions import BuiltinToolProviderSortfrom core.tools.provider.tool_provider import ToolProviderControllerfrom core.tools.tool_manager import ToolManagerfrom core.tools.utils.configuration import ToolConfigurationManagerfrom core.tools.utils.parser import ApiBasedToolSchemaParserfrom extensions.ext_database import dbfrom models.tools import ApiToolProvider, BuiltinToolProviderfrom services.model_provider_service import ModelProviderServicefrom services.tools_transform_service import ToolTransformServicelogger = logging.getLogger(__name__)class ToolManageService:    @staticmethod    def list_tool_providers(user_id: str, tenant_id: str):        """            list tool providers            :return: the list of tool providers        """        providers = ToolManager.user_list_providers(            user_id, tenant_id        )        # add icon        for provider in providers:            ToolTransformService.repack_provider(provider)        result = [provider.to_dict() for provider in providers]        return result        @staticmethod    def list_builtin_tool_provider_tools(        user_id: str, tenant_id: str, provider: str    ) -> list[UserTool]:        """            list builtin tool provider tools        """        provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)        tools = provider_controller.get_tools()        tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)        # check if user has added the provider        builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(            BuiltinToolProvider.tenant_id == tenant_id,            BuiltinToolProvider.provider == provider,        ).first()        credentials = {}        if builtin_provider is not None:            # get credentials            credentials = builtin_provider.credentials            credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)        result = []        for tool in tools:            result.append(ToolTransformService.tool_to_user_tool(                tool=tool, credentials=credentials, tenant_id=tenant_id            ))        return result        @staticmethod    def list_builtin_provider_credentials_schema(        provider_name    ):        """            list builtin provider credentials schema            :return: the list of tool providers        """        provider = ToolManager.get_builtin_provider(provider_name)        return jsonable_encoder([            v for _, v in (provider.credentials_schema or {}).items()        ])    @staticmethod    def parser_api_schema(schema: str) -> list[ApiBasedToolBundle]:        """            parse api schema to tool bundle        """        try:            warnings = {}            try:                tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings)            except Exception as e:                raise ValueError(f'invalid schema: {str(e)}')                        credentials_schema = [                ToolProviderCredentials(                    name='auth_type',                    type=ToolProviderCredentials.CredentialsType.SELECT,                    required=True,                    default='none',                    options=[                        ToolCredentialsOption(value='none', label=I18nObject(                            en_US='None',                            zh_Hans='无'                        )),                        ToolCredentialsOption(value='api_key', label=I18nObject(                            en_US='Api Key',                            zh_Hans='Api Key'                        )),                    ],                    placeholder=I18nObject(                        en_US='Select auth type',                        zh_Hans='选择认证方式'                    )                ),                ToolProviderCredentials(                    name='api_key_header',                    type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,                    required=False,                    placeholder=I18nObject(                        en_US='Enter api key header',                        zh_Hans='输入 api key header,如:X-API-KEY'                    ),                    default='api_key',                    help=I18nObject(                        en_US='HTTP header name for api key',                        zh_Hans='HTTP 头部字段名,用于传递 api key'                    )                ),                ToolProviderCredentials(                    name='api_key_value',                    type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,                    required=False,                    placeholder=I18nObject(                        en_US='Enter api key',                        zh_Hans='输入 api key'                    ),                    default=''                ),            ]            return jsonable_encoder({                'schema_type': schema_type,                'parameters_schema': tool_bundles,                'credentials_schema': credentials_schema,                'warning': warnings            })        except Exception as e:            raise ValueError(f'invalid schema: {str(e)}')    @staticmethod    def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiBasedToolBundle]:        """            convert schema to tool bundles            :return: the list of tool bundles, description        """        try:            tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)            return tool_bundles        except Exception as e:            raise ValueError(f'invalid schema: {str(e)}')    @staticmethod    def create_api_tool_provider(        user_id: str, tenant_id: str, provider_name: str, icon: dict, credentials: dict,        schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str    ):        """            create api tool provider        """        if schema_type not in [member.value for member in ApiProviderSchemaType]:            raise ValueError(f'invalid schema type {schema}')                # check if the provider exists        provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(            ApiToolProvider.tenant_id == tenant_id,            ApiToolProvider.name == provider_name,        ).first()        if provider is not None:            raise ValueError(f'provider {provider_name} already exists')        # parse openapi to tool bundle        extra_info = {}        # extra info like description will be set here        tool_bundles, schema_type = ToolManageService.convert_schema_to_tool_bundles(schema, extra_info)                if len(tool_bundles) > 100:            raise ValueError('the number of apis should be less than 100')        # create db provider        db_provider = ApiToolProvider(            tenant_id=tenant_id,            user_id=user_id,            name=provider_name,            icon=json.dumps(icon),            schema=schema,            description=extra_info.get('description', ''),            schema_type_str=schema_type,            tools_str=json.dumps(jsonable_encoder(tool_bundles)),            credentials_str={},            privacy_policy=privacy_policy,            custom_disclaimer=custom_disclaimer        )        if 'auth_type' not in credentials:            raise ValueError('auth_type is required')        # get auth type, none or api key        auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])        # create provider entity        provider_controller = ApiBasedToolProviderController.from_db(db_provider, auth_type)        # load tools into provider entity        provider_controller.load_bundled_tools(tool_bundles)        # encrypt credentials        tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)        encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials)        db_provider.credentials_str = json.dumps(encrypted_credentials)        db.session.add(db_provider)        db.session.commit()        return { 'result': 'success' }        @staticmethod    def get_api_tool_provider_remote_schema(        user_id: str, tenant_id: str, url: str    ):        """            get api tool provider remote schema        """        headers = {            "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0",            "Accept": "*/*",        }        try:            response = get(url, headers=headers, timeout=10)            if response.status_code != 200:                raise ValueError(f'Got status code {response.status_code}')            schema = response.text            # try to parse schema, avoid SSRF attack            ToolManageService.parser_api_schema(schema)        except Exception as e:            logger.error(f"parse api schema error: {str(e)}")            raise ValueError('invalid schema, please check the url you provided')                return {            'schema': schema        }    @staticmethod    def list_api_tool_provider_tools(        user_id: str, tenant_id: str, provider: str    ) -> list[UserTool]:        """            list api tool provider tools        """        provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(            ApiToolProvider.tenant_id == tenant_id,            ApiToolProvider.name == provider,        ).first()        if provider is None:            raise ValueError(f'you have not added provider {provider}')                return [            ToolTransformService.tool_to_user_tool(tool_bundle) for tool_bundle in provider.tools        ]    @staticmethod    def update_builtin_tool_provider(        user_id: str, tenant_id: str, provider_name: str, credentials: dict    ):        """            update builtin tool provider        """        # get if the provider exists        provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(            BuiltinToolProvider.tenant_id == tenant_id,            BuiltinToolProvider.provider == provider_name,        ).first()        try:             # get provider            provider_controller = ToolManager.get_builtin_provider(provider_name)            if not provider_controller.need_credentials:                raise ValueError(f'provider {provider_name} does not need credentials')            tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)            # get original credentials if exists            if provider is not None:                original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)                masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)                # check if the credential has changed, save the original credential                for name, value in credentials.items():                    if name in masked_credentials and value == masked_credentials[name]:                        credentials[name] = original_credentials[name]            # validate credentials            provider_controller.validate_credentials(credentials)            # encrypt credentials            credentials = tool_configuration.encrypt_tool_credentials(credentials)        except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e:            raise ValueError(str(e))        if provider is None:            # create provider            provider = BuiltinToolProvider(                tenant_id=tenant_id,                user_id=user_id,                provider=provider_name,                encrypted_credentials=json.dumps(credentials),            )            db.session.add(provider)            db.session.commit()        else:            provider.encrypted_credentials = json.dumps(credentials)            db.session.add(provider)            db.session.commit()            # delete cache            tool_configuration.delete_tool_credentials_cache()        return { 'result': 'success' }        @staticmethod    def get_builtin_tool_provider_credentials(        user_id: str, tenant_id: str, provider: str    ):        """            get builtin tool provider credentials        """        provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(            BuiltinToolProvider.tenant_id == tenant_id,            BuiltinToolProvider.provider == provider,        ).first()        if provider is None:            return {}                provider_controller = ToolManager.get_builtin_provider(provider.provider)        tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)        credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)        credentials = tool_configuration.mask_tool_credentials(credentials)        return credentials    @staticmethod    def update_api_tool_provider(        user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict,         schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str    ):        """            update api tool provider        """        if schema_type not in [member.value for member in ApiProviderSchemaType]:            raise ValueError(f'invalid schema type {schema}')                # check if the provider exists        provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(            ApiToolProvider.tenant_id == tenant_id,            ApiToolProvider.name == original_provider,        ).first()        if provider is None:            raise ValueError(f'api provider {provider_name} does not exists')        # parse openapi to tool bundle        extra_info = {}        # extra info like description will be set here        tool_bundles, schema_type = ToolManageService.convert_schema_to_tool_bundles(schema, extra_info)                # update db provider        provider.name = provider_name        provider.icon = json.dumps(icon)        provider.schema = schema        provider.description = extra_info.get('description', '')        provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value        provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))        provider.privacy_policy = privacy_policy        provider.custom_disclaimer = custom_disclaimer        if 'auth_type' not in credentials:            raise ValueError('auth_type is required')        # get auth type, none or api key        auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])        # create provider entity        provider_controller = ApiBasedToolProviderController.from_db(provider, auth_type)        # load tools into provider entity        provider_controller.load_bundled_tools(tool_bundles)        # get original credentials if exists        tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)        original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)        masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)        # check if the credential has changed, save the original credential        for name, value in credentials.items():            if name in masked_credentials and value == masked_credentials[name]:                credentials[name] = original_credentials[name]        credentials = tool_configuration.encrypt_tool_credentials(credentials)        provider.credentials_str = json.dumps(credentials)        db.session.add(provider)        db.session.commit()        # delete cache        tool_configuration.delete_tool_credentials_cache()        return { 'result': 'success' }        @staticmethod    def delete_builtin_tool_provider(        user_id: str, tenant_id: str, provider_name: str    ):        """            delete tool provider        """        provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(            BuiltinToolProvider.tenant_id == tenant_id,            BuiltinToolProvider.provider == provider_name,        ).first()        if provider is None:            raise ValueError(f'you have not added provider {provider_name}')                db.session.delete(provider)        db.session.commit()        # delete cache        provider_controller = ToolManager.get_builtin_provider(provider_name)        tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)        tool_configuration.delete_tool_credentials_cache()        return { 'result': 'success' }        @staticmethod    def get_builtin_tool_provider_icon(        provider: str    ):        """            get tool provider icon and it's mimetype        """        icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider)        with open(icon_path, 'rb') as f:            icon_bytes = f.read()        return icon_bytes, mime_type        @staticmethod    def get_model_tool_provider_icon(        provider: str    ):        """            get tool provider icon and it's mimetype        """                service = ModelProviderService()        icon_bytes, mime_type = service.get_model_provider_icon(provider=provider, icon_type='icon_small', lang='en_US')        if icon_bytes is None:            raise ValueError(f'provider {provider} does not exists')        return icon_bytes, mime_type        @staticmethod    def list_model_tool_provider_tools(        user_id: str, tenant_id: str, provider: str    ) -> list[UserTool]:        """            list model tool provider tools        """        provider_controller = ToolManager.get_model_provider(tenant_id=tenant_id, provider_name=provider)        tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)        result = [            UserTool(                author=tool.identity.author,                name=tool.identity.name,                label=tool.identity.label,                description=tool.description.human,                parameters=tool.parameters or []            ) for tool in tools        ]        return jsonable_encoder(result)        @staticmethod    def delete_api_tool_provider(        user_id: str, tenant_id: str, provider_name: str    ):        """            delete tool provider        """        provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(            ApiToolProvider.tenant_id == tenant_id,            ApiToolProvider.name == provider_name,        ).first()        if provider is None:            raise ValueError(f'you have not added provider {provider_name}')                db.session.delete(provider)        db.session.commit()        return { 'result': 'success' }        @staticmethod    def get_api_tool_provider(        user_id: str, tenant_id: str, provider: str    ):        """            get api tool provider        """        return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id)        @staticmethod    def test_api_tool_preview(        tenant_id: str,         provider_name: str,        tool_name: str,         credentials: dict,         parameters: dict,         schema_type: str,         schema: str    ):        """            test api tool before adding api tool provider        """        if schema_type not in [member.value for member in ApiProviderSchemaType]:            raise ValueError(f'invalid schema type {schema_type}')                try:            tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema)        except Exception as e:            raise ValueError('invalid schema')                # get tool bundle        tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None)        if tool_bundle is None:            raise ValueError(f'invalid tool name {tool_name}')                db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(            ApiToolProvider.tenant_id == tenant_id,            ApiToolProvider.name == provider_name,        ).first()        if not db_provider:            # create a fake db provider            db_provider = ApiToolProvider(                tenant_id='', user_id='', name='', icon='',                schema=schema,                description='',                schema_type_str=ApiProviderSchemaType.OPENAPI.value,                tools_str=json.dumps(jsonable_encoder(tool_bundles)),                credentials_str=json.dumps(credentials),            )        if 'auth_type' not in credentials:            raise ValueError('auth_type is required')        # get auth type, none or api key        auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])        # create provider entity        provider_controller = ApiBasedToolProviderController.from_db(db_provider, auth_type)        # load tools into provider entity        provider_controller.load_bundled_tools(tool_bundles)        # decrypt credentials        if db_provider.id:            tool_configuration = ToolConfigurationManager(                tenant_id=tenant_id,                 provider_controller=provider_controller            )            decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)            # check if the credential has changed, save the original credential            masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)            for name, value in credentials.items():                if name in masked_credentials and value == masked_credentials[name]:                    credentials[name] = decrypted_credentials[name]        try:            provider_controller.validate_credentials_format(credentials)            # get tool            tool = provider_controller.get_tool(tool_name)            tool = tool.fork_tool_runtime(meta={                'credentials': credentials,                'tenant_id': tenant_id,            })            result = tool.validate_credentials(credentials, parameters)        except Exception as e:            return { 'error': str(e) }                return { 'result': result or 'empty response' }        @staticmethod    def list_builtin_tools(        user_id: str, tenant_id: str    ) -> list[UserToolProvider]:        """            list builtin tools        """        # get all builtin providers        provider_controllers = ToolManager.list_builtin_providers()        # get all user added providers        db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter(            BuiltinToolProvider.tenant_id == tenant_id        ).all() or []        # find provider        find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)        result: list[UserToolProvider] = []        for provider_controller in provider_controllers:            # convert provider controller to user provider            user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(                provider_controller=provider_controller,                db_provider=find_provider(provider_controller.identity.name),                decrypt_credentials=True            )            # add icon            ToolTransformService.repack_provider(user_builtin_provider)            tools = provider_controller.get_tools()            for tool in tools:                user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool(                    tenant_id=tenant_id,                    tool=tool,                     credentials=user_builtin_provider.original_credentials,                 ))            result.append(user_builtin_provider)        return BuiltinToolProviderSort.sort(result)        @staticmethod    def list_api_tools(        user_id: str, tenant_id: str    ) -> list[UserToolProvider]:        """            list api tools        """        # get all api providers        db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter(            ApiToolProvider.tenant_id == tenant_id        ).all() or []        result: list[UserToolProvider] = []        for provider in db_providers:            # convert provider controller to user provider            provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider)            user_provider = ToolTransformService.api_provider_to_user_provider(                provider_controller,                db_provider=provider,                decrypt_credentials=True            )            # add icon            ToolTransformService.repack_provider(user_provider)            tools = provider_controller.get_tools(                user_id=user_id, tenant_id=tenant_id            )            for tool in tools:                user_provider.tools.append(ToolTransformService.tool_to_user_tool(                    tenant_id=tenant_id,                    tool=tool,                     credentials=user_provider.original_credentials,                 ))            result.append(user_provider)        return result
 |