|  | @@ -1,7 +1,9 @@
 | 
	
		
			
				|  |  | +import json
 | 
	
		
			
				|  |  |  import logging
 | 
	
		
			
				|  |  |  from collections.abc import Generator
 | 
	
		
			
				|  |  |  from typing import Optional, Union
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +import google.ai.generativelanguage as glm
 | 
	
		
			
				|  |  |  import google.api_core.exceptions as exceptions
 | 
	
		
			
				|  |  |  import google.generativeai as genai
 | 
	
		
			
				|  |  |  import google.generativeai.client as client
 | 
	
	
		
			
				|  | @@ -13,9 +15,9 @@ from core.model_runtime.entities.message_entities import (
 | 
	
		
			
				|  |  |      AssistantPromptMessage,
 | 
	
		
			
				|  |  |      PromptMessage,
 | 
	
		
			
				|  |  |      PromptMessageContentType,
 | 
	
		
			
				|  |  | -    PromptMessageRole,
 | 
	
		
			
				|  |  |      PromptMessageTool,
 | 
	
		
			
				|  |  |      SystemPromptMessage,
 | 
	
		
			
				|  |  | +    ToolPromptMessage,
 | 
	
		
			
				|  |  |      UserPromptMessage,
 | 
	
		
			
				|  |  |  )
 | 
	
		
			
				|  |  |  from core.model_runtime.errors.invoke import (
 | 
	
	
		
			
				|  | @@ -62,7 +64,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |          :return: full response or stream response chunk generator result
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |          # invoke model
 | 
	
		
			
				|  |  | -        return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
 | 
	
		
			
				|  |  | +        return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
 | 
	
		
			
				|  |  |      
 | 
	
		
			
				|  |  |      def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
 | 
	
		
			
				|  |  |                         tools: Optional[list[PromptMessageTool]] = None) -> int:
 | 
	
	
		
			
				|  | @@ -94,6 +96,32 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |          )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          return text.rstrip()
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +    def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        Convert tool messages to glm tools
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        :param tools: tool messages
 | 
	
		
			
				|  |  | +        :return: glm tools
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        return glm.Tool(
 | 
	
		
			
				|  |  | +            function_declarations=[
 | 
	
		
			
				|  |  | +                glm.FunctionDeclaration(
 | 
	
		
			
				|  |  | +                    name=tool.name,
 | 
	
		
			
				|  |  | +                    parameters=glm.Schema(
 | 
	
		
			
				|  |  | +                        type=glm.Type.OBJECT,
 | 
	
		
			
				|  |  | +                        properties={
 | 
	
		
			
				|  |  | +                            key: {
 | 
	
		
			
				|  |  | +                                'type_': value.get('type', 'string').upper(),
 | 
	
		
			
				|  |  | +                                'description': value.get('description', ''),
 | 
	
		
			
				|  |  | +                                'enum': value.get('enum', [])
 | 
	
		
			
				|  |  | +                            } for key, value in tool.parameters.get('properties', {}).items()
 | 
	
		
			
				|  |  | +                        },
 | 
	
		
			
				|  |  | +                        required=tool.parameters.get('required', [])
 | 
	
		
			
				|  |  | +                    ),
 | 
	
		
			
				|  |  | +                ) for tool in tools
 | 
	
		
			
				|  |  | +            ]
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def validate_credentials(self, model: str, credentials: dict) -> None:
 | 
	
		
			
				|  |  |          """
 | 
	
	
		
			
				|  | @@ -105,7 +133,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |          
 | 
	
		
			
				|  |  |          try:
 | 
	
		
			
				|  |  | -            ping_message = PromptMessage(content="ping", role="system")
 | 
	
		
			
				|  |  | +            ping_message = SystemPromptMessage(content="ping")
 | 
	
		
			
				|  |  |              self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
 | 
	
		
			
				|  |  |              
 | 
	
		
			
				|  |  |          except Exception as ex:
 | 
	
	
		
			
				|  | @@ -114,8 +142,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def _generate(self, model: str, credentials: dict,
 | 
	
		
			
				|  |  |                    prompt_messages: list[PromptMessage], model_parameters: dict,
 | 
	
		
			
				|  |  | -                  stop: Optional[list[str]] = None, stream: bool = True,
 | 
	
		
			
				|  |  | -                  user: Optional[str] = None) -> Union[LLMResult, Generator]:
 | 
	
		
			
				|  |  | +                  tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, 
 | 
	
		
			
				|  |  | +                  stream: bool = True, user: Optional[str] = None
 | 
	
		
			
				|  |  | +        ) -> Union[LLMResult, Generator]:
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |          Invoke large language model
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -153,7 +182,6 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |                  else:
 | 
	
		
			
				|  |  |                      history.append(content)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |          # Create a new ClientManager with tenant's API key
 | 
	
		
			
				|  |  |          new_client_manager = client._ClientManager()
 | 
	
		
			
				|  |  |          new_client_manager.configure(api_key=credentials["google_api_key"])
 | 
	
	
		
			
				|  | @@ -167,14 +195,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |              HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
 | 
	
		
			
				|  |  |              HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  | -        
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          response = google_model.generate_content(
 | 
	
		
			
				|  |  |              contents=history,
 | 
	
		
			
				|  |  |              generation_config=genai.types.GenerationConfig(
 | 
	
		
			
				|  |  |                  **config_kwargs
 | 
	
		
			
				|  |  |              ),
 | 
	
		
			
				|  |  |              stream=stream,
 | 
	
		
			
				|  |  | -            safety_settings=safety_settings
 | 
	
		
			
				|  |  | +            safety_settings=safety_settings,
 | 
	
		
			
				|  |  | +            tools=self._convert_tools_to_glm_tool(tools) if tools else None,
 | 
	
		
			
				|  |  |          )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          if stream:
 | 
	
	
		
			
				|  | @@ -228,43 +257,61 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |          index = -1
 | 
	
		
			
				|  |  |          for chunk in response:
 | 
	
		
			
				|  |  | -            content = chunk.text
 | 
	
		
			
				|  |  | -            index += 1
 | 
	
		
			
				|  |  | -           
 | 
	
		
			
				|  |  | -            assistant_prompt_message = AssistantPromptMessage(
 | 
	
		
			
				|  |  | -                content=content if content else '',
 | 
	
		
			
				|  |  | -            )
 | 
	
		
			
				|  |  | -  
 | 
	
		
			
				|  |  | -            if not response._done:
 | 
	
		
			
				|  |  | -                
 | 
	
		
			
				|  |  | -                # transform assistant message to prompt message
 | 
	
		
			
				|  |  | -                yield LLMResultChunk(
 | 
	
		
			
				|  |  | -                    model=model,
 | 
	
		
			
				|  |  | -                    prompt_messages=prompt_messages,
 | 
	
		
			
				|  |  | -                    delta=LLMResultChunkDelta(
 | 
	
		
			
				|  |  | -                        index=index,
 | 
	
		
			
				|  |  | -                        message=assistant_prompt_message
 | 
	
		
			
				|  |  | -                    )
 | 
	
		
			
				|  |  | +            for part in chunk.parts:
 | 
	
		
			
				|  |  | +                assistant_prompt_message = AssistantPromptMessage(
 | 
	
		
			
				|  |  | +                    content=''
 | 
	
		
			
				|  |  |                  )
 | 
	
		
			
				|  |  | -            else:
 | 
	
		
			
				|  |  | -                
 | 
	
		
			
				|  |  | -                # calculate num tokens
 | 
	
		
			
				|  |  | -                prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
 | 
	
		
			
				|  |  | -                completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -                # transform usage
 | 
	
		
			
				|  |  | -                usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
 | 
	
		
			
				|  |  | -                
 | 
	
		
			
				|  |  | -                yield LLMResultChunk(
 | 
	
		
			
				|  |  | -                    model=model,
 | 
	
		
			
				|  |  | -                    prompt_messages=prompt_messages,
 | 
	
		
			
				|  |  | -                    delta=LLMResultChunkDelta(
 | 
	
		
			
				|  |  | -                        index=index,
 | 
	
		
			
				|  |  | -                        message=assistant_prompt_message,
 | 
	
		
			
				|  |  | -                        finish_reason=chunk.candidates[0].finish_reason,
 | 
	
		
			
				|  |  | -                        usage=usage
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                if part.text:
 | 
	
		
			
				|  |  | +                    assistant_prompt_message.content += part.text
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                if part.function_call:
 | 
	
		
			
				|  |  | +                    assistant_prompt_message.tool_calls = [
 | 
	
		
			
				|  |  | +                        AssistantPromptMessage.ToolCall(
 | 
	
		
			
				|  |  | +                            id=part.function_call.name,
 | 
	
		
			
				|  |  | +                            type='function',
 | 
	
		
			
				|  |  | +                            function=AssistantPromptMessage.ToolCall.ToolCallFunction(
 | 
	
		
			
				|  |  | +                                name=part.function_call.name,
 | 
	
		
			
				|  |  | +                                arguments=json.dumps({
 | 
	
		
			
				|  |  | +                                    key: value 
 | 
	
		
			
				|  |  | +                                    for key, value in part.function_call.args.items()
 | 
	
		
			
				|  |  | +                                })
 | 
	
		
			
				|  |  | +                            )
 | 
	
		
			
				|  |  | +                        )
 | 
	
		
			
				|  |  | +                    ]
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                index += 1
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +                if not response._done:
 | 
	
		
			
				|  |  | +                    
 | 
	
		
			
				|  |  | +                    # transform assistant message to prompt message
 | 
	
		
			
				|  |  | +                    yield LLMResultChunk(
 | 
	
		
			
				|  |  | +                        model=model,
 | 
	
		
			
				|  |  | +                        prompt_messages=prompt_messages,
 | 
	
		
			
				|  |  | +                        delta=LLMResultChunkDelta(
 | 
	
		
			
				|  |  | +                            index=index,
 | 
	
		
			
				|  |  | +                            message=assistant_prompt_message
 | 
	
		
			
				|  |  | +                        )
 | 
	
		
			
				|  |  | +                    )
 | 
	
		
			
				|  |  | +                else:
 | 
	
		
			
				|  |  | +                    
 | 
	
		
			
				|  |  | +                    # calculate num tokens
 | 
	
		
			
				|  |  | +                    prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
 | 
	
		
			
				|  |  | +                    completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                    # transform usage
 | 
	
		
			
				|  |  | +                    usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
 | 
	
		
			
				|  |  | +                    
 | 
	
		
			
				|  |  | +                    yield LLMResultChunk(
 | 
	
		
			
				|  |  | +                        model=model,
 | 
	
		
			
				|  |  | +                        prompt_messages=prompt_messages,
 | 
	
		
			
				|  |  | +                        delta=LLMResultChunkDelta(
 | 
	
		
			
				|  |  | +                            index=index,
 | 
	
		
			
				|  |  | +                            message=assistant_prompt_message,
 | 
	
		
			
				|  |  | +                            finish_reason=chunk.candidates[0].finish_reason,
 | 
	
		
			
				|  |  | +                            usage=usage
 | 
	
		
			
				|  |  | +                        )
 | 
	
		
			
				|  |  |                      )
 | 
	
		
			
				|  |  | -                )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def _convert_one_message_to_text(self, message: PromptMessage) -> str:
 | 
	
		
			
				|  |  |          """
 | 
	
	
		
			
				|  | @@ -288,6 +335,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |              message_text = f"{ai_prompt} {content}"
 | 
	
		
			
				|  |  |          elif isinstance(message, SystemPromptMessage):
 | 
	
		
			
				|  |  |              message_text = f"{human_prompt} {content}"
 | 
	
		
			
				|  |  | +        elif isinstance(message, ToolPromptMessage):
 | 
	
		
			
				|  |  | +            message_text = f"{human_prompt} {content}"
 | 
	
		
			
				|  |  |          else:
 | 
	
		
			
				|  |  |              raise ValueError(f"Got unknown type {message}")
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -300,26 +349,53 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |          :param message: one PromptMessage
 | 
	
		
			
				|  |  |          :return: glm Content representation of message
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        parts = []
 | 
	
		
			
				|  |  | -        if (isinstance(message.content, str)):
 | 
	
		
			
				|  |  | -            parts.append(to_part(message.content))
 | 
	
		
			
				|  |  | +        if isinstance(message, UserPromptMessage):
 | 
	
		
			
				|  |  | +            glm_content = {
 | 
	
		
			
				|  |  | +                "role": "user",
 | 
	
		
			
				|  |  | +                "parts": []
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +            if (isinstance(message.content, str)):
 | 
	
		
			
				|  |  | +                glm_content['parts'].append(to_part(message.content))
 | 
	
		
			
				|  |  | +            else:
 | 
	
		
			
				|  |  | +                for c in message.content:
 | 
	
		
			
				|  |  | +                    if c.type == PromptMessageContentType.TEXT:
 | 
	
		
			
				|  |  | +                        glm_content['parts'].append(to_part(c.data))
 | 
	
		
			
				|  |  | +                    else:
 | 
	
		
			
				|  |  | +                        metadata, data = c.data.split(',', 1)
 | 
	
		
			
				|  |  | +                        mime_type = metadata.split(';', 1)[0].split(':')[1]
 | 
	
		
			
				|  |  | +                        blob = {"inline_data":{"mime_type":mime_type,"data":data}}
 | 
	
		
			
				|  |  | +                        glm_content['parts'].append(blob)
 | 
	
		
			
				|  |  | +            return glm_content
 | 
	
		
			
				|  |  | +        elif isinstance(message, AssistantPromptMessage):
 | 
	
		
			
				|  |  | +            glm_content = {
 | 
	
		
			
				|  |  | +                "role": "model",
 | 
	
		
			
				|  |  | +                "parts": []
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +            if message.content:
 | 
	
		
			
				|  |  | +                glm_content['parts'].append(to_part(message.content))
 | 
	
		
			
				|  |  | +            if message.tool_calls:
 | 
	
		
			
				|  |  | +                glm_content["parts"].append(to_part(glm.FunctionCall(
 | 
	
		
			
				|  |  | +                    name=message.tool_calls[0].function.name,
 | 
	
		
			
				|  |  | +                    args=json.loads(message.tool_calls[0].function.arguments),
 | 
	
		
			
				|  |  | +                )))
 | 
	
		
			
				|  |  | +            return glm_content
 | 
	
		
			
				|  |  | +        elif isinstance(message, SystemPromptMessage):
 | 
	
		
			
				|  |  | +            return {
 | 
	
		
			
				|  |  | +                "role": "user",
 | 
	
		
			
				|  |  | +                "parts": [to_part(message.content)]
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +        elif isinstance(message, ToolPromptMessage):
 | 
	
		
			
				|  |  | +            return {
 | 
	
		
			
				|  |  | +                "role": "function",
 | 
	
		
			
				|  |  | +                "parts": [glm.Part(function_response=glm.FunctionResponse(
 | 
	
		
			
				|  |  | +                    name=message.name,
 | 
	
		
			
				|  |  | +                    response={
 | 
	
		
			
				|  |  | +                        "response": message.content
 | 
	
		
			
				|  |  | +                    }
 | 
	
		
			
				|  |  | +                ))]
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  |          else:
 | 
	
		
			
				|  |  | -            for c in message.content:
 | 
	
		
			
				|  |  | -                if c.type == PromptMessageContentType.TEXT:
 | 
	
		
			
				|  |  | -                    parts.append(to_part(c.data))
 | 
	
		
			
				|  |  | -                else:
 | 
	
		
			
				|  |  | -                    metadata, data = c.data.split(',', 1)
 | 
	
		
			
				|  |  | -                    mime_type = metadata.split(';', 1)[0].split(':')[1]
 | 
	
		
			
				|  |  | -                    blob = {"inline_data":{"mime_type":mime_type,"data":data}}
 | 
	
		
			
				|  |  | -                    parts.append(blob)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        glm_content = {
 | 
	
		
			
				|  |  | -            "role": "user" if message.role in (PromptMessageRole.USER, PromptMessageRole.SYSTEM) else "model",
 | 
	
		
			
				|  |  | -            "parts": parts
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  | -        
 | 
	
		
			
				|  |  | -        return glm_content
 | 
	
		
			
				|  |  | +            raise ValueError(f"Got unknown type {message}")
 | 
	
		
			
				|  |  |      
 | 
	
		
			
				|  |  |      @property
 | 
	
		
			
				|  |  |      def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
 |