openai_provider.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. import json
  2. import logging
  3. from json import JSONDecodeError
  4. from typing import Type, Optional
  5. from flask import current_app
  6. from openai.error import AuthenticationError, OpenAIError
  7. import openai
  8. from core.helper import encrypter
  9. from core.model_providers.models.entity.provider import ModelFeature
  10. from core.model_providers.models.speech2text.openai_whisper import OpenAIWhisper
  11. from core.model_providers.models.base import BaseProviderModel
  12. from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
  13. from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
  14. from core.model_providers.models.llm.openai_model import OpenAIModel
  15. from core.model_providers.models.moderation.openai_moderation import OpenAIModeration
  16. from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
  17. from core.model_providers.providers.hosted import hosted_model_providers
  18. from models.provider import ProviderType, ProviderQuotaType
  19. class OpenAIProvider(BaseModelProvider):
  20. @property
  21. def provider_name(self):
  22. """
  23. Returns the name of a provider.
  24. """
  25. return 'openai'
  26. def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
  27. if model_type == ModelType.TEXT_GENERATION:
  28. models = [
  29. {
  30. 'id': 'gpt-3.5-turbo',
  31. 'name': 'gpt-3.5-turbo',
  32. 'features': [
  33. ModelFeature.AGENT_THOUGHT.value
  34. ]
  35. },
  36. {
  37. 'id': 'gpt-3.5-turbo-instruct',
  38. 'name': 'GPT-3.5-Turbo-Instruct',
  39. },
  40. {
  41. 'id': 'gpt-3.5-turbo-16k',
  42. 'name': 'gpt-3.5-turbo-16k',
  43. 'features': [
  44. ModelFeature.AGENT_THOUGHT.value
  45. ]
  46. },
  47. {
  48. 'id': 'gpt-4',
  49. 'name': 'gpt-4',
  50. 'features': [
  51. ModelFeature.AGENT_THOUGHT.value
  52. ]
  53. },
  54. {
  55. 'id': 'gpt-4-32k',
  56. 'name': 'gpt-4-32k',
  57. 'features': [
  58. ModelFeature.AGENT_THOUGHT.value
  59. ]
  60. },
  61. {
  62. 'id': 'text-davinci-003',
  63. 'name': 'text-davinci-003',
  64. }
  65. ]
  66. if self.provider.provider_type == ProviderType.SYSTEM.value \
  67. and self.provider.quota_type == ProviderQuotaType.TRIAL.value:
  68. models = [item for item in models if item['id'] not in ['gpt-4', 'gpt-4-32k']]
  69. return models
  70. elif model_type == ModelType.EMBEDDINGS:
  71. return [
  72. {
  73. 'id': 'text-embedding-ada-002',
  74. 'name': 'text-embedding-ada-002'
  75. }
  76. ]
  77. elif model_type == ModelType.SPEECH_TO_TEXT:
  78. return [
  79. {
  80. 'id': 'whisper-1',
  81. 'name': 'whisper-1'
  82. }
  83. ]
  84. elif model_type == ModelType.MODERATION:
  85. return [
  86. {
  87. 'id': 'text-moderation-stable',
  88. 'name': 'text-moderation-stable'
  89. }
  90. ]
  91. else:
  92. return []
  93. def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
  94. """
  95. Returns the model class.
  96. :param model_type:
  97. :return:
  98. """
  99. if model_type == ModelType.TEXT_GENERATION:
  100. model_class = OpenAIModel
  101. elif model_type == ModelType.EMBEDDINGS:
  102. model_class = OpenAIEmbedding
  103. elif model_type == ModelType.MODERATION:
  104. model_class = OpenAIModeration
  105. elif model_type == ModelType.SPEECH_TO_TEXT:
  106. model_class = OpenAIWhisper
  107. else:
  108. raise NotImplementedError
  109. return model_class
  110. def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
  111. """
  112. get model parameter rules.
  113. :param model_name:
  114. :param model_type:
  115. :return:
  116. """
  117. model_max_tokens = {
  118. 'gpt-4': 8192,
  119. 'gpt-4-32k': 32768,
  120. 'gpt-3.5-turbo': 4096,
  121. 'gpt-3.5-turbo-instruct': 8192,
  122. 'gpt-3.5-turbo-16k': 16384,
  123. 'text-davinci-003': 4097,
  124. }
  125. return ModelKwargsRules(
  126. temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
  127. top_p=KwargRule[float](min=0, max=1, default=1, precision=2),
  128. presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
  129. frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
  130. max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16, precision=0),
  131. )
  132. @classmethod
  133. def is_provider_credentials_valid_or_raise(cls, credentials: dict):
  134. """
  135. Validates the given credentials.
  136. """
  137. if 'openai_api_key' not in credentials:
  138. raise CredentialsValidateFailedError('OpenAI API key is required')
  139. try:
  140. credentials_kwargs = {
  141. "api_key": credentials['openai_api_key']
  142. }
  143. if 'openai_api_base' in credentials and credentials['openai_api_base']:
  144. credentials_kwargs['api_base'] = credentials['openai_api_base'] + '/v1'
  145. if 'openai_organization' in credentials:
  146. credentials_kwargs['organization'] = credentials['openai_organization']
  147. openai.ChatCompletion.create(
  148. messages=[{"role": "user", "content": 'ping'}],
  149. model='gpt-3.5-turbo',
  150. timeout=10,
  151. request_timeout=(5, 30),
  152. max_tokens=20,
  153. **credentials_kwargs
  154. )
  155. except (AuthenticationError, OpenAIError) as ex:
  156. raise CredentialsValidateFailedError(str(ex))
  157. except Exception as ex:
  158. logging.exception('OpenAI config validation failed')
  159. raise ex
  160. @classmethod
  161. def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
  162. credentials['openai_api_key'] = encrypter.encrypt_token(tenant_id, credentials['openai_api_key'])
  163. return credentials
  164. def get_provider_credentials(self, obfuscated: bool = False) -> dict:
  165. if self.provider.provider_type == ProviderType.CUSTOM.value:
  166. try:
  167. credentials = json.loads(self.provider.encrypted_config)
  168. except JSONDecodeError:
  169. credentials = {
  170. 'openai_api_base': None,
  171. 'openai_api_key': self.provider.encrypted_config,
  172. 'openai_organization': None
  173. }
  174. if credentials['openai_api_key']:
  175. credentials['openai_api_key'] = encrypter.decrypt_token(
  176. self.provider.tenant_id,
  177. credentials['openai_api_key']
  178. )
  179. if obfuscated:
  180. credentials['openai_api_key'] = encrypter.obfuscated_token(credentials['openai_api_key'])
  181. if 'openai_api_base' not in credentials or not credentials['openai_api_base']:
  182. credentials['openai_api_base'] = None
  183. else:
  184. credentials['openai_api_base'] = credentials['openai_api_base'] + '/v1'
  185. if 'openai_organization' not in credentials:
  186. credentials['openai_organization'] = None
  187. return credentials
  188. else:
  189. if hosted_model_providers.openai:
  190. return {
  191. 'openai_api_base': hosted_model_providers.openai.api_base,
  192. 'openai_api_key': hosted_model_providers.openai.api_key,
  193. 'openai_organization': hosted_model_providers.openai.api_organization
  194. }
  195. else:
  196. return {
  197. 'openai_api_base': None,
  198. 'openai_api_key': None,
  199. 'openai_organization': None
  200. }
  201. @classmethod
  202. def is_provider_type_system_supported(cls) -> bool:
  203. if current_app.config['EDITION'] != 'CLOUD':
  204. return False
  205. if hosted_model_providers.openai:
  206. return True
  207. return False
  208. def should_deduct_quota(self):
  209. if hosted_model_providers.openai \
  210. and hosted_model_providers.openai.quota_limit and hosted_model_providers.openai.quota_limit > 0:
  211. return True
  212. return False
  213. def get_payment_info(self) -> Optional[dict]:
  214. """
  215. get payment info if it payable.
  216. :return:
  217. """
  218. if hosted_model_providers.openai \
  219. and hosted_model_providers.openai.paid_enabled:
  220. return {
  221. 'product_id': hosted_model_providers.openai.paid_stripe_price_id,
  222. 'increase_quota': hosted_model_providers.openai.paid_increase_quota,
  223. }
  224. return None
  225. @classmethod
  226. def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
  227. """
  228. check model credentials valid.
  229. :param model_name:
  230. :param model_type:
  231. :param credentials:
  232. """
  233. return
  234. @classmethod
  235. def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, credentials: dict) -> dict:
  236. """
  237. encrypt model credentials for save.
  238. :param tenant_id:
  239. :param model_name:
  240. :param model_type:
  241. :param credentials:
  242. :return:
  243. """
  244. return {}
  245. def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
  246. """
  247. get credentials for llm use.
  248. :param model_name:
  249. :param model_type:
  250. :param obfuscated:
  251. :return:
  252. """
  253. return self.get_provider_credentials(obfuscated)