azure_provider.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import json
  2. from typing import Optional, Union
  3. import requests
  4. from core.llm.provider.base import BaseProvider
  5. from models.provider import ProviderName
  6. class AzureProvider(BaseProvider):
  7. def get_models(self, model_id: Optional[str] = None) -> list[dict]:
  8. credentials = self.get_credentials(model_id)
  9. url = "{}/openai/deployments?api-version={}".format(
  10. credentials.get('openai_api_base'),
  11. credentials.get('openai_api_version')
  12. )
  13. headers = {
  14. "api-key": credentials.get('openai_api_key'),
  15. "content-type": "application/json; charset=utf-8"
  16. }
  17. response = requests.get(url, headers=headers)
  18. if response.status_code == 200:
  19. result = response.json()
  20. return [{
  21. 'id': deployment['id'],
  22. 'name': '{} ({})'.format(deployment['id'], deployment['model'])
  23. } for deployment in result['data'] if deployment['status'] == 'succeeded']
  24. else:
  25. # TODO: optimize in future
  26. raise Exception('Failed to get deployments from Azure OpenAI. Status code: {}'.format(response.status_code))
  27. def get_credentials(self, model_id: Optional[str] = None) -> dict:
  28. """
  29. Returns the API credentials for Azure OpenAI as a dictionary.
  30. """
  31. encrypted_config = self.get_provider_api_key(model_id=model_id)
  32. config = json.loads(encrypted_config)
  33. config['openai_api_type'] = 'azure'
  34. config['deployment_name'] = model_id
  35. return config
  36. def get_provider_name(self):
  37. return ProviderName.AZURE_OPENAI
  38. def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
  39. """
  40. Returns the provider configs.
  41. """
  42. try:
  43. config = self.get_provider_api_key()
  44. config = json.loads(config)
  45. except:
  46. config = {
  47. 'openai_api_type': 'azure',
  48. 'openai_api_version': '2023-03-15-preview',
  49. 'openai_api_base': 'https://foo.microsoft.com/bar',
  50. 'openai_api_key': ''
  51. }
  52. if obfuscated:
  53. if not config.get('openai_api_key'):
  54. config = {
  55. 'openai_api_type': 'azure',
  56. 'openai_api_version': '2023-03-15-preview',
  57. 'openai_api_base': 'https://foo.microsoft.com/bar',
  58. 'openai_api_key': ''
  59. }
  60. config['openai_api_key'] = self.obfuscated_token(config.get('openai_api_key'))
  61. return config
  62. return config
  63. def get_token_type(self):
  64. # TODO: change to dict when implemented
  65. return lambda value: value
  66. def config_validate(self, config: Union[dict | str]):
  67. """
  68. Validates the given config.
  69. """
  70. # TODO: implement
  71. pass
  72. def get_encrypted_token(self, config: Union[dict | str]):
  73. """
  74. Returns the encrypted token.
  75. """
  76. return json.dumps({
  77. 'openai_api_type': 'azure',
  78. 'openai_api_version': '2023-03-15-preview',
  79. 'openai_api_base': config['openai_api_base'],
  80. 'openai_api_key': self.encrypt_token(config['openai_api_key'])
  81. })
  82. def get_decrypted_token(self, token: str):
  83. """
  84. Returns the decrypted token.
  85. """
  86. config = json.loads(token)
  87. config['openai_api_key'] = self.decrypt_token(config['openai_api_key'])
  88. return config