|  | @@ -57,7 +57,7 @@ class BaiduAccessToken:
 | 
	
		
			
				|  |  |                  raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
 | 
	
		
			
				|  |  |              else:
 | 
	
		
			
				|  |  |                  raise Exception(f'Unknown error: {resp["error_description"]}')
 | 
	
		
			
				|  |  | -                
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          return resp['access_token']
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @staticmethod
 | 
	
	
		
			
				|  | @@ -114,7 +114,7 @@ class ErnieMessage:
 | 
	
		
			
				|  |  |              'role': self.role,
 | 
	
		
			
				|  |  |              'content': self.content,
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  | -    
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def __init__(self, content: str, role: str = 'user') -> None:
 | 
	
		
			
				|  |  |          self.content = content
 | 
	
		
			
				|  |  |          self.role = role
 | 
	
	
		
			
				|  | @@ -131,6 +131,7 @@ class ErnieBotModel:
 | 
	
		
			
				|  |  |          'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
 | 
	
		
			
				|  |  |          'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
 | 
	
		
			
				|  |  |          'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
 | 
	
		
			
				|  |  | +        'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
 | 
	
		
			
				|  |  |          'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
 | 
	
		
			
				|  |  |          'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
 | 
	
		
			
				|  |  |          'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
 | 
	
	
		
			
				|  | @@ -157,7 +158,7 @@ class ErnieBotModel:
 | 
	
		
			
				|  |  |          self.api_key = api_key
 | 
	
		
			
				|  |  |          self.secret_key = secret_key
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def generate(self, model: str, stream: bool, messages: list[ErnieMessage], 
 | 
	
		
			
				|  |  | +    def generate(self, model: str, stream: bool, messages: list[ErnieMessage],
 | 
	
		
			
				|  |  |                   parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \
 | 
	
		
			
				|  |  |                   stop: list[str], user: str) \
 | 
	
		
			
				|  |  |          -> Union[Generator[ErnieMessage, None, None], ErnieMessage]:
 | 
	
	
		
			
				|  | @@ -189,7 +190,7 @@ class ErnieBotModel:
 | 
	
		
			
				|  |  |          if stream:
 | 
	
		
			
				|  |  |              return self._handle_chat_stream_generate_response(resp)
 | 
	
		
			
				|  |  |          return self._handle_chat_generate_response(resp)
 | 
	
		
			
				|  |  | -    
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def _handle_error(self, code: int, msg: str):
 | 
	
		
			
				|  |  |          error_map = {
 | 
	
		
			
				|  |  |              1: InternalServerError,
 | 
	
	
		
			
				|  | @@ -234,15 +235,15 @@ class ErnieBotModel:
 | 
	
		
			
				|  |  |      def _get_access_token(self) -> str:
 | 
	
		
			
				|  |  |          token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
 | 
	
		
			
				|  |  |          return token.access_token
 | 
	
		
			
				|  |  | -    
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]:
 | 
	
		
			
				|  |  |          return [ErnieMessage(message.content, message.role) for message in messages]
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def _check_parameters(self, model: str, parameters: dict[str, Any], 
 | 
	
		
			
				|  |  | +    def _check_parameters(self, model: str, parameters: dict[str, Any],
 | 
	
		
			
				|  |  |                            tools: list[PromptMessageTool], stop: list[str]) -> None:
 | 
	
		
			
				|  |  |          if model not in self.api_bases:
 | 
	
		
			
				|  |  |              raise BadRequestError(f'Invalid model: {model}')
 | 
	
		
			
				|  |  | -        
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          # if model not in self.function_calling_supports and tools is not None and len(tools) > 0:
 | 
	
		
			
				|  |  |          #     raise BadRequestError(f'Model {model} does not support calling function.')
 | 
	
		
			
				|  |  |          # ErnieBot supports function calling, however, there is lots of limitations.
 | 
	
	
		
			
				|  | @@ -259,32 +260,32 @@ class ErnieBotModel:
 | 
	
		
			
				|  |  |              for s in stop:
 | 
	
		
			
				|  |  |                  if len(s) > 20:
 | 
	
		
			
				|  |  |                      raise BadRequestError('stop item should not exceed 20 characters.')
 | 
	
		
			
				|  |  | -        
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def _build_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, parameters: dict[str, Any],
 | 
	
		
			
				|  |  |                              tools: list[PromptMessageTool], stop: list[str], user: str) -> dict[str, Any]:
 | 
	
		
			
				|  |  |          # if model in self.function_calling_supports:
 | 
	
		
			
				|  |  |          #     return self._build_function_calling_request_body(model, messages, parameters, tools, stop, user)
 | 
	
		
			
				|  |  |          return self._build_chat_request_body(model, messages, stream, parameters, stop, user)
 | 
	
		
			
				|  |  | -        
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def _build_function_calling_request_body(self, model: str, messages: list[ErnieMessage], stream: bool,
 | 
	
		
			
				|  |  | -                                                parameters: dict[str, Any], tools: list[PromptMessageTool], 
 | 
	
		
			
				|  |  | +                                                parameters: dict[str, Any], tools: list[PromptMessageTool],
 | 
	
		
			
				|  |  |                                                  stop: list[str], user: str) \
 | 
	
		
			
				|  |  |          -> dict[str, Any]:
 | 
	
		
			
				|  |  |          if len(messages) % 2 == 0:
 | 
	
		
			
				|  |  |              raise BadRequestError('The number of messages should be odd.')
 | 
	
		
			
				|  |  |          if messages[0].role == 'function':
 | 
	
		
			
				|  |  |              raise BadRequestError('The first message should be user message.')
 | 
	
		
			
				|  |  | -        
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |          TODO: implement function calling
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def _build_chat_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, 
 | 
	
		
			
				|  |  | +    def _build_chat_request_body(self, model: str, messages: list[ErnieMessage], stream: bool,
 | 
	
		
			
				|  |  |                                   parameters: dict[str, Any], stop: list[str], user: str) \
 | 
	
		
			
				|  |  |          -> dict[str, Any]:
 | 
	
		
			
				|  |  |          if len(messages) == 0:
 | 
	
		
			
				|  |  |              raise BadRequestError('The number of messages should not be zero.')
 | 
	
		
			
				|  |  | -        
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          # check if the first element is system, shift it
 | 
	
		
			
				|  |  |          system_message = ''
 | 
	
		
			
				|  |  |          if messages[0].role == 'system':
 | 
	
	
		
			
				|  | @@ -313,7 +314,7 @@ class ErnieBotModel:
 | 
	
		
			
				|  |  |              body['system'] = system_message
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          return body
 | 
	
		
			
				|  |  | -        
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def _handle_chat_generate_response(self, response: Response) -> ErnieMessage:
 | 
	
		
			
				|  |  |          data = response.json()
 | 
	
		
			
				|  |  |          if 'error_code' in data:
 | 
	
	
		
			
				|  | @@ -349,7 +350,7 @@ class ErnieBotModel:
 | 
	
		
			
				|  |  |                          self._handle_error(code, msg)
 | 
	
		
			
				|  |  |                  except Exception as e:
 | 
	
		
			
				|  |  |                      raise InternalServerError(f'Failed to parse response: {e}')
 | 
	
		
			
				|  |  | -            
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |              if line.startswith('data:'):
 | 
	
		
			
				|  |  |                  line = line[5:].strip()
 | 
	
		
			
				|  |  |              else:
 | 
	
	
		
			
				|  | @@ -361,7 +362,7 @@ class ErnieBotModel:
 | 
	
		
			
				|  |  |                  data = loads(line)
 | 
	
		
			
				|  |  |              except Exception as e:
 | 
	
		
			
				|  |  |                  raise InternalServerError(f'Failed to parse response: {e}')
 | 
	
		
			
				|  |  | -            
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |              result = data['result']
 | 
	
		
			
				|  |  |              is_end = data['is_end']
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -379,4 +380,4 @@ class ErnieBotModel:
 | 
	
		
			
				|  |  |                  yield message
 | 
	
		
			
				|  |  |              else:
 | 
	
		
			
				|  |  |                  message = ErnieMessage(content=result, role='assistant')
 | 
	
		
			
				|  |  | -                yield message
 | 
	
		
			
				|  |  | +                yield message
 |