| 
					
				 | 
			
			
				@@ -1,11 +1,13 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from typing import Dict, Any, Optional, List, Tuple, Union 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from typing import Dict, Any, Optional, List, Tuple, Union, cast 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from langchain.callbacks.manager import CallbackManagerForLLMRun 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from langchain.chat_models import AzureChatOpenAI 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from langchain.chat_models.openai import _convert_dict_to_message 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from langchain.schema import ChatResult, BaseMessage, ChatGeneration 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from pydantic import root_validator 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from langchain.schema import ChatResult, BaseMessage, ChatGeneration, ChatMessage, HumanMessage, AIMessage, SystemMessage, FunctionMessage 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.model_providers.models.entity.message import LCHumanMessageWithFiles, PromptMessageFileType, ImagePromptMessageFile 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 class EnhanceAzureChatOpenAI(AzureChatOpenAI): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -51,13 +53,18 @@ class EnhanceAzureChatOpenAI(AzureChatOpenAI): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def _generate( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        messages: List[BaseMessage], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        stop: Optional[List[str]] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        run_manager: Optional[CallbackManagerForLLMRun] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        **kwargs: Any, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            messages: List[BaseMessage], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            stop: Optional[List[str]] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            run_manager: Optional[CallbackManagerForLLMRun] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            **kwargs: Any, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     ) -> ChatResult: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        message_dicts, params = self._create_message_dicts(messages, stop) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        params = self._client_params 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if stop is not None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if "stop" in params: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                raise ValueError("`stop` found in both the input and default params.") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            params["stop"] = stop 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        message_dicts = [self._convert_message_to_dict(m) for m in messages] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         params = {**params, **kwargs} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if self.streaming: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             inner_completion = "" 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -65,7 +72,7 @@ class EnhanceAzureChatOpenAI(AzureChatOpenAI): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             params["stream"] = True 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             function_call: Optional[dict] = None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             for stream_resp in self.completion_with_retry( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                messages=message_dicts, **params 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    messages=message_dicts, **params 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             ): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 if len(stream_resp["choices"]) > 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     role = stream_resp["choices"][0]["delta"].get("role", role) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -88,4 +95,47 @@ class EnhanceAzureChatOpenAI(AzureChatOpenAI): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             return ChatResult(generations=[ChatGeneration(message=message)]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         response = self.completion_with_retry(messages=message_dicts, **params) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return self._create_chat_result(response) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return self._create_chat_result(response) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _convert_message_to_dict(self, message: BaseMessage) -> dict: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if isinstance(message, ChatMessage): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            message_dict = {"role": message.role, "content": message.content} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif isinstance(message, LCHumanMessageWithFiles): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            content = [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "type": "text", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "text": message.content 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for file in message.files: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if file.type == PromptMessageFileType.IMAGE: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    file = cast(ImagePromptMessageFile, file) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    content.append({ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        "type": "image_url", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        "image_url": { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            "url": file.data, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            "detail": file.detail.value 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    }) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            message_dict = {"role": "user", "content": content} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif isinstance(message, HumanMessage): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            message_dict = {"role": "user", "content": message.content} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif isinstance(message, AIMessage): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            message_dict = {"role": "assistant", "content": message.content} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if "function_call" in message.additional_kwargs: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                message_dict["function_call"] = message.additional_kwargs["function_call"] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif isinstance(message, SystemMessage): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            message_dict = {"role": "system", "content": message.content} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif isinstance(message, FunctionMessage): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            message_dict = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "role": "function", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "content": message.content, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "name": message.name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            raise ValueError(f"Got unknown type {message}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if "name" in message.additional_kwargs: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            message_dict["name"] = message.additional_kwargs["name"] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return message_dict 
			 |