|  | @@ -0,0 +1,565 @@
 | 
	
		
			
				|  |  | +import logging
 | 
	
		
			
				|  |  | +from typing import Generator, List, Optional, Union, cast, Tuple
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +import cohere
 | 
	
		
			
				|  |  | +from cohere.responses import Chat, Generations
 | 
	
		
			
				|  |  | +from cohere.responses.chat import StreamingChat, StreamTextGeneration, StreamEnd
 | 
	
		
			
				|  |  | +from cohere.responses.generation import StreamingText, StreamingGenerations
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
 | 
	
		
			
				|  |  | +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage,
 | 
	
		
			
				|  |  | +                                                          PromptMessageContentType, SystemPromptMessage,
 | 
	
		
			
				|  |  | +                                                          TextPromptMessageContent, UserPromptMessage,
 | 
	
		
			
				|  |  | +                                                          PromptMessageTool)
 | 
	
		
			
				|  |  | +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
 | 
	
		
			
				|  |  | +from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeError, \
 | 
	
		
			
				|  |  | +    InvokeRateLimitError, InvokeAuthorizationError, InvokeBadRequestError
 | 
	
		
			
				|  |  | +from core.model_runtime.errors.validate import CredentialsValidateFailedError
 | 
	
		
			
				|  |  | +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +logger = logging.getLogger(__name__)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class CohereLargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  | +    """
 | 
	
		
			
				|  |  | +    Model class for Cohere large language model.
 | 
	
		
			
				|  |  | +    """
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _invoke(self, model: str, credentials: dict,
 | 
	
		
			
				|  |  | +                prompt_messages: list[PromptMessage], model_parameters: dict,
 | 
	
		
			
				|  |  | +                tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
 | 
	
		
			
				|  |  | +                stream: bool = True, user: Optional[str] = None) \
 | 
	
		
			
				|  |  | +            -> Union[LLMResult, Generator]:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        Invoke large language model
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        :param model: model name
 | 
	
		
			
				|  |  | +        :param credentials: model credentials
 | 
	
		
			
				|  |  | +        :param prompt_messages: prompt messages
 | 
	
		
			
				|  |  | +        :param model_parameters: model parameters
 | 
	
		
			
				|  |  | +        :param tools: tools for tool calling
 | 
	
		
			
				|  |  | +        :param stop: stop words
 | 
	
		
			
				|  |  | +        :param stream: is stream response
 | 
	
		
			
				|  |  | +        :param user: unique user id
 | 
	
		
			
				|  |  | +        :return: full response or stream response chunk generator result
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        # get model mode
 | 
	
		
			
				|  |  | +        model_mode = self.get_model_mode(model, credentials)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if model_mode == LLMMode.CHAT:
 | 
	
		
			
				|  |  | +            return self._chat_generate(
 | 
	
		
			
				|  |  | +                model=model,
 | 
	
		
			
				|  |  | +                credentials=credentials,
 | 
	
		
			
				|  |  | +                prompt_messages=prompt_messages,
 | 
	
		
			
				|  |  | +                model_parameters=model_parameters,
 | 
	
		
			
				|  |  | +                stop=stop,
 | 
	
		
			
				|  |  | +                stream=stream,
 | 
	
		
			
				|  |  | +                user=user
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            return self._generate(
 | 
	
		
			
				|  |  | +                model=model,
 | 
	
		
			
				|  |  | +                credentials=credentials,
 | 
	
		
			
				|  |  | +                prompt_messages=prompt_messages,
 | 
	
		
			
				|  |  | +                model_parameters=model_parameters,
 | 
	
		
			
				|  |  | +                stop=stop,
 | 
	
		
			
				|  |  | +                stream=stream,
 | 
	
		
			
				|  |  | +                user=user
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
 | 
	
		
			
				|  |  | +                       tools: Optional[list[PromptMessageTool]] = None) -> int:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        Get number of tokens for given prompt messages
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        :param model: model name
 | 
	
		
			
				|  |  | +        :param credentials: model credentials
 | 
	
		
			
				|  |  | +        :param prompt_messages: prompt messages
 | 
	
		
			
				|  |  | +        :param tools: tools for tool calling
 | 
	
		
			
				|  |  | +        :return:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        # get model mode
 | 
	
		
			
				|  |  | +        model_mode = self.get_model_mode(model)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            if model_mode == LLMMode.CHAT:
 | 
	
		
			
				|  |  | +                return self._num_tokens_from_messages(model, credentials, prompt_messages)
 | 
	
		
			
				|  |  | +            else:
 | 
	
		
			
				|  |  | +                return self._num_tokens_from_string(model, credentials, prompt_messages[0].content)
 | 
	
		
			
				|  |  | +        except Exception as e:
 | 
	
		
			
				|  |  | +            raise self._transform_invoke_error(e)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def validate_credentials(self, model: str, credentials: dict) -> None:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        Validate model credentials
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        :param model: model name
 | 
	
		
			
				|  |  | +        :param credentials: model credentials
 | 
	
		
			
				|  |  | +        :return:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            # get model mode
 | 
	
		
			
				|  |  | +            model_mode = self.get_model_mode(model)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            if model_mode == LLMMode.CHAT:
 | 
	
		
			
				|  |  | +                self._chat_generate(
 | 
	
		
			
				|  |  | +                    model=model,
 | 
	
		
			
				|  |  | +                    credentials=credentials,
 | 
	
		
			
				|  |  | +                    prompt_messages=[UserPromptMessage(content='ping')],
 | 
	
		
			
				|  |  | +                    model_parameters={
 | 
	
		
			
				|  |  | +                        'max_tokens': 20,
 | 
	
		
			
				|  |  | +                        'temperature': 0,
 | 
	
		
			
				|  |  | +                    },
 | 
	
		
			
				|  |  | +                    stream=False
 | 
	
		
			
				|  |  | +                )
 | 
	
		
			
				|  |  | +            else:
 | 
	
		
			
				|  |  | +                self._generate(
 | 
	
		
			
				|  |  | +                    model=model,
 | 
	
		
			
				|  |  | +                    credentials=credentials,
 | 
	
		
			
				|  |  | +                    prompt_messages=[UserPromptMessage(content='ping')],
 | 
	
		
			
				|  |  | +                    model_parameters={
 | 
	
		
			
				|  |  | +                        'max_tokens': 20,
 | 
	
		
			
				|  |  | +                        'temperature': 0,
 | 
	
		
			
				|  |  | +                    },
 | 
	
		
			
				|  |  | +                    stream=False
 | 
	
		
			
				|  |  | +                )
 | 
	
		
			
				|  |  | +        except Exception as ex:
 | 
	
		
			
				|  |  | +            raise CredentialsValidateFailedError(str(ex))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _generate(self, model: str, credentials: dict,
 | 
	
		
			
				|  |  | +                  prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None,
 | 
	
		
			
				|  |  | +                  stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        Invoke llm model
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        :param model: model name
 | 
	
		
			
				|  |  | +        :param credentials: credentials
 | 
	
		
			
				|  |  | +        :param prompt_messages: prompt messages
 | 
	
		
			
				|  |  | +        :param model_parameters: model parameters
 | 
	
		
			
				|  |  | +        :param stop: stop words
 | 
	
		
			
				|  |  | +        :param stream: is stream response
 | 
	
		
			
				|  |  | +        :param user: unique user id
 | 
	
		
			
				|  |  | +        :return: full response or stream response chunk generator result
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        # initialize client
 | 
	
		
			
				|  |  | +        client = cohere.Client(credentials.get('api_key'))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if stop:
 | 
	
		
			
				|  |  | +            model_parameters['end_sequences'] = stop
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        response = client.generate(
 | 
	
		
			
				|  |  | +            prompt=prompt_messages[0].content,
 | 
	
		
			
				|  |  | +            model=model,
 | 
	
		
			
				|  |  | +            stream=stream,
 | 
	
		
			
				|  |  | +            **model_parameters,
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if stream:
 | 
	
		
			
				|  |  | +            return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return self._handle_generate_response(model, credentials, response, prompt_messages)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _handle_generate_response(self, model: str, credentials: dict, response: Generations,
 | 
	
		
			
				|  |  | +                                  prompt_messages: list[PromptMessage]) \
 | 
	
		
			
				|  |  | +            -> LLMResult:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        Handle llm response
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        :param model: model name
 | 
	
		
			
				|  |  | +        :param credentials: credentials
 | 
	
		
			
				|  |  | +        :param response: response
 | 
	
		
			
				|  |  | +        :param prompt_messages: prompt messages
 | 
	
		
			
				|  |  | +        :return: llm response
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        assistant_text = response.generations[0].text
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # transform assistant message to prompt message
 | 
	
		
			
				|  |  | +        assistant_prompt_message = AssistantPromptMessage(
 | 
	
		
			
				|  |  | +            content=assistant_text
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # calculate num tokens
 | 
	
		
			
				|  |  | +        prompt_tokens = response.meta['billed_units']['input_tokens']
 | 
	
		
			
				|  |  | +        completion_tokens = response.meta['billed_units']['output_tokens']
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # transform usage
 | 
	
		
			
				|  |  | +        usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # transform response
 | 
	
		
			
				|  |  | +        response = LLMResult(
 | 
	
		
			
				|  |  | +            model=model,
 | 
	
		
			
				|  |  | +            prompt_messages=prompt_messages,
 | 
	
		
			
				|  |  | +            message=assistant_prompt_message,
 | 
	
		
			
				|  |  | +            usage=usage
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return response
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _handle_generate_stream_response(self, model: str, credentials: dict, response: StreamingGenerations,
 | 
	
		
			
				|  |  | +                                         prompt_messages: list[PromptMessage]) -> Generator:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        Handle llm stream response
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        :param model: model name
 | 
	
		
			
				|  |  | +        :param response: response
 | 
	
		
			
				|  |  | +        :param prompt_messages: prompt messages
 | 
	
		
			
				|  |  | +        :return: llm response chunk generator
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        index = 1
 | 
	
		
			
				|  |  | +        full_assistant_content = ''
 | 
	
		
			
				|  |  | +        for chunk in response:
 | 
	
		
			
				|  |  | +            if isinstance(chunk, StreamingText):
 | 
	
		
			
				|  |  | +                chunk = cast(StreamingText, chunk)
 | 
	
		
			
				|  |  | +                text = chunk.text
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                if text is None:
 | 
	
		
			
				|  |  | +                    continue
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                # transform assistant message to prompt message
 | 
	
		
			
				|  |  | +                assistant_prompt_message = AssistantPromptMessage(
 | 
	
		
			
				|  |  | +                    content=text
 | 
	
		
			
				|  |  | +                )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                full_assistant_content += text
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                yield LLMResultChunk(
 | 
	
		
			
				|  |  | +                    model=model,
 | 
	
		
			
				|  |  | +                    prompt_messages=prompt_messages,
 | 
	
		
			
				|  |  | +                    delta=LLMResultChunkDelta(
 | 
	
		
			
				|  |  | +                        index=index,
 | 
	
		
			
				|  |  | +                        message=assistant_prompt_message,
 | 
	
		
			
				|  |  | +                    )
 | 
	
		
			
				|  |  | +                )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                index += 1
 | 
	
		
			
				|  |  | +            elif chunk is None:
 | 
	
		
			
				|  |  | +                # calculate num tokens
 | 
	
		
			
				|  |  | +                prompt_tokens = response.meta['billed_units']['input_tokens']
 | 
	
		
			
				|  |  | +                completion_tokens = response.meta['billed_units']['output_tokens']
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                # transform usage
 | 
	
		
			
				|  |  | +                usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                yield LLMResultChunk(
 | 
	
		
			
				|  |  | +                    model=model,
 | 
	
		
			
				|  |  | +                    prompt_messages=prompt_messages,
 | 
	
		
			
				|  |  | +                    delta=LLMResultChunkDelta(
 | 
	
		
			
				|  |  | +                        index=index,
 | 
	
		
			
				|  |  | +                        message=AssistantPromptMessage(content=''),
 | 
	
		
			
				|  |  | +                        finish_reason=response.finish_reason,
 | 
	
		
			
				|  |  | +                        usage=usage
 | 
	
		
			
				|  |  | +                    )
 | 
	
		
			
				|  |  | +                )
 | 
	
		
			
				|  |  | +                break
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _chat_generate(self, model: str, credentials: dict,
 | 
	
		
			
				|  |  | +                       prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None,
 | 
	
		
			
				|  |  | +                       stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        Invoke llm chat model
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        :param model: model name
 | 
	
		
			
				|  |  | +        :param credentials: credentials
 | 
	
		
			
				|  |  | +        :param prompt_messages: prompt messages
 | 
	
		
			
				|  |  | +        :param model_parameters: model parameters
 | 
	
		
			
				|  |  | +        :param stop: stop words
 | 
	
		
			
				|  |  | +        :param stream: is stream response
 | 
	
		
			
				|  |  | +        :param user: unique user id
 | 
	
		
			
				|  |  | +        :return: full response or stream response chunk generator result
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        # initialize client
 | 
	
		
			
				|  |  | +        client = cohere.Client(credentials.get('api_key'))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if user:
 | 
	
		
			
				|  |  | +            model_parameters['user_name'] = user
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        message, chat_histories = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # chat model
 | 
	
		
			
				|  |  | +        real_model = model
 | 
	
		
			
				|  |  | +        if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL:
 | 
	
		
			
				|  |  | +            real_model = model.removesuffix('-chat')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        response = client.chat(
 | 
	
		
			
				|  |  | +            message=message,
 | 
	
		
			
				|  |  | +            chat_history=chat_histories,
 | 
	
		
			
				|  |  | +            model=real_model,
 | 
	
		
			
				|  |  | +            stream=stream,
 | 
	
		
			
				|  |  | +            return_preamble=True,
 | 
	
		
			
				|  |  | +            **model_parameters,
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if stream:
 | 
	
		
			
				|  |  | +            return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, stop)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return self._handle_chat_generate_response(model, credentials, response, prompt_messages, stop)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _handle_chat_generate_response(self, model: str, credentials: dict, response: Chat,
 | 
	
		
			
				|  |  | +                                       prompt_messages: list[PromptMessage], stop: Optional[List[str]] = None) \
 | 
	
		
			
				|  |  | +            -> LLMResult:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        Handle llm chat response
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        :param model: model name
 | 
	
		
			
				|  |  | +        :param credentials: credentials
 | 
	
		
			
				|  |  | +        :param response: response
 | 
	
		
			
				|  |  | +        :param prompt_messages: prompt messages
 | 
	
		
			
				|  |  | +        :param stop: stop words
 | 
	
		
			
				|  |  | +        :return: llm response
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        assistant_text = response.text
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # transform assistant message to prompt message
 | 
	
		
			
				|  |  | +        assistant_prompt_message = AssistantPromptMessage(
 | 
	
		
			
				|  |  | +            content=assistant_text
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # calculate num tokens
 | 
	
		
			
				|  |  | +        prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
 | 
	
		
			
				|  |  | +        completion_tokens = self._num_tokens_from_messages(model, credentials, [assistant_prompt_message])
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # transform usage
 | 
	
		
			
				|  |  | +        usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if stop:
 | 
	
		
			
				|  |  | +            # enforce stop tokens
 | 
	
		
			
				|  |  | +            assistant_text = self.enforce_stop_tokens(assistant_text, stop)
 | 
	
		
			
				|  |  | +            assistant_prompt_message = AssistantPromptMessage(
 | 
	
		
			
				|  |  | +                content=assistant_text
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # transform response
 | 
	
		
			
				|  |  | +        response = LLMResult(
 | 
	
		
			
				|  |  | +            model=model,
 | 
	
		
			
				|  |  | +            prompt_messages=prompt_messages,
 | 
	
		
			
				|  |  | +            message=assistant_prompt_message,
 | 
	
		
			
				|  |  | +            usage=usage,
 | 
	
		
			
				|  |  | +            system_fingerprint=response.preamble
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return response
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: StreamingChat,
 | 
	
		
			
				|  |  | +                                              prompt_messages: list[PromptMessage],
 | 
	
		
			
				|  |  | +                                              stop: Optional[List[str]] = None) -> Generator:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        Handle llm chat stream response
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        :param model: model name
 | 
	
		
			
				|  |  | +        :param response: response
 | 
	
		
			
				|  |  | +        :param prompt_messages: prompt messages
 | 
	
		
			
				|  |  | +        :param stop: stop words
 | 
	
		
			
				|  |  | +        :return: llm response chunk generator
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        def final_response(full_text: str, index: int, finish_reason: Optional[str] = None,
 | 
	
		
			
				|  |  | +                           preamble: Optional[str] = None) -> LLMResultChunk:
 | 
	
		
			
				|  |  | +            # calculate num tokens
 | 
	
		
			
				|  |  | +            prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            full_assistant_prompt_message = AssistantPromptMessage(
 | 
	
		
			
				|  |  | +                content=full_text
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +            completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message])
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            # transform usage
 | 
	
		
			
				|  |  | +            usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            return LLMResultChunk(
 | 
	
		
			
				|  |  | +                model=model,
 | 
	
		
			
				|  |  | +                prompt_messages=prompt_messages,
 | 
	
		
			
				|  |  | +                system_fingerprint=preamble,
 | 
	
		
			
				|  |  | +                delta=LLMResultChunkDelta(
 | 
	
		
			
				|  |  | +                    index=index,
 | 
	
		
			
				|  |  | +                    message=AssistantPromptMessage(content=''),
 | 
	
		
			
				|  |  | +                    finish_reason=finish_reason,
 | 
	
		
			
				|  |  | +                    usage=usage
 | 
	
		
			
				|  |  | +                )
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        index = 1
 | 
	
		
			
				|  |  | +        full_assistant_content = ''
 | 
	
		
			
				|  |  | +        for chunk in response:
 | 
	
		
			
				|  |  | +            if isinstance(chunk, StreamTextGeneration):
 | 
	
		
			
				|  |  | +                chunk = cast(StreamTextGeneration, chunk)
 | 
	
		
			
				|  |  | +                text = chunk.text
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                if text is None:
 | 
	
		
			
				|  |  | +                    continue
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                # transform assistant message to prompt message
 | 
	
		
			
				|  |  | +                assistant_prompt_message = AssistantPromptMessage(
 | 
	
		
			
				|  |  | +                    content=text
 | 
	
		
			
				|  |  | +                )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                # stop
 | 
	
		
			
				|  |  | +                # notice: This logic can only cover few stop scenarios
 | 
	
		
			
				|  |  | +                if stop and text in stop:
 | 
	
		
			
				|  |  | +                    yield final_response(full_assistant_content, index, 'stop')
 | 
	
		
			
				|  |  | +                    break
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                full_assistant_content += text
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                yield LLMResultChunk(
 | 
	
		
			
				|  |  | +                    model=model,
 | 
	
		
			
				|  |  | +                    prompt_messages=prompt_messages,
 | 
	
		
			
				|  |  | +                    delta=LLMResultChunkDelta(
 | 
	
		
			
				|  |  | +                        index=index,
 | 
	
		
			
				|  |  | +                        message=assistant_prompt_message,
 | 
	
		
			
				|  |  | +                    )
 | 
	
		
			
				|  |  | +                )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                index += 1
 | 
	
		
			
				|  |  | +            elif isinstance(chunk, StreamEnd):
 | 
	
		
			
				|  |  | +                chunk = cast(StreamEnd, chunk)
 | 
	
		
			
				|  |  | +                yield final_response(full_assistant_content, index, chunk.finish_reason, response.preamble)
 | 
	
		
			
				|  |  | +                index += 1
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \
 | 
	
		
			
				|  |  | +            -> Tuple[str, list[dict]]:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        Convert prompt messages to message and chat histories
 | 
	
		
			
				|  |  | +        :param prompt_messages: prompt messages
 | 
	
		
			
				|  |  | +        :return:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        chat_histories = []
 | 
	
		
			
				|  |  | +        for prompt_message in prompt_messages:
 | 
	
		
			
				|  |  | +            chat_histories.append(self._convert_prompt_message_to_dict(prompt_message))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # get latest message from chat histories and pop it
 | 
	
		
			
				|  |  | +        if len(chat_histories) > 0:
 | 
	
		
			
				|  |  | +            latest_message = chat_histories.pop()
 | 
	
		
			
				|  |  | +            message = latest_message['message']
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            raise ValueError('Prompt messages is empty')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return message, chat_histories
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        Convert PromptMessage to dict for Cohere model
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        if isinstance(message, UserPromptMessage):
 | 
	
		
			
				|  |  | +            message = cast(UserPromptMessage, message)
 | 
	
		
			
				|  |  | +            if isinstance(message.content, str):
 | 
	
		
			
				|  |  | +                message_dict = {"role": "USER", "message": message.content}
 | 
	
		
			
				|  |  | +            else:
 | 
	
		
			
				|  |  | +                sub_message_text = ''
 | 
	
		
			
				|  |  | +                for message_content in message.content:
 | 
	
		
			
				|  |  | +                    if message_content.type == PromptMessageContentType.TEXT:
 | 
	
		
			
				|  |  | +                        message_content = cast(TextPromptMessageContent, message_content)
 | 
	
		
			
				|  |  | +                        sub_message_text += message_content.data
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                message_dict = {"role": "USER", "message": sub_message_text}
 | 
	
		
			
				|  |  | +        elif isinstance(message, AssistantPromptMessage):
 | 
	
		
			
				|  |  | +            message = cast(AssistantPromptMessage, message)
 | 
	
		
			
				|  |  | +            message_dict = {"role": "CHATBOT", "message": message.content}
 | 
	
		
			
				|  |  | +        elif isinstance(message, SystemPromptMessage):
 | 
	
		
			
				|  |  | +            message = cast(SystemPromptMessage, message)
 | 
	
		
			
				|  |  | +            message_dict = {"role": "USER", "message": message.content}
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            raise ValueError(f"Got unknown type {message}")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if message.name is not None:
 | 
	
		
			
				|  |  | +            message_dict["user_name"] = message.name
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return message_dict
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _num_tokens_from_string(self, model: str, credentials: dict, text: str) -> int:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        Calculate num tokens for text completion model.
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        :param model: model name
 | 
	
		
			
				|  |  | +        :param credentials: credentials
 | 
	
		
			
				|  |  | +        :param text: prompt text
 | 
	
		
			
				|  |  | +        :return: number of tokens
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        # initialize client
 | 
	
		
			
				|  |  | +        client = cohere.Client(credentials.get('api_key'))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        response = client.tokenize(
 | 
	
		
			
				|  |  | +            text=text,
 | 
	
		
			
				|  |  | +            model=model
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return response.length
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _num_tokens_from_messages(self, model: str, credentials: dict, messages: List[PromptMessage]) -> int:
 | 
	
		
			
				|  |  | +        """Calculate num tokens Cohere model."""
 | 
	
		
			
				|  |  | +        messages = [self._convert_prompt_message_to_dict(m) for m in messages]
 | 
	
		
			
				|  |  | +        message_strs = [f"{message['role']}: {message['message']}" for message in messages]
 | 
	
		
			
				|  |  | +        message_str = "\n".join(message_strs)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        real_model = model
 | 
	
		
			
				|  |  | +        if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL:
 | 
	
		
			
				|  |  | +            real_model = model.removesuffix('-chat')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return self._num_tokens_from_string(real_model, credentials, message_str)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +            Cohere supports fine-tuning of their models. This method returns the schema of the base model
 | 
	
		
			
				|  |  | +            but renamed to the fine-tuned model name.
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            :param model: model name
 | 
	
		
			
				|  |  | +            :param credentials: credentials
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            :return: model schema
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        # get model schema
 | 
	
		
			
				|  |  | +        models = self.predefined_models()
 | 
	
		
			
				|  |  | +        model_map = {model.model: model for model in models}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        mode = credentials.get('mode')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if mode == 'chat':
 | 
	
		
			
				|  |  | +            base_model_schema = model_map['command-light-chat']
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            base_model_schema = model_map['command-light']
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        base_model_schema = cast(AIModelEntity, base_model_schema)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        base_model_schema_features = base_model_schema.features or []
 | 
	
		
			
				|  |  | +        base_model_schema_model_properties = base_model_schema.model_properties or {}
 | 
	
		
			
				|  |  | +        base_model_schema_parameters_rules = base_model_schema.parameter_rules or []
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        entity = AIModelEntity(
 | 
	
		
			
				|  |  | +            model=model,
 | 
	
		
			
				|  |  | +            label=I18nObject(
 | 
	
		
			
				|  |  | +                zh_Hans=model,
 | 
	
		
			
				|  |  | +                en_US=model
 | 
	
		
			
				|  |  | +            ),
 | 
	
		
			
				|  |  | +            model_type=ModelType.LLM,
 | 
	
		
			
				|  |  | +            features=[feature for feature in base_model_schema_features],
 | 
	
		
			
				|  |  | +            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
 | 
	
		
			
				|  |  | +            model_properties={
 | 
	
		
			
				|  |  | +                key: property for key, property in base_model_schema_model_properties.items()
 | 
	
		
			
				|  |  | +            },
 | 
	
		
			
				|  |  | +            parameter_rules=[rule for rule in base_model_schema_parameters_rules],
 | 
	
		
			
				|  |  | +            pricing=base_model_schema.pricing
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return entity
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @property
 | 
	
		
			
				|  |  | +    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        Map model invoke error to unified error
 | 
	
		
			
				|  |  | +        The key is the error type thrown to the caller
 | 
	
		
			
				|  |  | +        The value is the error type thrown by the model,
 | 
	
		
			
				|  |  | +        which needs to be converted into a unified error type for the caller.
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        :return: Invoke error mapping
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        return {
 | 
	
		
			
				|  |  | +            InvokeConnectionError: [
 | 
	
		
			
				|  |  | +                cohere.CohereConnectionError
 | 
	
		
			
				|  |  | +            ],
 | 
	
		
			
				|  |  | +            InvokeServerUnavailableError: [],
 | 
	
		
			
				|  |  | +            InvokeRateLimitError: [],
 | 
	
		
			
				|  |  | +            InvokeAuthorizationError: [],
 | 
	
		
			
				|  |  | +            InvokeBadRequestError: [
 | 
	
		
			
				|  |  | +                cohere.CohereAPIError,
 | 
	
		
			
				|  |  | +                cohere.CohereError,
 | 
	
		
			
				|  |  | +            ]
 | 
	
		
			
				|  |  | +        }
 |