azure_provider.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import json
  2. import logging
  3. from typing import Optional, Union
  4. import openai
  5. import requests
  6. from core.llm.provider.base import BaseProvider
  7. from core.llm.provider.errors import ValidateFailedError
  8. from models.provider import ProviderName
  9. AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
  10. class AzureProvider(BaseProvider):
  11. def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]:
  12. return []
  13. def check_embedding_model(self, credentials: Optional[dict] = None):
  14. credentials = self.get_credentials('text-embedding-ada-002') if not credentials else credentials
  15. try:
  16. result = openai.Embedding.create(input=['test'],
  17. engine='text-embedding-ada-002',
  18. timeout=60,
  19. api_key=str(credentials.get('openai_api_key')),
  20. api_base=str(credentials.get('openai_api_base')),
  21. api_type='azure',
  22. api_version=str(credentials.get('openai_api_version')))["data"][0][
  23. "embedding"]
  24. except openai.error.AuthenticationError as e:
  25. raise AzureAuthenticationError(str(e))
  26. except openai.error.APIConnectionError as e:
  27. raise AzureRequestFailedError(
  28. 'Failed to request Azure OpenAI, please check your API Base Endpoint, The format is `https://xxx.openai.azure.com/`')
  29. except openai.error.InvalidRequestError as e:
  30. if e.http_status == 404:
  31. raise AzureRequestFailedError("Please check your 'gpt-3.5-turbo' or 'text-embedding-ada-002' "
  32. "deployment name is exists in Azure AI")
  33. else:
  34. raise AzureRequestFailedError(
  35. 'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
  36. except openai.error.OpenAIError as e:
  37. raise AzureRequestFailedError(
  38. 'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
  39. if not isinstance(result, list):
  40. raise AzureRequestFailedError('Failed to request Azure OpenAI.')
  41. def get_credentials(self, model_id: Optional[str] = None) -> dict:
  42. """
  43. Returns the API credentials for Azure OpenAI as a dictionary.
  44. """
  45. config = self.get_provider_api_key(model_id=model_id)
  46. config['openai_api_type'] = 'azure'
  47. config['openai_api_version'] = AZURE_OPENAI_API_VERSION
  48. if model_id == 'text-embedding-ada-002':
  49. config['deployment'] = model_id.replace('.', '') if model_id else None
  50. config['chunk_size'] = 16
  51. else:
  52. config['deployment_name'] = model_id.replace('.', '') if model_id else None
  53. return config
  54. def get_provider_name(self):
  55. return ProviderName.AZURE_OPENAI
  56. def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
  57. """
  58. Returns the provider configs.
  59. """
  60. try:
  61. config = self.get_provider_api_key(only_custom=only_custom)
  62. except:
  63. config = {
  64. 'openai_api_type': 'azure',
  65. 'openai_api_version': AZURE_OPENAI_API_VERSION,
  66. 'openai_api_base': '',
  67. 'openai_api_key': ''
  68. }
  69. if obfuscated:
  70. if not config.get('openai_api_key'):
  71. config = {
  72. 'openai_api_type': 'azure',
  73. 'openai_api_version': AZURE_OPENAI_API_VERSION,
  74. 'openai_api_base': '',
  75. 'openai_api_key': ''
  76. }
  77. config['openai_api_key'] = self.obfuscated_token(config.get('openai_api_key'))
  78. return config
  79. return config
  80. def get_token_type(self):
  81. return dict
  82. def config_validate(self, config: Union[dict | str]):
  83. """
  84. Validates the given config.
  85. """
  86. try:
  87. if not isinstance(config, dict):
  88. raise ValueError('Config must be a object.')
  89. if 'openai_api_version' not in config:
  90. config['openai_api_version'] = AZURE_OPENAI_API_VERSION
  91. self.check_embedding_model(credentials=config)
  92. except ValidateFailedError as e:
  93. raise e
  94. except AzureAuthenticationError:
  95. raise ValidateFailedError('Validation failed, please check your API Key.')
  96. except AzureRequestFailedError as ex:
  97. raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
  98. except Exception as ex:
  99. logging.exception('Azure OpenAI Credentials validation failed')
  100. raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
  101. def get_encrypted_token(self, config: Union[dict | str]):
  102. """
  103. Returns the encrypted token.
  104. """
  105. return json.dumps({
  106. 'openai_api_type': 'azure',
  107. 'openai_api_version': AZURE_OPENAI_API_VERSION,
  108. 'openai_api_base': config['openai_api_base'],
  109. 'openai_api_key': self.encrypt_token(config['openai_api_key'])
  110. })
  111. def get_decrypted_token(self, token: str):
  112. """
  113. Returns the decrypted token.
  114. """
  115. config = json.loads(token)
  116. config['openai_api_key'] = self.decrypt_token(config['openai_api_key'])
  117. return config
  118. class AzureAuthenticationError(Exception):
  119. pass
  120. class AzureRequestFailedError(Exception):
  121. pass