|  | @@ -1,22 +1,24 @@
 | 
	
		
			
				|  |  |  import json
 | 
	
		
			
				|  |  | +import logging
 | 
	
		
			
				|  |  |  from typing import Optional, Union
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import requests
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from core.llm.provider.base import BaseProvider
 | 
	
		
			
				|  |  | +from core.llm.provider.errors import ValidateFailedError
 | 
	
		
			
				|  |  |  from models.provider import ProviderName
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class AzureProvider(BaseProvider):
 | 
	
		
			
				|  |  | -    def get_models(self, model_id: Optional[str] = None) -> list[dict]:
 | 
	
		
			
				|  |  | -        credentials = self.get_credentials(model_id)
 | 
	
		
			
				|  |  | +    def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]:
 | 
	
		
			
				|  |  | +        credentials = self.get_credentials(model_id) if not credentials else credentials
 | 
	
		
			
				|  |  |          url = "{}/openai/deployments?api-version={}".format(
 | 
	
		
			
				|  |  | -            credentials.get('openai_api_base'),
 | 
	
		
			
				|  |  | -            credentials.get('openai_api_version')
 | 
	
		
			
				|  |  | +            str(credentials.get('openai_api_base')),
 | 
	
		
			
				|  |  | +            str(credentials.get('openai_api_version'))
 | 
	
		
			
				|  |  |          )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          headers = {
 | 
	
		
			
				|  |  | -            "api-key": credentials.get('openai_api_key'),
 | 
	
		
			
				|  |  | +            "api-key": str(credentials.get('openai_api_key')),
 | 
	
		
			
				|  |  |              "content-type": "application/json; charset=utf-8"
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -29,8 +31,10 @@ class AzureProvider(BaseProvider):
 | 
	
		
			
				|  |  |                  'name': '{} ({})'.format(deployment['id'], deployment['model'])
 | 
	
		
			
				|  |  |              } for deployment in result['data'] if deployment['status'] == 'succeeded']
 | 
	
		
			
				|  |  |          else:
 | 
	
		
			
				|  |  | -            # TODO: optimize in future
 | 
	
		
			
				|  |  | -            raise Exception('Failed to get deployments from Azure OpenAI. Status code: {}'.format(response.status_code))
 | 
	
		
			
				|  |  | +            if response.status_code == 401:
 | 
	
		
			
				|  |  | +                raise AzureAuthenticationError()
 | 
	
		
			
				|  |  | +            else:
 | 
	
		
			
				|  |  | +                raise AzureRequestFailedError('Failed to request Azure OpenAI. Status code: {}'.format(response.status_code))
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def get_credentials(self, model_id: Optional[str] = None) -> dict:
 | 
	
		
			
				|  |  |          """
 | 
	
	
		
			
				|  | @@ -38,7 +42,7 @@ class AzureProvider(BaseProvider):
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |          config = self.get_provider_api_key(model_id=model_id)
 | 
	
		
			
				|  |  |          config['openai_api_type'] = 'azure'
 | 
	
		
			
				|  |  | -        config['deployment_name'] = model_id.replace('.', '')
 | 
	
		
			
				|  |  | +        config['deployment_name'] = model_id.replace('.', '') if model_id else None
 | 
	
		
			
				|  |  |          return config
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def get_provider_name(self):
 | 
	
	
		
			
				|  | @@ -54,7 +58,7 @@ class AzureProvider(BaseProvider):
 | 
	
		
			
				|  |  |              config = {
 | 
	
		
			
				|  |  |                  'openai_api_type': 'azure',
 | 
	
		
			
				|  |  |                  'openai_api_version': '2023-03-15-preview',
 | 
	
		
			
				|  |  | -                'openai_api_base': 'https://<your-domain-prefix>.openai.azure.com/',
 | 
	
		
			
				|  |  | +                'openai_api_base': '',
 | 
	
		
			
				|  |  |                  'openai_api_key': ''
 | 
	
		
			
				|  |  |              }
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -63,7 +67,7 @@ class AzureProvider(BaseProvider):
 | 
	
		
			
				|  |  |                  config = {
 | 
	
		
			
				|  |  |                      'openai_api_type': 'azure',
 | 
	
		
			
				|  |  |                      'openai_api_version': '2023-03-15-preview',
 | 
	
		
			
				|  |  | -                    'openai_api_base': 'https://<your-domain-prefix>.openai.azure.com/',
 | 
	
		
			
				|  |  | +                    'openai_api_base': '',
 | 
	
		
			
				|  |  |                      'openai_api_key': ''
 | 
	
		
			
				|  |  |                  }
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -80,8 +84,23 @@ class AzureProvider(BaseProvider):
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |          Validates the given config.
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  | -        # TODO: implement
 | 
	
		
			
				|  |  | -        pass
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            if not isinstance(config, dict):
 | 
	
		
			
				|  |  | +                raise ValueError('Config must be a object.')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            if 'openai_api_version' not in config:
 | 
	
		
			
				|  |  | +                config['openai_api_version'] = '2023-03-15-preview'
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            self.get_models(credentials=config)
 | 
	
		
			
				|  |  | +        except AzureAuthenticationError:
 | 
	
		
			
				|  |  | +            raise ValidateFailedError('Azure OpenAI Credentials validation failed, please check your API Key.')
 | 
	
		
			
				|  |  | +        except requests.ConnectionError:
 | 
	
		
			
				|  |  | +            raise ValidateFailedError('Azure OpenAI Credentials validation failed, please check your API Base Endpoint.')
 | 
	
		
			
				|  |  | +        except AzureRequestFailedError as ex:
 | 
	
		
			
				|  |  | +            raise ValidateFailedError('Azure OpenAI Credentials validation failed, error: {}.'.format(str(ex)))
 | 
	
		
			
				|  |  | +        except Exception as ex:
 | 
	
		
			
				|  |  | +            logging.exception('Azure OpenAI Credentials validation failed')
 | 
	
		
			
				|  |  | +            raise ex
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def get_encrypted_token(self, config: Union[dict | str]):
 | 
	
		
			
				|  |  |          """
 | 
	
	
		
			
				|  | @@ -101,3 +120,11 @@ class AzureProvider(BaseProvider):
 | 
	
		
			
				|  |  |          config = json.loads(token)
 | 
	
		
			
				|  |  |          config['openai_api_key'] = self.decrypt_token(config['openai_api_key'])
 | 
	
		
			
				|  |  |          return config
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class AzureAuthenticationError(Exception):
 | 
	
		
			
				|  |  | +    pass
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class AzureRequestFailedError(Exception):
 | 
	
		
			
				|  |  | +    pass
 |