base.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import base64
  2. from abc import ABC, abstractmethod
  3. from typing import Optional, Union
  4. from core import hosted_llm_credentials
  5. from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
  6. from extensions.ext_database import db
  7. from libs import rsa
  8. from models.account import Tenant
  9. from models.provider import Provider, ProviderType, ProviderName
  10. class BaseProvider(ABC):
  11. def __init__(self, tenant_id: str):
  12. self.tenant_id = tenant_id
  13. def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]:
  14. """
  15. Returns the decrypted API key for the given tenant_id and provider_name.
  16. If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
  17. If the provider is not found or not valid, raises a ProviderTokenNotInitError.
  18. """
  19. provider = self.get_provider(prefer_custom)
  20. if not provider:
  21. raise ProviderTokenNotInitError()
  22. if provider.provider_type == ProviderType.SYSTEM.value:
  23. quota_used = provider.quota_used if provider.quota_used is not None else 0
  24. quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
  25. if model_id and model_id == 'gpt-4':
  26. raise ModelCurrentlyNotSupportError()
  27. if quota_used >= quota_limit:
  28. raise QuotaExceededError()
  29. return self.get_hosted_credentials()
  30. else:
  31. return self.get_decrypted_token(provider.encrypted_config)
  32. def get_provider(self, prefer_custom: bool) -> Optional[Provider]:
  33. """
  34. Returns the Provider instance for the given tenant_id and provider_name.
  35. If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
  36. """
  37. return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom)
  38. @classmethod
  39. def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[Provider]:
  40. """
  41. Returns the Provider instance for the given tenant_id and provider_name.
  42. If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
  43. """
  44. query = db.session.query(Provider).filter(
  45. Provider.tenant_id == tenant_id
  46. )
  47. if provider_name:
  48. query = query.filter(Provider.provider_name == provider_name)
  49. providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()
  50. custom_provider = None
  51. system_provider = None
  52. for provider in providers:
  53. if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
  54. custom_provider = provider
  55. elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
  56. system_provider = provider
  57. if custom_provider:
  58. return custom_provider
  59. elif system_provider:
  60. return system_provider
  61. else:
  62. return None
  63. def get_hosted_credentials(self) -> str:
  64. if self.get_provider_name() != ProviderName.OPENAI:
  65. raise ProviderTokenNotInitError()
  66. if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
  67. raise ProviderTokenNotInitError()
  68. return hosted_llm_credentials.openai.api_key
  69. def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
  70. """
  71. Returns the provider configs.
  72. """
  73. try:
  74. config = self.get_provider_api_key()
  75. except:
  76. config = ''
  77. if obfuscated:
  78. return self.obfuscated_token(config)
  79. return config
  80. def obfuscated_token(self, token: str):
  81. return token[:6] + '*' * (len(token) - 8) + token[-2:]
  82. def get_token_type(self):
  83. return str
  84. def get_encrypted_token(self, config: Union[dict | str]):
  85. return self.encrypt_token(config)
  86. def get_decrypted_token(self, token: str):
  87. return self.decrypt_token(token)
  88. def encrypt_token(self, token):
  89. tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
  90. encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
  91. return base64.b64encode(encrypted_token).decode()
  92. def decrypt_token(self, token):
  93. return rsa.decrypt(base64.b64decode(token), self.tenant_id)
  94. @abstractmethod
  95. def get_provider_name(self):
  96. raise NotImplementedError
  97. @abstractmethod
  98. def get_credentials(self, model_id: Optional[str] = None) -> dict:
  99. raise NotImplementedError
  100. @abstractmethod
  101. def get_models(self, model_id: Optional[str] = None) -> list[dict]:
  102. raise NotImplementedError
  103. @abstractmethod
  104. def config_validate(self, config: str):
  105. raise NotImplementedError