base.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. from abc import ABC, abstractmethod
  2. from datetime import datetime
  3. from typing import Type, Optional
  4. from flask import current_app
  5. from pydantic import BaseModel
  6. from core.model_providers.error import QuotaExceededError, LLMBadRequestError
  7. from extensions.ext_database import db
  8. from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
  9. from core.model_providers.models.entity.provider import ProviderQuotaUnit
  10. from core.model_providers.rules import provider_rules
  11. from models.provider import Provider, ProviderType, ProviderModel
  12. class BaseModelProvider(BaseModel, ABC):
  13. provider: Provider
  14. class Config:
  15. """Configuration for this pydantic object."""
  16. arbitrary_types_allowed = True
  17. @property
  18. @abstractmethod
  19. def provider_name(self):
  20. """
  21. Returns the name of a provider.
  22. """
  23. raise NotImplementedError
  24. def get_rules(self):
  25. """
  26. Returns the rules of a provider.
  27. """
  28. return provider_rules[self.provider_name]
  29. def get_supported_model_list(self, model_type: ModelType) -> list[dict]:
  30. """
  31. get supported model object list for use.
  32. :param model_type:
  33. :return:
  34. """
  35. rules = self.get_rules()
  36. if 'custom' not in rules['support_provider_types']:
  37. return self._get_fixed_model_list(model_type)
  38. if 'model_flexibility' not in rules:
  39. return self._get_fixed_model_list(model_type)
  40. if rules['model_flexibility'] == 'fixed':
  41. return self._get_fixed_model_list(model_type)
  42. # get configurable provider models
  43. provider_models = db.session.query(ProviderModel).filter(
  44. ProviderModel.tenant_id == self.provider.tenant_id,
  45. ProviderModel.provider_name == self.provider.provider_name,
  46. ProviderModel.model_type == model_type.value,
  47. ProviderModel.is_valid == True
  48. ).order_by(ProviderModel.created_at.asc()).all()
  49. return [{
  50. 'id': provider_model.model_name,
  51. 'name': provider_model.model_name
  52. } for provider_model in provider_models]
  53. @abstractmethod
  54. def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
  55. """
  56. get supported model object list for use.
  57. :param model_type:
  58. :return:
  59. """
  60. raise NotImplementedError
  61. @abstractmethod
  62. def get_model_class(self, model_type: ModelType) -> Type:
  63. """
  64. get specific model class.
  65. :param model_type:
  66. :return:
  67. """
  68. raise NotImplementedError
  69. @classmethod
  70. @abstractmethod
  71. def is_provider_credentials_valid_or_raise(cls, credentials: dict):
  72. """
  73. check provider credentials valid.
  74. :param credentials:
  75. """
  76. raise NotImplementedError
  77. @classmethod
  78. @abstractmethod
  79. def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
  80. """
  81. encrypt provider credentials for save.
  82. :param tenant_id:
  83. :param credentials:
  84. :return:
  85. """
  86. raise NotImplementedError
  87. @abstractmethod
  88. def get_provider_credentials(self, obfuscated: bool = False) -> dict:
  89. """
  90. get credentials for llm use.
  91. :param obfuscated:
  92. :return:
  93. """
  94. raise NotImplementedError
  95. @classmethod
  96. @abstractmethod
  97. def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
  98. """
  99. check model credentials valid.
  100. :param model_name:
  101. :param model_type:
  102. :param credentials:
  103. """
  104. raise NotImplementedError
  105. @classmethod
  106. @abstractmethod
  107. def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
  108. credentials: dict) -> dict:
  109. """
  110. encrypt model credentials for save.
  111. :param tenant_id:
  112. :param model_name:
  113. :param model_type:
  114. :param credentials:
  115. :return:
  116. """
  117. raise NotImplementedError
  118. @abstractmethod
  119. def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
  120. """
  121. get model parameter rules.
  122. :param model_name:
  123. :param model_type:
  124. :return:
  125. """
  126. raise NotImplementedError
  127. @abstractmethod
  128. def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
  129. """
  130. get credentials for llm use.
  131. :param model_name:
  132. :param model_type:
  133. :param obfuscated:
  134. :return:
  135. """
  136. raise NotImplementedError
  137. @classmethod
  138. def is_provider_type_system_supported(cls) -> bool:
  139. return current_app.config['EDITION'] == 'CLOUD'
  140. def check_quota_over_limit(self):
  141. """
  142. check provider quota over limit.
  143. :return:
  144. """
  145. if self.provider.provider_type != ProviderType.SYSTEM.value:
  146. return
  147. rules = self.get_rules()
  148. if 'system' not in rules['support_provider_types']:
  149. return
  150. provider = db.session.query(Provider).filter(
  151. db.and_(
  152. Provider.id == self.provider.id,
  153. Provider.is_valid == True,
  154. Provider.quota_limit > Provider.quota_used
  155. )
  156. ).first()
  157. if not provider:
  158. raise QuotaExceededError()
  159. def deduct_quota(self, used_tokens: int = 0) -> None:
  160. """
  161. deduct available quota when provider type is system or paid.
  162. :return:
  163. """
  164. if self.provider.provider_type != ProviderType.SYSTEM.value:
  165. return
  166. rules = self.get_rules()
  167. if 'system' not in rules['support_provider_types']:
  168. return
  169. if not self.should_deduct_quota():
  170. return
  171. if 'system_config' not in rules:
  172. quota_unit = ProviderQuotaUnit.TIMES.value
  173. elif 'quota_unit' not in rules['system_config']:
  174. quota_unit = ProviderQuotaUnit.TIMES.value
  175. else:
  176. quota_unit = rules['system_config']['quota_unit']
  177. if quota_unit == ProviderQuotaUnit.TOKENS.value:
  178. used_quota = used_tokens
  179. else:
  180. used_quota = 1
  181. db.session.query(Provider).filter(
  182. Provider.tenant_id == self.provider.tenant_id,
  183. Provider.provider_name == self.provider.provider_name,
  184. Provider.provider_type == self.provider.provider_type,
  185. Provider.quota_type == self.provider.quota_type,
  186. Provider.quota_limit > Provider.quota_used
  187. ).update({'quota_used': Provider.quota_used + used_quota})
  188. db.session.commit()
  189. def should_deduct_quota(self):
  190. return False
  191. def update_last_used(self) -> None:
  192. """
  193. update last used time.
  194. :return:
  195. """
  196. db.session.query(Provider).filter(
  197. Provider.tenant_id == self.provider.tenant_id,
  198. Provider.provider_name == self.provider.provider_name
  199. ).update({'last_used': datetime.utcnow()})
  200. db.session.commit()
  201. def get_payment_info(self) -> Optional[dict]:
  202. """
  203. get product info if it payable.
  204. :return:
  205. """
  206. return None
  207. def _get_provider_model(self, model_name: str, model_type: ModelType) -> ProviderModel:
  208. """
  209. get provider model.
  210. :param model_name:
  211. :param model_type:
  212. :return:
  213. """
  214. provider_model = db.session.query(ProviderModel).filter(
  215. ProviderModel.tenant_id == self.provider.tenant_id,
  216. ProviderModel.provider_name == self.provider.provider_name,
  217. ProviderModel.model_name == model_name,
  218. ProviderModel.model_type == model_type.value,
  219. ProviderModel.is_valid == True
  220. ).first()
  221. if not provider_model:
  222. raise LLMBadRequestError(f"The model {model_name} does not exist. "
  223. f"Please check the configuration.")
  224. return provider_model
  225. class CredentialsValidateFailedError(Exception):
  226. pass