| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294 | from abc import ABC, abstractmethodfrom datetime import datetimefrom typing import Type, Optionalfrom flask import current_appfrom pydantic import BaseModelfrom core.model_providers.error import QuotaExceededError, LLMBadRequestErrorfrom extensions.ext_database import dbfrom core.model_providers.models.entity.model_params import ModelType, ModelKwargsRulesfrom core.model_providers.models.entity.provider import ProviderQuotaUnitfrom core.model_providers.rules import provider_rulesfrom models.provider import Provider, ProviderType, ProviderModelclass BaseModelProvider(BaseModel, ABC):    provider: Provider    class Config:        """Configuration for this pydantic object."""        arbitrary_types_allowed = True    @property    @abstractmethod    def provider_name(self):        """        Returns the name of a provider.        """        raise NotImplementedError    def get_rules(self):        """        Returns the rules of a provider.        """        return provider_rules[self.provider_name]    def get_supported_model_list(self, model_type: ModelType) -> list[dict]:        """        get supported model object list for use.        :param model_type:        :return:        """        rules = self.get_rules()        if 'custom' not in rules['support_provider_types']:            return self._get_fixed_model_list(model_type)        if 'model_flexibility' not in rules:            return self._get_fixed_model_list(model_type)        if rules['model_flexibility'] == 'fixed':            return self._get_fixed_model_list(model_type)        # get configurable provider models        provider_models = db.session.query(ProviderModel).filter(            ProviderModel.tenant_id == self.provider.tenant_id,            ProviderModel.provider_name == self.provider.provider_name,            ProviderModel.model_type == model_type.value,            ProviderModel.is_valid == True        ).order_by(ProviderModel.created_at.asc()).all()        provider_model_list = []        for provider_model in provider_models:            provider_model_dict = {                'id': provider_model.model_name,                'name': provider_model.model_name            }            if model_type == ModelType.TEXT_GENERATION:                provider_model_dict['mode'] = self._get_text_generation_model_mode(provider_model.model_name)            provider_model_list.append(provider_model_dict)        return provider_model_list    @abstractmethod    def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:        """        get supported model object list for use.        :param model_type:        :return:        """        raise NotImplementedError    @abstractmethod    def _get_text_generation_model_mode(self, model_name) -> str:        """        get text generation model mode.        :param model_name:        :return:        """        raise NotImplementedError    @abstractmethod    def get_model_class(self, model_type: ModelType) -> Type:        """        get specific model class.        :param model_type:        :return:        """        raise NotImplementedError    @classmethod    @abstractmethod    def is_provider_credentials_valid_or_raise(cls, credentials: dict):        """        check provider credentials valid.        :param credentials:        """        raise NotImplementedError    @classmethod    @abstractmethod    def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:        """        encrypt provider credentials for save.        :param tenant_id:        :param credentials:        :return:        """        raise NotImplementedError    @abstractmethod    def get_provider_credentials(self, obfuscated: bool = False) -> dict:        """        get credentials for llm use.        :param obfuscated:        :return:        """        raise NotImplementedError    @classmethod    @abstractmethod    def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):        """        check model credentials valid.        :param model_name:        :param model_type:        :param credentials:        """        raise NotImplementedError    @classmethod    @abstractmethod    def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,                                  credentials: dict) -> dict:        """        encrypt model credentials for save.        :param tenant_id:        :param model_name:        :param model_type:        :param credentials:        :return:        """        raise NotImplementedError    @abstractmethod    def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:        """        get model parameter rules.        :param model_name:        :param model_type:        :return:        """        raise NotImplementedError    @abstractmethod    def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:        """        get credentials for llm use.        :param model_name:        :param model_type:        :param obfuscated:        :return:        """        raise NotImplementedError    @classmethod    def is_provider_type_system_supported(cls) -> bool:        return current_app.config['EDITION'] == 'CLOUD'    def check_quota_over_limit(self):        """        check provider quota over limit.        :return:        """        if self.provider.provider_type != ProviderType.SYSTEM.value:            return        rules = self.get_rules()        if 'system' not in rules['support_provider_types']:            return        provider = db.session.query(Provider).filter(            db.and_(                Provider.id == self.provider.id,                Provider.is_valid == True,                Provider.quota_limit > Provider.quota_used            )        ).first()        if not provider:            raise QuotaExceededError()    def deduct_quota(self, used_tokens: int = 0) -> None:        """        deduct available quota when provider type is system or paid.        :return:        """        if self.provider.provider_type != ProviderType.SYSTEM.value:            return        rules = self.get_rules()        if 'system' not in rules['support_provider_types']:            return        if not self.should_deduct_quota():            return        if 'system_config' not in rules:            quota_unit = ProviderQuotaUnit.TIMES.value        elif 'quota_unit' not in rules['system_config']:            quota_unit = ProviderQuotaUnit.TIMES.value        else:            quota_unit = rules['system_config']['quota_unit']        if quota_unit == ProviderQuotaUnit.TOKENS.value:            used_quota = used_tokens        else:            used_quota = 1        db.session.query(Provider).filter(            Provider.tenant_id == self.provider.tenant_id,            Provider.provider_name == self.provider.provider_name,            Provider.provider_type == self.provider.provider_type,            Provider.quota_type == self.provider.quota_type,            Provider.quota_limit > Provider.quota_used        ).update({'quota_used': Provider.quota_used + used_quota})        db.session.commit()    def should_deduct_quota(self):        return False    def update_last_used(self) -> None:        """        update last used time.        :return:        """        db.session.query(Provider).filter(            Provider.tenant_id == self.provider.tenant_id,            Provider.provider_name == self.provider.provider_name        ).update({'last_used': datetime.utcnow()})        db.session.commit()    def _get_provider_model(self, model_name: str, model_type: ModelType) -> ProviderModel:        """        get provider model.        :param model_name:        :param model_type:        :return:        """        provider_model = db.session.query(ProviderModel).filter(            ProviderModel.tenant_id == self.provider.tenant_id,            ProviderModel.provider_name == self.provider.provider_name,            ProviderModel.model_name == model_name,            ProviderModel.model_type == model_type.value,            ProviderModel.is_valid == True        ).first()        if not provider_model:            raise LLMBadRequestError(f"The model {model_name} does not exist. "                                     f"Please check the configuration.")        return provider_modelclass CredentialsValidateFailedError(Exception):    pass
 |