|  | @@ -48,6 +48,28 @@ logger = logging.getLogger(__name__)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class BedrockLargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    # please refer to the documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html
 | 
	
		
			
				|  |  | +    # TODO There is invoke issue: context limit on Cohere Model, will add them after fixed.
 | 
	
		
			
				|  |  | +    CONVERSE_API_ENABLED_MODEL_INFO=[
 | 
	
		
			
				|  |  | +        {'prefix': 'anthropic.claude-v2', 'support_system_prompts': True, 'support_tool_use': False},
 | 
	
		
			
				|  |  | +        {'prefix': 'anthropic.claude-v1', 'support_system_prompts': True, 'support_tool_use': False},
 | 
	
		
			
				|  |  | +        {'prefix': 'anthropic.claude-3', 'support_system_prompts': True, 'support_tool_use': True},
 | 
	
		
			
				|  |  | +        {'prefix': 'meta.llama', 'support_system_prompts': True, 'support_tool_use': False},
 | 
	
		
			
				|  |  | +        {'prefix': 'mistral.mistral-7b-instruct', 'support_system_prompts': False, 'support_tool_use': False},
 | 
	
		
			
				|  |  | +        {'prefix': 'mistral.mixtral-8x7b-instruct', 'support_system_prompts': False, 'support_tool_use': False},
 | 
	
		
			
				|  |  | +        {'prefix': 'mistral.mistral-large', 'support_system_prompts': True, 'support_tool_use': True},
 | 
	
		
			
				|  |  | +        {'prefix': 'mistral.mistral-small', 'support_system_prompts': True, 'support_tool_use': True},
 | 
	
		
			
				|  |  | +        {'prefix': 'amazon.titan', 'support_system_prompts': False, 'support_tool_use': False}
 | 
	
		
			
				|  |  | +    ]
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @staticmethod
 | 
	
		
			
				|  |  | +    def _find_model_info(model_id):
 | 
	
		
			
				|  |  | +        for model in BedrockLargeLanguageModel.CONVERSE_API_ENABLED_MODEL_INFO:
 | 
	
		
			
				|  |  | +            if model_id.startswith(model['prefix']):
 | 
	
		
			
				|  |  | +                return model
 | 
	
		
			
				|  |  | +        logger.info(f"current model id: {model_id} did not support by Converse API")
 | 
	
		
			
				|  |  | +        return None
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def _invoke(self, model: str, credentials: dict,
 | 
	
		
			
				|  |  |                  prompt_messages: list[PromptMessage], model_parameters: dict,
 | 
	
		
			
				|  |  |                  tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
 | 
	
	
		
			
				|  | @@ -66,10 +88,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |          :param user: unique user id
 | 
	
		
			
				|  |  |          :return: full response or stream response chunk generator result
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  | -        # TODO: consolidate different invocation methods for models based on base model capabilities
 | 
	
		
			
				|  |  | -        # invoke anthropic models via boto3 client
 | 
	
		
			
				|  |  | -        if "anthropic" in model:
 | 
	
		
			
				|  |  | -            return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
 | 
	
		
			
				|  |  | +        
 | 
	
		
			
				|  |  | +        model_info= BedrockLargeLanguageModel._find_model_info(model)
 | 
	
		
			
				|  |  | +        if model_info:
 | 
	
		
			
				|  |  | +            model_info['model'] = model
 | 
	
		
			
				|  |  | +            # invoke models via boto3 converse API
 | 
	
		
			
				|  |  | +            return self._generate_with_converse(model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
 | 
	
		
			
				|  |  |          # invoke Cohere models via boto3 client
 | 
	
		
			
				|  |  |          if "cohere.command-r" in model:
 | 
	
		
			
				|  |  |              return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
 | 
	
	
		
			
				|  | @@ -151,12 +175,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |          return self._handle_generate_response(model, credentials, response, prompt_messages)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
 | 
	
		
			
				|  |  | +    def _generate_with_converse(self, model_info: dict, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
 | 
	
		
			
				|  |  |                  stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  | -        Invoke Anthropic large language model
 | 
	
		
			
				|  |  | +        Invoke large language model with converse API
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        :param model: model name
 | 
	
		
			
				|  |  | +        :param model_info: model information
 | 
	
		
			
				|  |  |          :param credentials: model credentials
 | 
	
		
			
				|  |  |          :param prompt_messages: prompt messages
 | 
	
		
			
				|  |  |          :param model_parameters: model parameters
 | 
	
	
		
			
				|  | @@ -173,24 +197,24 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |          inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          parameters = {
 | 
	
		
			
				|  |  | -            'modelId': model,
 | 
	
		
			
				|  |  | +            'modelId': model_info['model'],
 | 
	
		
			
				|  |  |              'messages': prompt_message_dicts,
 | 
	
		
			
				|  |  |              'inferenceConfig': inference_config,
 | 
	
		
			
				|  |  |              'additionalModelRequestFields': additional_model_fields,
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        if system and len(system) > 0:
 | 
	
		
			
				|  |  | +        if model_info['support_system_prompts'] and system and len(system) > 0:
 | 
	
		
			
				|  |  |              parameters['system'] = system
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        if tools:
 | 
	
		
			
				|  |  | +        if model_info['support_tool_use'] and tools:
 | 
	
		
			
				|  |  |              parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          if stream:
 | 
	
		
			
				|  |  |              response = bedrock_client.converse_stream(**parameters)
 | 
	
		
			
				|  |  | -            return self._handle_converse_stream_response(model, credentials, response, prompt_messages)
 | 
	
		
			
				|  |  | +            return self._handle_converse_stream_response(model_info['model'], credentials, response, prompt_messages)
 | 
	
		
			
				|  |  |          else:
 | 
	
		
			
				|  |  |              response = bedrock_client.converse(**parameters)
 | 
	
		
			
				|  |  | -            return self._handle_converse_response(model, credentials, response, prompt_messages)
 | 
	
		
			
				|  |  | +            return self._handle_converse_response(model_info['model'], credentials, response, prompt_messages)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def _handle_converse_response(self, model: str, credentials: dict, response: dict,
 | 
	
		
			
				|  |  |                                  prompt_messages: list[PromptMessage]) -> LLMResult:
 | 
	
	
		
			
				|  | @@ -203,10 +227,30 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |          :param prompt_messages: prompt messages
 | 
	
		
			
				|  |  |          :return: full response chunk generator result
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  | +        response_content = response['output']['message']['content']
 | 
	
		
			
				|  |  |          # transform assistant message to prompt message
 | 
	
		
			
				|  |  | -        assistant_prompt_message = AssistantPromptMessage(
 | 
	
		
			
				|  |  | -            content=response['output']['message']['content'][0]['text']
 | 
	
		
			
				|  |  | -        )
 | 
	
		
			
				|  |  | +        if response['stopReason'] == 'tool_use':
 | 
	
		
			
				|  |  | +            tool_calls = []
 | 
	
		
			
				|  |  | +            text, tool_use = self._extract_tool_use(response_content)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            tool_call = AssistantPromptMessage.ToolCall(
 | 
	
		
			
				|  |  | +                id=tool_use['toolUseId'],
 | 
	
		
			
				|  |  | +                type='function',
 | 
	
		
			
				|  |  | +                function=AssistantPromptMessage.ToolCall.ToolCallFunction(
 | 
	
		
			
				|  |  | +                    name=tool_use['name'],
 | 
	
		
			
				|  |  | +                    arguments=json.dumps(tool_use['input'])
 | 
	
		
			
				|  |  | +                )
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +            tool_calls.append(tool_call)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            assistant_prompt_message = AssistantPromptMessage(
 | 
	
		
			
				|  |  | +                content=text,
 | 
	
		
			
				|  |  | +                tool_calls=tool_calls
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            assistant_prompt_message = AssistantPromptMessage(
 | 
	
		
			
				|  |  | +                content=response_content[0]['text']
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          # calculate num tokens
 | 
	
		
			
				|  |  |          if response['usage']:
 | 
	
	
		
			
				|  | @@ -229,6 +273,18 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |          )
 | 
	
		
			
				|  |  |          return result
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    def _extract_tool_use(self, content:dict)-> tuple[str, dict]:
 | 
	
		
			
				|  |  | +        tool_use = {}
 | 
	
		
			
				|  |  | +        text = ''
 | 
	
		
			
				|  |  | +        for item in content:
 | 
	
		
			
				|  |  | +            if 'toolUse' in item:
 | 
	
		
			
				|  |  | +                tool_use = item['toolUse']
 | 
	
		
			
				|  |  | +            elif 'text' in item:
 | 
	
		
			
				|  |  | +                text = item['text']
 | 
	
		
			
				|  |  | +            else:
 | 
	
		
			
				|  |  | +                raise ValueError(f"Got unknown item: {item}")
 | 
	
		
			
				|  |  | +        return text, tool_use
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def _handle_converse_stream_response(self, model: str, credentials: dict, response: dict,
 | 
	
		
			
				|  |  |                                          prompt_messages: list[PromptMessage], ) -> Generator:
 | 
	
		
			
				|  |  |          """
 | 
	
	
		
			
				|  | @@ -340,14 +396,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          system = []
 | 
	
		
			
				|  |  | +        prompt_message_dicts = []
 | 
	
		
			
				|  |  |          for message in prompt_messages:
 | 
	
		
			
				|  |  |              if isinstance(message, SystemPromptMessage):
 | 
	
		
			
				|  |  |                  message.content=message.content.strip()
 | 
	
		
			
				|  |  |                  system.append({"text": message.content})
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        prompt_message_dicts = []
 | 
	
		
			
				|  |  | -        for message in prompt_messages:
 | 
	
		
			
				|  |  | -            if not isinstance(message, SystemPromptMessage):
 | 
	
		
			
				|  |  | +            else:
 | 
	
		
			
				|  |  |                  prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          return system, prompt_message_dicts
 | 
	
	
		
			
				|  | @@ -448,7 +502,6 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
 | 
	
		
			
				|  |  |              }
 | 
	
		
			
				|  |  |          else:
 | 
	
		
			
				|  |  |              raise ValueError(f"Got unknown type {message}")
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |          return message_dict
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage] | str,
 |