wenxin.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. """Wrapper around Wenxin APIs."""
  2. from __future__ import annotations
  3. import json
  4. import logging
  5. from json import JSONDecodeError
  6. from typing import (
  7. Any,
  8. Dict,
  9. List,
  10. Optional, Iterator, Tuple,
  11. )
  12. import requests
  13. from langchain.chat_models.base import BaseChatModel
  14. from langchain.llms.utils import enforce_stop_tokens
  15. from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
  16. from langchain.schema.messages import AIMessageChunk
  17. from langchain.schema.output import GenerationChunk, ChatResult, ChatGenerationChunk, ChatGeneration
  18. from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator
  19. from langchain.callbacks.manager import (
  20. CallbackManagerForLLMRun,
  21. )
  22. from langchain.llms.base import LLM
  23. from langchain.utils import get_from_dict_or_env
  24. logger = logging.getLogger(__name__)
  25. class _WenxinEndpointClient(BaseModel):
  26. """An API client that talks to a Wenxin llm endpoint."""
  27. base_url: str = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/"
  28. secret_key: str
  29. api_key: str
  30. def get_access_token(self) -> str:
  31. url = f"https://aip.baidubce.com/oauth/2.0/token?client_id={self.api_key}" \
  32. f"&client_secret={self.secret_key}&grant_type=client_credentials"
  33. headers = {
  34. 'Content-Type': 'application/json',
  35. 'Accept': 'application/json'
  36. }
  37. response = requests.post(url, headers=headers)
  38. if not response.ok:
  39. raise ValueError(f"Wenxin HTTP {response.status_code} error: {response.text}")
  40. if 'error' in response.json():
  41. raise ValueError(
  42. f"Wenxin API {response.json()['error']}"
  43. f" error: {response.json()['error_description']}"
  44. )
  45. access_token = response.json()['access_token']
  46. # todo add cache
  47. return access_token
  48. def post(self, request: dict) -> Any:
  49. if 'model' not in request:
  50. raise ValueError(f"Wenxin Model name is required")
  51. model_url_map = {
  52. 'ernie-bot-4': 'completions_pro',
  53. 'ernie-bot': 'completions',
  54. 'ernie-bot-turbo': 'eb-instant',
  55. 'bloomz-7b': 'bloomz_7b1',
  56. }
  57. stream = 'stream' in request and request['stream']
  58. access_token = self.get_access_token()
  59. api_url = f"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}"
  60. del request['model']
  61. headers = {"Content-Type": "application/json"}
  62. response = requests.post(api_url,
  63. headers=headers,
  64. json=request,
  65. stream=stream)
  66. if not response.ok:
  67. raise ValueError(f"Wenxin HTTP {response.status_code} error: {response.text}")
  68. if not stream:
  69. json_response = response.json()
  70. if 'error_code' in json_response:
  71. raise ValueError(
  72. f"Wenxin API {json_response['error_code']}"
  73. f" error: {json_response['error_msg']}"
  74. )
  75. return json_response
  76. else:
  77. return response
  78. class Wenxin(BaseChatModel):
  79. """Wrapper around Wenxin large language models."""
  80. @property
  81. def lc_secrets(self) -> Dict[str, str]:
  82. return {"api_key": "API_KEY", "secret_key": "SECRET_KEY"}
  83. @property
  84. def lc_serializable(self) -> bool:
  85. return True
  86. _client: _WenxinEndpointClient = PrivateAttr()
  87. model: str = "ernie-bot"
  88. """Model name to use."""
  89. temperature: float = 0.7
  90. """A non-negative float that tunes the degree of randomness in generation."""
  91. top_p: float = 0.95
  92. """Total probability mass of tokens to consider at each step."""
  93. model_kwargs: Dict[str, Any] = Field(default_factory=dict)
  94. """Holds any model parameters valid for `create` call not explicitly specified."""
  95. streaming: bool = False
  96. """Whether to stream the response or return it all at once."""
  97. api_key: Optional[str] = None
  98. secret_key: Optional[str] = None
  99. class Config:
  100. """Configuration for this pydantic object."""
  101. extra = Extra.forbid
  102. @root_validator()
  103. def validate_environment(cls, values: Dict) -> Dict:
  104. """Validate that api key and python package exists in environment."""
  105. values["api_key"] = get_from_dict_or_env(
  106. values, "api_key", "WENXIN_API_KEY"
  107. )
  108. values["secret_key"] = get_from_dict_or_env(
  109. values, "secret_key", "WENXIN_SECRET_KEY"
  110. )
  111. return values
  112. @property
  113. def _default_params(self) -> Dict[str, Any]:
  114. """Get the default parameters for calling OpenAI API."""
  115. return {
  116. "model": self.model,
  117. "temperature": self.temperature,
  118. "top_p": self.top_p,
  119. "stream": self.streaming,
  120. **self.model_kwargs,
  121. }
  122. @property
  123. def _identifying_params(self) -> Dict[str, Any]:
  124. """Get the identifying parameters."""
  125. return {**{"model": self.model}, **self._default_params}
  126. @property
  127. def _llm_type(self) -> str:
  128. """Return type of llm."""
  129. return "wenxin"
  130. def __init__(self, **data: Any):
  131. super().__init__(**data)
  132. self._client = _WenxinEndpointClient(
  133. api_key=self.api_key,
  134. secret_key=self.secret_key,
  135. )
  136. def _convert_message_to_dict(self, message: BaseMessage) -> dict:
  137. if isinstance(message, ChatMessage):
  138. message_dict = {"role": message.role, "content": message.content}
  139. elif isinstance(message, HumanMessage):
  140. message_dict = {"role": "user", "content": message.content}
  141. elif isinstance(message, AIMessage):
  142. message_dict = {"role": "assistant", "content": message.content}
  143. elif isinstance(message, SystemMessage):
  144. message_dict = {"role": "system", "content": message.content}
  145. else:
  146. raise ValueError(f"Got unknown type {message}")
  147. return message_dict
  148. def _create_message_dicts(
  149. self, messages: List[BaseMessage]
  150. ) -> Tuple[List[Dict[str, Any]], str]:
  151. dict_messages = []
  152. system = None
  153. for m in messages:
  154. message = self._convert_message_to_dict(m)
  155. if message['role'] == 'system':
  156. if not system:
  157. system = message['content']
  158. else:
  159. system += f"\n{message['content']}"
  160. continue
  161. if dict_messages:
  162. previous_message = dict_messages[-1]
  163. if previous_message['role'] == message['role']:
  164. dict_messages[-1]['content'] += f"\n{message['content']}"
  165. else:
  166. dict_messages.append(message)
  167. else:
  168. dict_messages.append(message)
  169. return dict_messages, system
  170. def _generate(
  171. self,
  172. messages: List[BaseMessage],
  173. stop: Optional[List[str]] = None,
  174. run_manager: Optional[CallbackManagerForLLMRun] = None,
  175. **kwargs: Any,
  176. ) -> ChatResult:
  177. if self.streaming:
  178. generation: Optional[ChatGenerationChunk] = None
  179. llm_output: Optional[Dict] = None
  180. for chunk in self._stream(
  181. messages=messages, stop=stop, run_manager=run_manager, **kwargs
  182. ):
  183. if chunk.generation_info is not None \
  184. and 'token_usage' in chunk.generation_info:
  185. llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
  186. if generation is None:
  187. generation = chunk
  188. else:
  189. generation += chunk
  190. assert generation is not None
  191. return ChatResult(generations=[generation], llm_output=llm_output)
  192. else:
  193. message_dicts, system = self._create_message_dicts(messages)
  194. request = self._default_params
  195. request["messages"] = message_dicts
  196. if system:
  197. request["system"] = system
  198. request.update(kwargs)
  199. response = self._client.post(request)
  200. return self._create_chat_result(response)
  201. def _stream(
  202. self,
  203. messages: List[BaseMessage],
  204. stop: Optional[List[str]] = None,
  205. run_manager: Optional[CallbackManagerForLLMRun] = None,
  206. **kwargs: Any,
  207. ) -> Iterator[ChatGenerationChunk]:
  208. message_dicts, system = self._create_message_dicts(messages)
  209. request = self._default_params
  210. request["messages"] = message_dicts
  211. if system:
  212. request["system"] = system
  213. request.update(kwargs)
  214. for token in self._client.post(request).iter_lines():
  215. if token:
  216. token = token.decode("utf-8")
  217. if token.startswith('data:'):
  218. completion = json.loads(token[5:])
  219. chunk_dict = {
  220. 'message': AIMessageChunk(content=completion['result']),
  221. }
  222. if completion['is_end']:
  223. token_usage = completion['usage']
  224. token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']
  225. chunk_dict['generation_info'] = dict({'token_usage': token_usage})
  226. yield ChatGenerationChunk(**chunk_dict)
  227. if run_manager:
  228. run_manager.on_llm_new_token(completion['result'])
  229. else:
  230. try:
  231. json_response = json.loads(token)
  232. except JSONDecodeError:
  233. raise ValueError(f"Wenxin Response Error {token}")
  234. raise ValueError(
  235. f"Wenxin API {json_response['error_code']}"
  236. f" error: {json_response['error_msg']}, "
  237. f"please confirm if the model you have chosen is already paid for."
  238. )
  239. def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
  240. generations = [ChatGeneration(
  241. message=AIMessage(content=response['result']),
  242. )]
  243. token_usage = response.get("usage")
  244. token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']
  245. llm_output = {"token_usage": token_usage, "model_name": self.model}
  246. return ChatResult(generations=generations, llm_output=llm_output)
  247. def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
  248. """Get the number of tokens in the messages.
  249. Useful for checking if an input will fit in a model's context window.
  250. Args:
  251. messages: The message inputs to tokenize.
  252. Returns:
  253. The sum of the number of tokens across the messages.
  254. """
  255. return sum([self.get_num_tokens(m.content) for m in messages])
  256. def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
  257. overall_token_usage: dict = {}
  258. for output in llm_outputs:
  259. if output is None:
  260. # Happens in streaming
  261. continue
  262. token_usage = output["token_usage"]
  263. for k, v in token_usage.items():
  264. if k in overall_token_usage:
  265. overall_token_usage[k] += v
  266. else:
  267. overall_token_usage[k] = v
  268. return {"token_usage": overall_token_usage, "model_name": self.model}