| 
					
				 | 
			
			
				@@ -1,20 +1,38 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import json 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import logging 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from collections.abc import Generator 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from collections.abc import Generator, Iterator 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from typing import Optional, Union, cast 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import cohere 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from cohere.responses import Chat, Generations 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from cohere.responses.chat import StreamEnd, StreamingChat, StreamTextGeneration 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from cohere.responses.generation import StreamingGenerations, StreamingText 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from cohere import ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ChatMessage, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ChatStreamRequestToolResultsItem, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    GenerateStreamedResponse, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    GenerateStreamedResponse_StreamEnd, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    GenerateStreamedResponse_StreamError, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    GenerateStreamedResponse_TextGeneration, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    Generation, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    NonStreamedChatResponse, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    StreamedChatResponse, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    StreamedChatResponse_StreamEnd, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    StreamedChatResponse_TextGeneration, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    StreamedChatResponse_ToolCallsGeneration, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    Tool, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ToolCall, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ToolParameterDefinitionsValue, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from cohere.core import RequestOptions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.model_runtime.entities.message_entities import ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     AssistantPromptMessage, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     PromptMessage, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     PromptMessageContentType, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    PromptMessageRole, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     PromptMessageTool, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     SystemPromptMessage, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     TextPromptMessageContent, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ToolPromptMessage, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     UserPromptMessage, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -64,6 +82,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 credentials=credentials, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 prompt_messages=prompt_messages, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 model_parameters=model_parameters, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                tools=tools, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 stop=stop, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 stream=stream, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 user=user 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -159,19 +178,26 @@ class CohereLargeLanguageModel(LargeLanguageModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if stop: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             model_parameters['end_sequences'] = stop 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        response = client.generate( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            prompt=prompt_messages[0].content, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            model=model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            stream=stream, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            **model_parameters, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if stream: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            response = client.generate_stream( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                prompt=prompt_messages[0].content, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                model=model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                **model_parameters, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                request_options=RequestOptions(max_retries=0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             return self._handle_generate_stream_response(model, credentials, response, prompt_messages) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            response = client.generate( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                prompt=prompt_messages[0].content, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                model=model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                **model_parameters, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                request_options=RequestOptions(max_retries=0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return self._handle_generate_response(model, credentials, response, prompt_messages) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return self._handle_generate_response(model, credentials, response, prompt_messages) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def _handle_generate_response(self, model: str, credentials: dict, response: Generations, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _handle_generate_response(self, model: str, credentials: dict, response: Generation, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                                   prompt_messages: list[PromptMessage]) \ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             -> LLMResult: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -191,8 +217,8 @@ class CohereLargeLanguageModel(LargeLanguageModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # calculate num tokens 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        prompt_tokens = response.meta['billed_units']['input_tokens'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        completion_tokens = response.meta['billed_units']['output_tokens'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        prompt_tokens = int(response.meta.billed_units.input_tokens) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        completion_tokens = int(response.meta.billed_units.output_tokens) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # transform usage 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -207,7 +233,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return response 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def _handle_generate_stream_response(self, model: str, credentials: dict, response: StreamingGenerations, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _handle_generate_stream_response(self, model: str, credentials: dict, response: Iterator[GenerateStreamedResponse], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                                          prompt_messages: list[PromptMessage]) -> Generator: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         Handle llm stream response 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -220,8 +246,8 @@ class CohereLargeLanguageModel(LargeLanguageModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         index = 1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         full_assistant_content = '' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         for chunk in response: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            if isinstance(chunk, StreamingText): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                chunk = cast(StreamingText, chunk) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if isinstance(chunk, GenerateStreamedResponse_TextGeneration): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                chunk = cast(GenerateStreamedResponse_TextGeneration, chunk) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 text = chunk.text 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 if text is None: 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -244,10 +270,16 @@ class CohereLargeLanguageModel(LargeLanguageModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 index += 1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            elif chunk is None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            elif isinstance(chunk, GenerateStreamedResponse_StreamEnd): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                chunk = cast(GenerateStreamedResponse_StreamEnd, chunk) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 # calculate num tokens 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                prompt_tokens = response.meta['billed_units']['input_tokens'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                completion_tokens = response.meta['billed_units']['output_tokens'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                completion_tokens = self._num_tokens_from_messages( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    credentials, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    [AssistantPromptMessage(content=full_assistant_content)] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 # transform usage 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -258,14 +290,18 @@ class CohereLargeLanguageModel(LargeLanguageModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     delta=LLMResultChunkDelta( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         index=index, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         message=AssistantPromptMessage(content=''), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                        finish_reason=response.finish_reason, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        finish_reason=chunk.finish_reason, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         usage=usage 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            elif isinstance(chunk, GenerateStreamedResponse_StreamError): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                chunk = cast(GenerateStreamedResponse_StreamError, chunk) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                raise InvokeBadRequestError(chunk.err) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def _chat_generate(self, model: str, credentials: dict, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                       prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                       prompt_messages: list[PromptMessage], model_parameters: dict, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                       tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                        stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         Invoke llm chat model 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -274,6 +310,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :param credentials: credentials 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :param prompt_messages: prompt messages 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :param model_parameters: model parameters 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param tools: tools for tool calling 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :param stop: stop words 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :param stream: is stream response 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :param user: unique user id 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -282,31 +319,46 @@ class CohereLargeLanguageModel(LargeLanguageModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # initialize client 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         client = cohere.Client(credentials.get('api_key')) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if user: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            model_parameters['user_name'] = user 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if stop: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            model_parameters['stop_sequences'] = stop 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if tools: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            model_parameters['tools'] = self._convert_tools(tools) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        message, chat_histories = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        message, chat_histories, tool_results \ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if tool_results: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            model_parameters['tool_results'] = tool_results 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # chat model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         real_model = model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             real_model = model.removesuffix('-chat') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        response = client.chat( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            message=message, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            chat_history=chat_histories, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            model=real_model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            stream=stream, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            **model_parameters, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if stream: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, stop) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            response = client.chat_stream( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                message=message, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                chat_history=chat_histories, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                model=real_model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                **model_parameters, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                request_options=RequestOptions(max_retries=0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return self._handle_chat_generate_response(model, credentials, response, prompt_messages, stop) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            response = client.chat( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                message=message, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                chat_history=chat_histories, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                model=real_model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                **model_parameters, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                request_options=RequestOptions(max_retries=0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def _handle_chat_generate_response(self, model: str, credentials: dict, response: Chat, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                       prompt_messages: list[PromptMessage], stop: Optional[list[str]] = None) \ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return self._handle_chat_generate_response(model, credentials, response, prompt_messages) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _handle_chat_generate_response(self, model: str, credentials: dict, response: NonStreamedChatResponse, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                       prompt_messages: list[PromptMessage]) \ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             -> LLMResult: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         Handle llm chat response 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -315,14 +367,27 @@ class CohereLargeLanguageModel(LargeLanguageModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :param credentials: credentials 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :param response: response 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :param prompt_messages: prompt messages 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        :param stop: stop words 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :return: llm response 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         assistant_text = response.text 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        tool_calls = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if response.tool_calls: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for cohere_tool_call in response.tool_calls: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                tool_call = AssistantPromptMessage.ToolCall( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    id=cohere_tool_call.name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    type='function', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    function=AssistantPromptMessage.ToolCall.ToolCallFunction( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        name=cohere_tool_call.name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        arguments=json.dumps(cohere_tool_call.parameters) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                tool_calls.append(tool_call) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # transform assistant message to prompt message 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         assistant_prompt_message = AssistantPromptMessage( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            content=assistant_text 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            content=assistant_text, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            tool_calls=tool_calls 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # calculate num tokens 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -332,44 +397,38 @@ class CohereLargeLanguageModel(LargeLanguageModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # transform usage 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if stop: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            # enforce stop tokens 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            assistant_text = self.enforce_stop_tokens(assistant_text, stop) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            assistant_prompt_message = AssistantPromptMessage( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                content=assistant_text 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # transform response 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         response = LLMResult( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             model=model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             prompt_messages=prompt_messages, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             message=assistant_prompt_message, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            usage=usage, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            system_fingerprint=response.preamble 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            usage=usage 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return response 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: StreamingChat, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                              prompt_messages: list[PromptMessage], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                              stop: Optional[list[str]] = None) -> Generator: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _handle_chat_generate_stream_response(self, model: str, credentials: dict, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                              response: Iterator[StreamedChatResponse], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                              prompt_messages: list[PromptMessage]) -> Generator: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         Handle llm chat stream response 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :param model: model name 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :param response: response 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :param prompt_messages: prompt messages 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        :param stop: stop words 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :return: llm response chunk generator 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        def final_response(full_text: str, index: int, finish_reason: Optional[str] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                           preamble: Optional[str] = None) -> LLMResultChunk: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        def final_response(full_text: str, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                           tool_calls: list[AssistantPromptMessage.ToolCall], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                           index: int, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                           finish_reason: Optional[str] = None) -> LLMResultChunk: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             # calculate num tokens 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             full_assistant_prompt_message = AssistantPromptMessage( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                content=full_text 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                content=full_text, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                tool_calls=tool_calls 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -379,10 +438,9 @@ class CohereLargeLanguageModel(LargeLanguageModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             return LLMResultChunk( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 model=model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 prompt_messages=prompt_messages, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                system_fingerprint=preamble, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 delta=LLMResultChunkDelta( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     index=index, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    message=AssistantPromptMessage(content=''), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    message=AssistantPromptMessage(content='', tool_calls=tool_calls), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     finish_reason=finish_reason, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     usage=usage 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 ) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -390,9 +448,10 @@ class CohereLargeLanguageModel(LargeLanguageModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         index = 1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         full_assistant_content = '' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        tool_calls = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         for chunk in response: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            if isinstance(chunk, StreamTextGeneration): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                chunk = cast(StreamTextGeneration, chunk) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if isinstance(chunk, StreamedChatResponse_TextGeneration): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                chunk = cast(StreamedChatResponse_TextGeneration, chunk) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 text = chunk.text 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 if text is None: 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -403,12 +462,6 @@ class CohereLargeLanguageModel(LargeLanguageModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     content=text 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                # stop 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                # notice: This logic can only cover few stop scenarios 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                if stop and text in stop: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    yield final_response(full_assistant_content, index, 'stop') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 full_assistant_content += text 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 yield LLMResultChunk( 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -421,39 +474,98 @@ class CohereLargeLanguageModel(LargeLanguageModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 index += 1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            elif isinstance(chunk, StreamEnd): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                chunk = cast(StreamEnd, chunk) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                yield final_response(full_assistant_content, index, chunk.finish_reason, response.preamble) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            elif isinstance(chunk, StreamedChatResponse_ToolCallsGeneration): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                chunk = cast(StreamedChatResponse_ToolCallsGeneration, chunk) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                tool_calls = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if chunk.tool_calls: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    for cohere_tool_call in chunk.tool_calls: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        tool_call = AssistantPromptMessage.ToolCall( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            id=cohere_tool_call.name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            type='function', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            function=AssistantPromptMessage.ToolCall.ToolCallFunction( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                name=cohere_tool_call.name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                arguments=json.dumps(cohere_tool_call.parameters) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        tool_calls.append(tool_call) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            elif isinstance(chunk, StreamedChatResponse_StreamEnd): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                chunk = cast(StreamedChatResponse_StreamEnd, chunk) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                yield final_response(full_assistant_content, tool_calls, index, chunk.finish_reason) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 index += 1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            -> tuple[str, list[dict]]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            -> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         Convert prompt messages to message and chat histories 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :param prompt_messages: prompt messages 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :return: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         chat_histories = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        latest_tool_call_n_outputs = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         for prompt_message in prompt_messages: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            chat_histories.append(self._convert_prompt_message_to_dict(prompt_message)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if prompt_message.role == PromptMessageRole.ASSISTANT: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                prompt_message = cast(AssistantPromptMessage, prompt_message) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if prompt_message.tool_calls: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    for tool_call in prompt_message.tool_calls: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        latest_tool_call_n_outputs.append(ChatStreamRequestToolResultsItem( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            call=ToolCall( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                name=tool_call.function.name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                parameters=json.loads(tool_call.function.arguments) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            ), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            outputs=[] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        )) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    if cohere_prompt_message: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        chat_histories.append(cohere_prompt_message) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            elif prompt_message.role == PromptMessageRole.TOOL: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                prompt_message = cast(ToolPromptMessage, prompt_message) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if latest_tool_call_n_outputs: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    i = 0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    for tool_call_n_outputs in latest_tool_call_n_outputs: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        if tool_call_n_outputs.call.name == prompt_message.tool_call_id: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            latest_tool_call_n_outputs[i] = ChatStreamRequestToolResultsItem( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                call=ToolCall( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                    name=tool_call_n_outputs.call.name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                    parameters=tool_call_n_outputs.call.parameters 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                ), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                outputs=[{ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                    "result": prompt_message.content 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                }] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        i += 1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if cohere_prompt_message: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    chat_histories.append(cohere_prompt_message) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if latest_tool_call_n_outputs: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            new_latest_tool_call_n_outputs = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for tool_call_n_outputs in latest_tool_call_n_outputs: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if tool_call_n_outputs.outputs: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    new_latest_tool_call_n_outputs.append(tool_call_n_outputs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            latest_tool_call_n_outputs = new_latest_tool_call_n_outputs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # get latest message from chat histories and pop it 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if len(chat_histories) > 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             latest_message = chat_histories.pop() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            message = latest_message['message'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            message = latest_message.message 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             raise ValueError('Prompt messages is empty') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return message, chat_histories 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return message, chat_histories, latest_tool_call_n_outputs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _convert_prompt_message_to_dict(self, message: PromptMessage) -> Optional[ChatMessage]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         Convert PromptMessage to dict for Cohere model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if isinstance(message, UserPromptMessage): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             message = cast(UserPromptMessage, message) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if isinstance(message.content, str): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                message_dict = {"role": "USER", "message": message.content} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                chat_message = ChatMessage(role="USER", message=message.content) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 sub_message_text = '' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 for message_content in message.content: 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -461,20 +573,57 @@ class CohereLargeLanguageModel(LargeLanguageModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         message_content = cast(TextPromptMessageContent, message_content) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         sub_message_text += message_content.data 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                message_dict = {"role": "USER", "message": sub_message_text} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                chat_message = ChatMessage(role="USER", message=sub_message_text) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         elif isinstance(message, AssistantPromptMessage): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             message = cast(AssistantPromptMessage, message) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            message_dict = {"role": "CHATBOT", "message": message.content} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if not message.content: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                return None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            chat_message = ChatMessage(role="CHATBOT", message=message.content) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         elif isinstance(message, SystemPromptMessage): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             message = cast(SystemPromptMessage, message) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            message_dict = {"role": "USER", "message": message.content} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            chat_message = ChatMessage(role="USER", message=message.content) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif isinstance(message, ToolPromptMessage): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             raise ValueError(f"Got unknown type {message}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if message.name: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            message_dict["user_name"] = message.name 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return chat_message 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _convert_tools(self, tools: list[PromptMessageTool]) -> list[Tool]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Convert tools to Cohere model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        cohere_tools = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for tool in tools: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            properties = tool.parameters['properties'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            required_properties = tool.parameters['required'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            parameter_definitions = {} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for p_key, p_val in properties.items(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                required = False 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if property in required_properties: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    required = True 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                desc = p_val['description'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if 'enum' in p_val: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    desc += (f"; Only accepts one of the following predefined options: " 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                             f"[{', '.join(p_val['enum'])}]") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                parameter_definitions[p_key] = ToolParameterDefinitionsValue( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    description=desc, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    type=p_val['type'], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    required=required 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return message_dict 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            cohere_tool = Tool( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                name=tool.name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                description=tool.description, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                parameter_definitions=parameter_definitions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            cohere_tools.append(cohere_tool) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return cohere_tools 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def _num_tokens_from_string(self, model: str, credentials: dict, text: str) -> int: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -493,12 +642,16 @@ class CohereLargeLanguageModel(LargeLanguageModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             model=model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return response.length 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return len(response.tokens) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def _num_tokens_from_messages(self, model: str, credentials: dict, messages: list[PromptMessage]) -> int: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """Calculate num tokens Cohere model.""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        messages = [self._convert_prompt_message_to_dict(m) for m in messages] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        message_strs = [f"{message['role']}: {message['message']}" for message in messages] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        calc_messages = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for message in messages: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            cohere_message = self._convert_prompt_message_to_dict(message) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if cohere_message: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                calc_messages.append(cohere_message) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        message_strs = [f"{message.role}: {message.message}" for message in calc_messages] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         message_str = "\n".join(message_strs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         real_model = model 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -564,13 +717,21 @@ class CohereLargeLanguageModel(LargeLanguageModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             InvokeConnectionError: [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                cohere.CohereConnectionError 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                cohere.errors.service_unavailable_error.ServiceUnavailableError 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            InvokeServerUnavailableError: [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                cohere.errors.internal_server_error.InternalServerError 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            InvokeRateLimitError: [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                cohere.errors.too_many_requests_error.TooManyRequestsError 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            InvokeAuthorizationError: [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                cohere.errors.unauthorized_error.UnauthorizedError, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                cohere.errors.forbidden_error.ForbiddenError 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             ], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            InvokeServerUnavailableError: [], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            InvokeRateLimitError: [], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            InvokeAuthorizationError: [], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             InvokeBadRequestError: [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                cohere.CohereAPIError, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                cohere.CohereError, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                cohere.core.api_error.ApiError, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                cohere.errors.bad_request_error.BadRequestError, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                cohere.errors.not_found_error.NotFoundError, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             ] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         } 
			 |