123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283 |
- from abc import ABC, abstractmethod
- from datetime import datetime
- from typing import Type, Optional
- from flask import current_app
- from pydantic import BaseModel
- from core.model_providers.error import QuotaExceededError, LLMBadRequestError
- from extensions.ext_database import db
- from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
- from core.model_providers.models.entity.provider import ProviderQuotaUnit
- from core.model_providers.rules import provider_rules
- from models.provider import Provider, ProviderType, ProviderModel
- class BaseModelProvider(BaseModel, ABC):
- provider: Provider
- class Config:
- """Configuration for this pydantic object."""
- arbitrary_types_allowed = True
- @property
- @abstractmethod
- def provider_name(self):
- """
- Returns the name of a provider.
- """
- raise NotImplementedError
- def get_rules(self):
- """
- Returns the rules of a provider.
- """
- return provider_rules[self.provider_name]
- def get_supported_model_list(self, model_type: ModelType) -> list[dict]:
- """
- get supported model object list for use.
- :param model_type:
- :return:
- """
- rules = self.get_rules()
- if 'custom' not in rules['support_provider_types']:
- return self._get_fixed_model_list(model_type)
- if 'model_flexibility' not in rules:
- return self._get_fixed_model_list(model_type)
- if rules['model_flexibility'] == 'fixed':
- return self._get_fixed_model_list(model_type)
- # get configurable provider models
- provider_models = db.session.query(ProviderModel).filter(
- ProviderModel.tenant_id == self.provider.tenant_id,
- ProviderModel.provider_name == self.provider.provider_name,
- ProviderModel.model_type == model_type.value,
- ProviderModel.is_valid == True
- ).order_by(ProviderModel.created_at.asc()).all()
- return [{
- 'id': provider_model.model_name,
- 'name': provider_model.model_name
- } for provider_model in provider_models]
- @abstractmethod
- def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
- """
- get supported model object list for use.
- :param model_type:
- :return:
- """
- raise NotImplementedError
- @abstractmethod
- def get_model_class(self, model_type: ModelType) -> Type:
- """
- get specific model class.
- :param model_type:
- :return:
- """
- raise NotImplementedError
- @classmethod
- @abstractmethod
- def is_provider_credentials_valid_or_raise(cls, credentials: dict):
- """
- check provider credentials valid.
- :param credentials:
- """
- raise NotImplementedError
- @classmethod
- @abstractmethod
- def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
- """
- encrypt provider credentials for save.
- :param tenant_id:
- :param credentials:
- :return:
- """
- raise NotImplementedError
- @abstractmethod
- def get_provider_credentials(self, obfuscated: bool = False) -> dict:
- """
- get credentials for llm use.
- :param obfuscated:
- :return:
- """
- raise NotImplementedError
- @classmethod
- @abstractmethod
- def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
- """
- check model credentials valid.
- :param model_name:
- :param model_type:
- :param credentials:
- """
- raise NotImplementedError
- @classmethod
- @abstractmethod
- def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
- credentials: dict) -> dict:
- """
- encrypt model credentials for save.
- :param tenant_id:
- :param model_name:
- :param model_type:
- :param credentials:
- :return:
- """
- raise NotImplementedError
- @abstractmethod
- def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
- """
- get model parameter rules.
- :param model_name:
- :param model_type:
- :return:
- """
- raise NotImplementedError
- @abstractmethod
- def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
- """
- get credentials for llm use.
- :param model_name:
- :param model_type:
- :param obfuscated:
- :return:
- """
- raise NotImplementedError
- @classmethod
- def is_provider_type_system_supported(cls) -> bool:
- return current_app.config['EDITION'] == 'CLOUD'
- def check_quota_over_limit(self):
- """
- check provider quota over limit.
- :return:
- """
- if self.provider.provider_type != ProviderType.SYSTEM.value:
- return
- rules = self.get_rules()
- if 'system' not in rules['support_provider_types']:
- return
- provider = db.session.query(Provider).filter(
- db.and_(
- Provider.id == self.provider.id,
- Provider.is_valid == True,
- Provider.quota_limit > Provider.quota_used
- )
- ).first()
- if not provider:
- raise QuotaExceededError()
- def deduct_quota(self, used_tokens: int = 0) -> None:
- """
- deduct available quota when provider type is system or paid.
- :return:
- """
- if self.provider.provider_type != ProviderType.SYSTEM.value:
- return
- rules = self.get_rules()
- if 'system' not in rules['support_provider_types']:
- return
- if not self.should_deduct_quota():
- return
- if 'system_config' not in rules:
- quota_unit = ProviderQuotaUnit.TIMES.value
- elif 'quota_unit' not in rules['system_config']:
- quota_unit = ProviderQuotaUnit.TIMES.value
- else:
- quota_unit = rules['system_config']['quota_unit']
- if quota_unit == ProviderQuotaUnit.TOKENS.value:
- used_quota = used_tokens
- else:
- used_quota = 1
- db.session.query(Provider).filter(
- Provider.tenant_id == self.provider.tenant_id,
- Provider.provider_name == self.provider.provider_name,
- Provider.provider_type == self.provider.provider_type,
- Provider.quota_type == self.provider.quota_type,
- Provider.quota_limit > Provider.quota_used
- ).update({'quota_used': Provider.quota_used + used_quota})
- db.session.commit()
- def should_deduct_quota(self):
- return False
- def update_last_used(self) -> None:
- """
- update last used time.
- :return:
- """
- db.session.query(Provider).filter(
- Provider.tenant_id == self.provider.tenant_id,
- Provider.provider_name == self.provider.provider_name
- ).update({'last_used': datetime.utcnow()})
- db.session.commit()
- def get_payment_info(self) -> Optional[dict]:
- """
- get product info if it payable.
- :return:
- """
- return None
- def _get_provider_model(self, model_name: str, model_type: ModelType) -> ProviderModel:
- """
- get provider model.
- :param model_name:
- :param model_type:
- :return:
- """
- provider_model = db.session.query(ProviderModel).filter(
- ProviderModel.tenant_id == self.provider.tenant_id,
- ProviderModel.provider_name == self.provider.provider_name,
- ProviderModel.model_name == model_name,
- ProviderModel.model_type == model_type.value,
- ProviderModel.is_valid == True
- ).first()
- if not provider_model:
- raise LLMBadRequestError(f"The model {model_name} does not exist. "
- f"Please check the configuration.")
- return provider_model
- class CredentialsValidateFailedError(Exception):
- pass
|