base.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import base64
  2. from abc import ABC, abstractmethod
  3. from typing import Optional, Union
  4. from core import hosted_llm_credentials
  5. from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
  6. from extensions.ext_database import db
  7. from libs import rsa
  8. from models.account import Tenant
  9. from models.provider import Provider, ProviderType, ProviderName
  10. class BaseProvider(ABC):
  11. def __init__(self, tenant_id: str):
  12. self.tenant_id = tenant_id
  13. def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> str:
  14. """
  15. Returns the decrypted API key for the given tenant_id and provider_name.
  16. If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
  17. If the provider is not found or not valid, raises a ProviderTokenNotInitError.
  18. """
  19. provider = self.get_provider(prefer_custom)
  20. if not provider:
  21. raise ProviderTokenNotInitError()
  22. if provider.provider_type == ProviderType.SYSTEM.value:
  23. quota_used = provider.quota_used if provider.quota_used is not None else 0
  24. quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
  25. if model_id and model_id == 'gpt-4':
  26. raise ModelCurrentlyNotSupportError()
  27. if quota_used >= quota_limit:
  28. raise QuotaExceededError()
  29. return self.get_hosted_credentials()
  30. else:
  31. return self.get_decrypted_token(provider.encrypted_config)
  32. def get_provider(self, prefer_custom: bool) -> Optional[Provider]:
  33. """
  34. Returns the Provider instance for the given tenant_id and provider_name.
  35. If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
  36. """
  37. providers = db.session.query(Provider).filter(
  38. Provider.tenant_id == self.tenant_id,
  39. Provider.provider_name == self.get_provider_name().value
  40. ).order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()
  41. custom_provider = None
  42. system_provider = None
  43. for provider in providers:
  44. if provider.provider_type == ProviderType.CUSTOM.value:
  45. custom_provider = provider
  46. elif provider.provider_type == ProviderType.SYSTEM.value:
  47. system_provider = provider
  48. if custom_provider and custom_provider.is_valid and custom_provider.encrypted_config:
  49. return custom_provider
  50. elif system_provider and system_provider.is_valid:
  51. return system_provider
  52. else:
  53. return None
  54. def get_hosted_credentials(self) -> str:
  55. if self.get_provider_name() != ProviderName.OPENAI:
  56. raise ProviderTokenNotInitError()
  57. if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
  58. raise ProviderTokenNotInitError()
  59. return hosted_llm_credentials.openai.api_key
  60. def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
  61. """
  62. Returns the provider configs.
  63. """
  64. try:
  65. config = self.get_provider_api_key()
  66. except:
  67. config = 'THIS-IS-A-MOCK-TOKEN'
  68. if obfuscated:
  69. return self.obfuscated_token(config)
  70. return config
  71. def obfuscated_token(self, token: str):
  72. return token[:6] + '*' * (len(token) - 8) + token[-2:]
  73. def get_token_type(self):
  74. return str
  75. def get_encrypted_token(self, config: Union[dict | str]):
  76. return self.encrypt_token(config)
  77. def get_decrypted_token(self, token: str):
  78. return self.decrypt_token(token)
  79. def encrypt_token(self, token):
  80. tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
  81. encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
  82. return base64.b64encode(encrypted_token).decode()
  83. def decrypt_token(self, token):
  84. return rsa.decrypt(base64.b64decode(token), self.tenant_id)
  85. @abstractmethod
  86. def get_provider_name(self):
  87. raise NotImplementedError
  88. @abstractmethod
  89. def get_credentials(self, model_id: Optional[str] = None) -> dict:
  90. raise NotImplementedError
  91. @abstractmethod
  92. def get_models(self, model_id: Optional[str] = None) -> list[dict]:
  93. raise NotImplementedError
  94. @abstractmethod
  95. def config_validate(self, config: str):
  96. raise NotImplementedError