|  | @@ -8,12 +8,15 @@ from typing import (
 | 
	
		
			
				|  |  |      Any,
 | 
	
		
			
				|  |  |      Dict,
 | 
	
		
			
				|  |  |      List,
 | 
	
		
			
				|  |  | -    Optional, Iterator,
 | 
	
		
			
				|  |  | +    Optional, Iterator, Tuple,
 | 
	
		
			
				|  |  |  )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import requests
 | 
	
		
			
				|  |  | +from langchain.chat_models.base import BaseChatModel
 | 
	
		
			
				|  |  |  from langchain.llms.utils import enforce_stop_tokens
 | 
	
		
			
				|  |  | -from langchain.schema.output import GenerationChunk
 | 
	
		
			
				|  |  | +from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
 | 
	
		
			
				|  |  | +from langchain.schema.messages import AIMessageChunk
 | 
	
		
			
				|  |  | +from langchain.schema.output import GenerationChunk, ChatResult, ChatGenerationChunk, ChatGeneration
 | 
	
		
			
				|  |  |  from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from langchain.callbacks.manager import (
 | 
	
	
		
			
				|  | @@ -61,6 +64,7 @@ class _WenxinEndpointClient(BaseModel):
 | 
	
		
			
				|  |  |              raise ValueError(f"Wenxin Model name is required")
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          model_url_map = {
 | 
	
		
			
				|  |  | +            'ernie-bot-4': 'completions_pro',
 | 
	
		
			
				|  |  |              'ernie-bot': 'completions',
 | 
	
		
			
				|  |  |              'ernie-bot-turbo': 'eb-instant',
 | 
	
		
			
				|  |  |              'bloomz-7b': 'bloomz_7b1',
 | 
	
	
		
			
				|  | @@ -70,6 +74,7 @@ class _WenxinEndpointClient(BaseModel):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          access_token = self.get_access_token()
 | 
	
		
			
				|  |  |          api_url = f"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}"
 | 
	
		
			
				|  |  | +        del request['model']
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          headers = {"Content-Type": "application/json"}
 | 
	
		
			
				|  |  |          response = requests.post(api_url,
 | 
	
	
		
			
				|  | @@ -86,22 +91,21 @@ class _WenxinEndpointClient(BaseModel):
 | 
	
		
			
				|  |  |                      f"Wenxin API {json_response['error_code']}"
 | 
	
		
			
				|  |  |                      f" error: {json_response['error_msg']}"
 | 
	
		
			
				|  |  |                  )
 | 
	
		
			
				|  |  | -            return json_response["result"]
 | 
	
		
			
				|  |  | +            return json_response
 | 
	
		
			
				|  |  |          else:
 | 
	
		
			
				|  |  |              return response
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -class Wenxin(LLM):
 | 
	
		
			
				|  |  | -    """Wrapper around Wenxin large language models.
 | 
	
		
			
				|  |  | -    To use, you should have the environment variable
 | 
	
		
			
				|  |  | -    ``WENXIN_API_KEY`` and ``WENXIN_SECRET_KEY`` set with your API key,
 | 
	
		
			
				|  |  | -    or pass them as a named parameter to the constructor.
 | 
	
		
			
				|  |  | -    Example:
 | 
	
		
			
				|  |  | -     .. code-block:: python
 | 
	
		
			
				|  |  | -         from langchain.llms.wenxin import Wenxin
 | 
	
		
			
				|  |  | -         wenxin = Wenxin(model="<model_name>", api_key="my-api-key",
 | 
	
		
			
				|  |  | -          secret_key="my-group-id")
 | 
	
		
			
				|  |  | -    """
 | 
	
		
			
				|  |  | +class Wenxin(BaseChatModel):
 | 
	
		
			
				|  |  | +    """Wrapper around Wenxin large language models."""
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @property
 | 
	
		
			
				|  |  | +    def lc_secrets(self) -> Dict[str, str]:
 | 
	
		
			
				|  |  | +        return {"api_key": "API_KEY", "secret_key": "SECRET_KEY"}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @property
 | 
	
		
			
				|  |  | +    def lc_serializable(self) -> bool:
 | 
	
		
			
				|  |  | +        return True
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      _client: _WenxinEndpointClient = PrivateAttr()
 | 
	
		
			
				|  |  |      model: str = "ernie-bot"
 | 
	
	
		
			
				|  | @@ -161,64 +165,89 @@ class Wenxin(LLM):
 | 
	
		
			
				|  |  |              secret_key=self.secret_key,
 | 
	
		
			
				|  |  |          )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def _call(
 | 
	
		
			
				|  |  | +    def _convert_message_to_dict(self, message: BaseMessage) -> dict:
 | 
	
		
			
				|  |  | +        if isinstance(message, ChatMessage):
 | 
	
		
			
				|  |  | +            message_dict = {"role": message.role, "content": message.content}
 | 
	
		
			
				|  |  | +        elif isinstance(message, HumanMessage):
 | 
	
		
			
				|  |  | +            message_dict = {"role": "user", "content": message.content}
 | 
	
		
			
				|  |  | +        elif isinstance(message, AIMessage):
 | 
	
		
			
				|  |  | +            message_dict = {"role": "assistant", "content": message.content}
 | 
	
		
			
				|  |  | +        elif isinstance(message, SystemMessage):
 | 
	
		
			
				|  |  | +            message_dict = {"role": "system", "content": message.content}
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            raise ValueError(f"Got unknown type {message}")
 | 
	
		
			
				|  |  | +        return message_dict
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _create_message_dicts(
 | 
	
		
			
				|  |  | +        self, messages: List[BaseMessage]
 | 
	
		
			
				|  |  | +    ) -> Tuple[List[Dict[str, Any]], str]:
 | 
	
		
			
				|  |  | +        dict_messages = []
 | 
	
		
			
				|  |  | +        system = None
 | 
	
		
			
				|  |  | +        for m in messages:
 | 
	
		
			
				|  |  | +            message = self._convert_message_to_dict(m)
 | 
	
		
			
				|  |  | +            if message['role'] == 'system':
 | 
	
		
			
				|  |  | +                if not system:
 | 
	
		
			
				|  |  | +                    system = message['content']
 | 
	
		
			
				|  |  | +                else:
 | 
	
		
			
				|  |  | +                    system += f"\n{message['content']}"
 | 
	
		
			
				|  |  | +                continue
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            if dict_messages:
 | 
	
		
			
				|  |  | +                previous_message = dict_messages[-1]
 | 
	
		
			
				|  |  | +                if previous_message['role'] == message['role']:
 | 
	
		
			
				|  |  | +                    dict_messages[-1]['content'] += f"\n{message['content']}"
 | 
	
		
			
				|  |  | +                else:
 | 
	
		
			
				|  |  | +                    dict_messages.append(message)
 | 
	
		
			
				|  |  | +            else:
 | 
	
		
			
				|  |  | +                dict_messages.append(message)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return dict_messages, system
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _generate(
 | 
	
		
			
				|  |  |          self,
 | 
	
		
			
				|  |  | -        prompt: str,
 | 
	
		
			
				|  |  | +        messages: List[BaseMessage],
 | 
	
		
			
				|  |  |          stop: Optional[List[str]] = None,
 | 
	
		
			
				|  |  |          run_manager: Optional[CallbackManagerForLLMRun] = None,
 | 
	
		
			
				|  |  |          **kwargs: Any,
 | 
	
		
			
				|  |  | -    ) -> str:
 | 
	
		
			
				|  |  | -        r"""Call out to Wenxin's completion endpoint to chat
 | 
	
		
			
				|  |  | -        Args:
 | 
	
		
			
				|  |  | -            prompt: The prompt to pass into the model.
 | 
	
		
			
				|  |  | -        Returns:
 | 
	
		
			
				|  |  | -            The string generated by the model.
 | 
	
		
			
				|  |  | -        Example:
 | 
	
		
			
				|  |  | -            .. code-block:: python
 | 
	
		
			
				|  |  | -                response = wenxin("Tell me a joke.")
 | 
	
		
			
				|  |  | -        """
 | 
	
		
			
				|  |  | +    ) -> ChatResult:
 | 
	
		
			
				|  |  |          if self.streaming:
 | 
	
		
			
				|  |  | -            completion = ""
 | 
	
		
			
				|  |  | +            generation: Optional[ChatGenerationChunk] = None
 | 
	
		
			
				|  |  | +            llm_output: Optional[Dict] = None
 | 
	
		
			
				|  |  |              for chunk in self._stream(
 | 
	
		
			
				|  |  | -                prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
 | 
	
		
			
				|  |  | +                    messages=messages, stop=stop, run_manager=run_manager, **kwargs
 | 
	
		
			
				|  |  |              ):
 | 
	
		
			
				|  |  | -                completion += chunk.text
 | 
	
		
			
				|  |  | +                if chunk.generation_info is not None \
 | 
	
		
			
				|  |  | +                        and 'token_usage' in chunk.generation_info:
 | 
	
		
			
				|  |  | +                    llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                if generation is None:
 | 
	
		
			
				|  |  | +                    generation = chunk
 | 
	
		
			
				|  |  | +                else:
 | 
	
		
			
				|  |  | +                    generation += chunk
 | 
	
		
			
				|  |  | +            assert generation is not None
 | 
	
		
			
				|  |  | +            return ChatResult(generations=[generation], llm_output=llm_output)
 | 
	
		
			
				|  |  |          else:
 | 
	
		
			
				|  |  | +            message_dicts, system = self._create_message_dicts(messages)
 | 
	
		
			
				|  |  |              request = self._default_params
 | 
	
		
			
				|  |  | -            request["messages"] = [{"role": "user", "content": prompt}]
 | 
	
		
			
				|  |  | +            request["messages"] = message_dicts
 | 
	
		
			
				|  |  | +            if system:
 | 
	
		
			
				|  |  | +                request["system"] = system
 | 
	
		
			
				|  |  |              request.update(kwargs)
 | 
	
		
			
				|  |  | -            completion = self._client.post(request)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        if stop is not None:
 | 
	
		
			
				|  |  | -            completion = enforce_stop_tokens(completion, stop)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        return completion
 | 
	
		
			
				|  |  | +            response = self._client.post(request)
 | 
	
		
			
				|  |  | +            return self._create_chat_result(response)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def _stream(
 | 
	
		
			
				|  |  | -        self,
 | 
	
		
			
				|  |  | -        prompt: str,
 | 
	
		
			
				|  |  | -        stop: Optional[List[str]] = None,
 | 
	
		
			
				|  |  | -        run_manager: Optional[CallbackManagerForLLMRun] = None,
 | 
	
		
			
				|  |  | -        **kwargs: Any,
 | 
	
		
			
				|  |  | -    ) -> Iterator[GenerationChunk]:
 | 
	
		
			
				|  |  | -        r"""Call wenxin completion_stream and return the resulting generator.
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        Args:
 | 
	
		
			
				|  |  | -            prompt: The prompt to pass into the model.
 | 
	
		
			
				|  |  | -            stop: Optional list of stop words to use when generating.
 | 
	
		
			
				|  |  | -        Returns:
 | 
	
		
			
				|  |  | -            A generator representing the stream of tokens from Wenxin.
 | 
	
		
			
				|  |  | -        Example:
 | 
	
		
			
				|  |  | -            .. code-block:: python
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -                prompt = "Write a poem about a stream."
 | 
	
		
			
				|  |  | -                prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
 | 
	
		
			
				|  |  | -                generator = wenxin.stream(prompt)
 | 
	
		
			
				|  |  | -                for token in generator:
 | 
	
		
			
				|  |  | -                    yield token
 | 
	
		
			
				|  |  | -        """
 | 
	
		
			
				|  |  | +            self,
 | 
	
		
			
				|  |  | +            messages: List[BaseMessage],
 | 
	
		
			
				|  |  | +            stop: Optional[List[str]] = None,
 | 
	
		
			
				|  |  | +            run_manager: Optional[CallbackManagerForLLMRun] = None,
 | 
	
		
			
				|  |  | +            **kwargs: Any,
 | 
	
		
			
				|  |  | +    ) -> Iterator[ChatGenerationChunk]:
 | 
	
		
			
				|  |  | +        message_dicts, system = self._create_message_dicts(messages)
 | 
	
		
			
				|  |  |          request = self._default_params
 | 
	
		
			
				|  |  | -        request["messages"] = [{"role": "user", "content": prompt}]
 | 
	
		
			
				|  |  | +        request["messages"] = message_dicts
 | 
	
		
			
				|  |  | +        if system:
 | 
	
		
			
				|  |  | +            request["system"] = system
 | 
	
		
			
				|  |  |          request.update(kwargs)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          for token in self._client.post(request).iter_lines():
 | 
	
	
		
			
				|  | @@ -228,12 +257,18 @@ class Wenxin(LLM):
 | 
	
		
			
				|  |  |                  if token.startswith('data:'):
 | 
	
		
			
				|  |  |                      completion = json.loads(token[5:])
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -                    yield GenerationChunk(text=completion['result'])
 | 
	
		
			
				|  |  | -                    if run_manager:
 | 
	
		
			
				|  |  | -                        run_manager.on_llm_new_token(completion['result'])
 | 
	
		
			
				|  |  | +                    chunk_dict = {
 | 
	
		
			
				|  |  | +                        'message': AIMessageChunk(content=completion['result']),
 | 
	
		
			
				|  |  | +                    }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |                      if completion['is_end']:
 | 
	
		
			
				|  |  | -                        break
 | 
	
		
			
				|  |  | +                        token_usage = completion['usage']
 | 
	
		
			
				|  |  | +                        token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']
 | 
	
		
			
				|  |  | +                        chunk_dict['generation_info'] = dict({'token_usage': token_usage})
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                    yield ChatGenerationChunk(**chunk_dict)
 | 
	
		
			
				|  |  | +                    if run_manager:
 | 
	
		
			
				|  |  | +                        run_manager.on_llm_new_token(completion['result'])
 | 
	
		
			
				|  |  |                  else:
 | 
	
		
			
				|  |  |                      try:
 | 
	
		
			
				|  |  |                          json_response = json.loads(token)
 | 
	
	
		
			
				|  | @@ -245,3 +280,40 @@ class Wenxin(LLM):
 | 
	
		
			
				|  |  |                          f" error: {json_response['error_msg']}, "
 | 
	
		
			
				|  |  |                          f"please confirm if the model you have chosen is already paid for."
 | 
	
		
			
				|  |  |                      )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
 | 
	
		
			
				|  |  | +        generations = [ChatGeneration(
 | 
	
		
			
				|  |  | +            message=AIMessage(content=response['result']),
 | 
	
		
			
				|  |  | +        )]
 | 
	
		
			
				|  |  | +        token_usage = response.get("usage")
 | 
	
		
			
				|  |  | +        token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        llm_output = {"token_usage": token_usage, "model_name": self.model}
 | 
	
		
			
				|  |  | +        return ChatResult(generations=generations, llm_output=llm_output)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
 | 
	
		
			
				|  |  | +        """Get the number of tokens in the messages.
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        Useful for checking if an input will fit in a model's context window.
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        Args:
 | 
	
		
			
				|  |  | +            messages: The message inputs to tokenize.
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        Returns:
 | 
	
		
			
				|  |  | +            The sum of the number of tokens across the messages.
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        return sum([self.get_num_tokens(m.content) for m in messages])
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
 | 
	
		
			
				|  |  | +        overall_token_usage: dict = {}
 | 
	
		
			
				|  |  | +        for output in llm_outputs:
 | 
	
		
			
				|  |  | +            if output is None:
 | 
	
		
			
				|  |  | +                # Happens in streaming
 | 
	
		
			
				|  |  | +                continue
 | 
	
		
			
				|  |  | +            token_usage = output["token_usage"]
 | 
	
		
			
				|  |  | +            for k, v in token_usage.items():
 | 
	
		
			
				|  |  | +                if k in overall_token_usage:
 | 
	
		
			
				|  |  | +                    overall_token_usage[k] += v
 | 
	
		
			
				|  |  | +                else:
 | 
	
		
			
				|  |  | +                    overall_token_usage[k] = v
 | 
	
		
			
				|  |  | +        return {"token_usage": overall_token_usage, "model_name": self.model}
 |