| 
					
				 | 
			
			
				@@ -0,0 +1,438 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import base64 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import json 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import logging 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from collections.abc import Generator 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from typing import Optional, Union 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import google.api_core.exceptions as exceptions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import vertexai.generative_models as glm 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from google.cloud import aiplatform 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from google.oauth2 import service_account 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from vertexai.generative_models import HarmBlockThreshold, HarmCategory 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.model_runtime.entities.message_entities import ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    AssistantPromptMessage, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    PromptMessage, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    PromptMessageContentType, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    PromptMessageTool, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    SystemPromptMessage, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ToolPromptMessage, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    UserPromptMessage, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.model_runtime.errors.invoke import ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    InvokeAuthorizationError, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    InvokeBadRequestError, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    InvokeConnectionError, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    InvokeError, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    InvokeRateLimitError, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    InvokeServerUnavailableError, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.model_runtime.errors.validate import CredentialsValidateFailedError 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+logger = logging.getLogger(__name__) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+if you are not sure about the structure. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+<instructions> 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+{{instructions}} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+</instructions> 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+class VertexAiLargeLanguageModel(LargeLanguageModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _invoke(self, model: str, credentials: dict, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                prompt_messages: list[PromptMessage], model_parameters: dict, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                stream: bool = True, user: Optional[str] = None) \ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            -> Union[LLMResult, Generator]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Invoke large language model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param model: model name 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param credentials: model credentials 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param prompt_messages: prompt messages 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param model_parameters: model parameters 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param tools: tools for tool calling 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param stop: stop words 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param stream: is stream response 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param user: unique user id 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :return: full response or stream response chunk generator result 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # invoke model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                       tools: Optional[list[PromptMessageTool]] = None) -> int: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Get number of tokens for given prompt messages 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param model: model name 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param credentials: model credentials 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param prompt_messages: prompt messages 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param tools: tools for tool calling 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :return:md = gml.GenerativeModel(model) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        prompt = self._convert_messages_to_prompt(prompt_messages) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return self._get_num_tokens_by_gpt2(prompt) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Format a list of messages into a full prompt for the Google model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param messages: List of PromptMessage to combine. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :return: Combined string with necessary human_prompt and ai_prompt tags. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        messages = messages.copy()  # don't mutate the original list 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        text = "".join( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self._convert_one_message_to_text(message) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for message in messages 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return text.rstrip() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Convert tool messages to glm tools 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param tools: tool messages 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :return: glm tools 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return glm.Tool( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            function_declarations=[ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                glm.FunctionDeclaration( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    name=tool.name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    parameters=glm.Schema( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        type=glm.Type.OBJECT, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        properties={ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            key: { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                'type_': value.get('type', 'string').upper(), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                'description': value.get('description', ''), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                'enum': value.get('enum', []) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            } for key, value in tool.parameters.get('properties', {}).items() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        }, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        required=tool.parameters.get('required', []) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    ), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ) for tool in tools 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def validate_credentials(self, model: str, credentials: dict) -> None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Validate model credentials 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param model: model name 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param credentials: model credentials 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :return: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        try: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ping_message = SystemPromptMessage(content="ping") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5}) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+             
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        except Exception as ex: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            raise CredentialsValidateFailedError(str(ex)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+             
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _generate(self, model: str, credentials: dict, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                  prompt_messages: list[PromptMessage], model_parameters: dict, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                  tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                  stream: bool = True, user: Optional[str] = None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ) -> Union[LLMResult, Generator]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Invoke large language model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param model: model name 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param credentials: credentials kwargs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param prompt_messages: prompt messages 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param model_parameters: model parameters 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param stop: stop words 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param stream: is stream response 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param user: unique user id 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :return: full response or stream response chunk generator result 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        config_kwargs = model_parameters.copy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if stop: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            config_kwargs["stop_sequences"] = stop 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        project_id = credentials["vertex_project_id"] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        location = credentials["vertex_location"] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        aiplatform.init(credentials=service_accountSA, project=project_id, location=location) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        history = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        system_instruction = GEMINI_BLOCK_MODE_PROMPT 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # hack for gemini-pro-vision, which currently does not support multi-turn chat 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if model == "gemini-1.0-pro-vision-001": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            last_msg = prompt_messages[-1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            content = self._format_message_to_glm_content(last_msg) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            history.append(content) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for msg in prompt_messages: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if isinstance(msg, SystemPromptMessage): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    system_instruction = msg.content 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    content = self._format_message_to_glm_content(msg) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    if history and history[-1].role == content.role: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        history[-1].parts.extend(content.parts) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        history.append(content) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        safety_settings={ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        google_model = glm.GenerativeModel( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            model_name=model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            system_instruction=system_instruction 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        response = google_model.generate_content( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            contents=history, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            generation_config=glm.GenerationConfig( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                **config_kwargs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            stream=stream, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            safety_settings=safety_settings, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            tools=self._convert_tools_to_glm_tool(tools) if tools else None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if stream: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return self._handle_generate_stream_response(model, credentials, response, prompt_messages) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return self._handle_generate_response(model, credentials, response, prompt_messages) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _handle_generate_response(self, model: str, credentials: dict, response: glm.GenerationResponse, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                  prompt_messages: list[PromptMessage]) -> LLMResult: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Handle llm response 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param model: model name 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param credentials: credentials 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param response: response 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param prompt_messages: prompt messages 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :return: llm response 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # transform assistant message to prompt message 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        assistant_prompt_message = AssistantPromptMessage( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            content=response.candidates[0].content.parts[0].text 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # calculate num tokens 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # transform usage 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # transform response 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        result = LLMResult( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            model=model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            prompt_messages=prompt_messages, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            message=assistant_prompt_message, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            usage=usage, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return result 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _handle_generate_stream_response(self, model: str, credentials: dict, response: glm.GenerationResponse, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                         prompt_messages: list[PromptMessage]) -> Generator: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Handle llm stream response 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param model: model name 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param credentials: credentials 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param response: response 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param prompt_messages: prompt messages 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :return: llm response chunk generator result 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        index = -1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for chunk in response: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for part in chunk.candidates[0].content.parts: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                assistant_prompt_message = AssistantPromptMessage( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    content='' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if part.text: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    assistant_prompt_message.content += part.text 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if part.function_call: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    assistant_prompt_message.tool_calls = [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        AssistantPromptMessage.ToolCall( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            id=part.function_call.name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            type='function', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            function=AssistantPromptMessage.ToolCall.ToolCallFunction( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                name=part.function_call.name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                arguments=json.dumps({ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                    key: value  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                    for key, value in part.function_call.args.items() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                }) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    ] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                index += 1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if not hasattr(chunk, 'finish_reason') or not chunk.finish_reason:                     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    # transform assistant message to prompt message 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    yield LLMResultChunk( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        model=model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        prompt_messages=prompt_messages, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        delta=LLMResultChunkDelta( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            index=index, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            message=assistant_prompt_message 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    # calculate num tokens 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    # transform usage 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    yield LLMResultChunk( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        model=model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        prompt_messages=prompt_messages, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        delta=LLMResultChunkDelta( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            index=index, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            message=assistant_prompt_message, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            finish_reason=chunk.candidates[0].finish_reason, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            usage=usage 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _convert_one_message_to_text(self, message: PromptMessage) -> str: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Convert a single message to a string. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param message: PromptMessage to convert. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :return: String representation of the message. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        human_prompt = "\n\nuser:" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ai_prompt = "\n\nmodel:" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        content = message.content 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if isinstance(content, list): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            content = "".join( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                c.data for c in content if c.type != PromptMessageContentType.IMAGE 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if isinstance(message, UserPromptMessage): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            message_text = f"{human_prompt} {content}" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif isinstance(message, AssistantPromptMessage): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            message_text = f"{ai_prompt} {content}" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif isinstance(message, SystemPromptMessage): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            message_text = f"{human_prompt} {content}" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif isinstance(message, ToolPromptMessage): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            message_text = f"{human_prompt} {content}" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            raise ValueError(f"Got unknown type {message}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return message_text 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Format a single message into glm.Content for Google API 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param message: one PromptMessage 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :return: glm Content representation of message 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if isinstance(message, UserPromptMessage): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            glm_content = glm.Content(role="user", parts=[]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if (isinstance(message.content, str)): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                glm_content = glm.Content(role="user", parts=[glm.Part.from_text(message.content)]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                parts = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                for c in message.content: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    if c.type == PromptMessageContentType.TEXT: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        parts.append(glm.Part.from_text(c.data)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        metadata, data = c.data.split(',', 1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        mime_type = metadata.split(';', 1)[0].split(':')[1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        blob = {"inline_data":{"mime_type":mime_type,"data":data}} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        parts.append(blob) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                glm_content = glm.Content(role="user", parts=[parts]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return glm_content 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif isinstance(message, AssistantPromptMessage): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if message.content: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                glm_content = glm.Content(role="model", parts=[glm.Part.from_text(message.content)]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if message.tool_calls: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                glm_content = glm.Content(role="model", parts=[glm.Part.from_function_response(glm.FunctionCall( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    name=message.tool_calls[0].function.name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    args=json.loads(message.tool_calls[0].function.arguments), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ))]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return glm_content 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif isinstance(message, ToolPromptMessage): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            glm_content = glm.Content(role="function", parts=[glm.Part(function_response=glm.FunctionResponse( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    name=message.name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    response={ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        "response": message.content 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ))]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return glm_content 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            raise ValueError(f"Got unknown type {message}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    @property 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Map model invoke error to unified error 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        The key is the ermd = gml.GenerativeModel(model)ror type thrown to the caller 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        The value is the md = gml.GenerativeModel(model)error type thrown by the model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        which needs to be converted into a unified error type for the caller. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :return: Invoke emd = gml.GenerativeModel(model)rror mapping 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            InvokeConnectionError: [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.RetryError 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            InvokeServerUnavailableError: [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.ServiceUnavailable, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.InternalServerError, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.BadGateway, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.GatewayTimeout, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.DeadlineExceeded 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            InvokeRateLimitError: [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.ResourceExhausted, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.TooManyRequests 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            InvokeAuthorizationError: [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.Unauthenticated, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.PermissionDenied, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.Unauthenticated, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.Forbidden 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            InvokeBadRequestError: [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.BadRequest, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.InvalidArgument, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.FailedPrecondition, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.OutOfRange, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.NotFound, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.MethodNotAllowed, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.Conflict, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.AlreadyExists, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.Aborted, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.LengthRequired, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.PreconditionFailed, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.RequestRangeNotSatisfiable, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                exceptions.Cancelled, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 |