wenxin.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. """Wrapper around Wenxin APIs."""
  2. from __future__ import annotations
  3. import json
  4. import logging
  5. from json import JSONDecodeError
  6. from typing import (
  7. Any,
  8. Dict,
  9. List,
  10. Optional, Iterator,
  11. )
  12. import requests
  13. from langchain.llms.utils import enforce_stop_tokens
  14. from langchain.schema.output import GenerationChunk
  15. from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator
  16. from langchain.callbacks.manager import (
  17. CallbackManagerForLLMRun,
  18. )
  19. from langchain.llms.base import LLM
  20. from langchain.utils import get_from_dict_or_env
  21. logger = logging.getLogger(__name__)
  22. class _WenxinEndpointClient(BaseModel):
  23. """An API client that talks to a Wenxin llm endpoint."""
  24. base_url: str = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/"
  25. secret_key: str
  26. api_key: str
  27. def get_access_token(self) -> str:
  28. url = f"https://aip.baidubce.com/oauth/2.0/token?client_id={self.api_key}" \
  29. f"&client_secret={self.secret_key}&grant_type=client_credentials"
  30. headers = {
  31. 'Content-Type': 'application/json',
  32. 'Accept': 'application/json'
  33. }
  34. response = requests.post(url, headers=headers)
  35. if not response.ok:
  36. raise ValueError(f"Wenxin HTTP {response.status_code} error: {response.text}")
  37. if 'error' in response.json():
  38. raise ValueError(
  39. f"Wenxin API {response.json()['error']}"
  40. f" error: {response.json()['error_description']}"
  41. )
  42. access_token = response.json()['access_token']
  43. # todo add cache
  44. return access_token
  45. def post(self, request: dict) -> Any:
  46. if 'model' not in request:
  47. raise ValueError(f"Wenxin Model name is required")
  48. model_url_map = {
  49. 'ernie-bot': 'completions',
  50. 'ernie-bot-turbo': 'eb-instant',
  51. 'bloomz-7b': 'bloomz_7b1',
  52. }
  53. stream = 'stream' in request and request['stream']
  54. access_token = self.get_access_token()
  55. api_url = f"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}"
  56. headers = {"Content-Type": "application/json"}
  57. response = requests.post(api_url,
  58. headers=headers,
  59. json=request,
  60. stream=stream)
  61. if not response.ok:
  62. raise ValueError(f"Wenxin HTTP {response.status_code} error: {response.text}")
  63. if not stream:
  64. json_response = response.json()
  65. if 'error_code' in json_response:
  66. raise ValueError(
  67. f"Wenxin API {json_response['error_code']}"
  68. f" error: {json_response['error_msg']}"
  69. )
  70. return json_response["result"]
  71. else:
  72. return response
  73. class Wenxin(LLM):
  74. """Wrapper around Wenxin large language models.
  75. To use, you should have the environment variable
  76. ``WENXIN_API_KEY`` and ``WENXIN_SECRET_KEY`` set with your API key,
  77. or pass them as a named parameter to the constructor.
  78. Example:
  79. .. code-block:: python
  80. from langchain.llms.wenxin import Wenxin
  81. wenxin = Wenxin(model="<model_name>", api_key="my-api-key",
  82. secret_key="my-group-id")
  83. """
  84. _client: _WenxinEndpointClient = PrivateAttr()
  85. model: str = "ernie-bot"
  86. """Model name to use."""
  87. temperature: float = 0.7
  88. """A non-negative float that tunes the degree of randomness in generation."""
  89. top_p: float = 0.95
  90. """Total probability mass of tokens to consider at each step."""
  91. model_kwargs: Dict[str, Any] = Field(default_factory=dict)
  92. """Holds any model parameters valid for `create` call not explicitly specified."""
  93. streaming: bool = False
  94. """Whether to stream the response or return it all at once."""
  95. api_key: Optional[str] = None
  96. secret_key: Optional[str] = None
  97. class Config:
  98. """Configuration for this pydantic object."""
  99. extra = Extra.forbid
  100. @root_validator()
  101. def validate_environment(cls, values: Dict) -> Dict:
  102. """Validate that api key and python package exists in environment."""
  103. values["api_key"] = get_from_dict_or_env(
  104. values, "api_key", "WENXIN_API_KEY"
  105. )
  106. values["secret_key"] = get_from_dict_or_env(
  107. values, "secret_key", "WENXIN_SECRET_KEY"
  108. )
  109. return values
  110. @property
  111. def _default_params(self) -> Dict[str, Any]:
  112. """Get the default parameters for calling OpenAI API."""
  113. return {
  114. "model": self.model,
  115. "temperature": self.temperature,
  116. "top_p": self.top_p,
  117. "stream": self.streaming,
  118. **self.model_kwargs,
  119. }
  120. @property
  121. def _identifying_params(self) -> Dict[str, Any]:
  122. """Get the identifying parameters."""
  123. return {**{"model": self.model}, **self._default_params}
  124. @property
  125. def _llm_type(self) -> str:
  126. """Return type of llm."""
  127. return "wenxin"
  128. def __init__(self, **data: Any):
  129. super().__init__(**data)
  130. self._client = _WenxinEndpointClient(
  131. api_key=self.api_key,
  132. secret_key=self.secret_key,
  133. )
  134. def _call(
  135. self,
  136. prompt: str,
  137. stop: Optional[List[str]] = None,
  138. run_manager: Optional[CallbackManagerForLLMRun] = None,
  139. **kwargs: Any,
  140. ) -> str:
  141. r"""Call out to Wenxin's completion endpoint to chat
  142. Args:
  143. prompt: The prompt to pass into the model.
  144. Returns:
  145. The string generated by the model.
  146. Example:
  147. .. code-block:: python
  148. response = wenxin("Tell me a joke.")
  149. """
  150. if self.streaming:
  151. completion = ""
  152. for chunk in self._stream(
  153. prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
  154. ):
  155. completion += chunk.text
  156. else:
  157. request = self._default_params
  158. request["messages"] = [{"role": "user", "content": prompt}]
  159. request.update(kwargs)
  160. completion = self._client.post(request)
  161. if stop is not None:
  162. completion = enforce_stop_tokens(completion, stop)
  163. return completion
  164. def _stream(
  165. self,
  166. prompt: str,
  167. stop: Optional[List[str]] = None,
  168. run_manager: Optional[CallbackManagerForLLMRun] = None,
  169. **kwargs: Any,
  170. ) -> Iterator[GenerationChunk]:
  171. r"""Call wenxin completion_stream and return the resulting generator.
  172. Args:
  173. prompt: The prompt to pass into the model.
  174. stop: Optional list of stop words to use when generating.
  175. Returns:
  176. A generator representing the stream of tokens from Wenxin.
  177. Example:
  178. .. code-block:: python
  179. prompt = "Write a poem about a stream."
  180. prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
  181. generator = wenxin.stream(prompt)
  182. for token in generator:
  183. yield token
  184. """
  185. request = self._default_params
  186. request["messages"] = [{"role": "user", "content": prompt}]
  187. request.update(kwargs)
  188. for token in self._client.post(request).iter_lines():
  189. if token:
  190. token = token.decode("utf-8")
  191. if token.startswith('data:'):
  192. completion = json.loads(token[5:])
  193. yield GenerationChunk(text=completion['result'])
  194. if run_manager:
  195. run_manager.on_llm_new_token(completion['result'])
  196. if completion['is_end']:
  197. break
  198. else:
  199. try:
  200. json_response = json.loads(token)
  201. except JSONDecodeError:
  202. raise ValueError(f"Wenxin Response Error {token}")
  203. raise ValueError(
  204. f"Wenxin API {json_response['error_code']}"
  205. f" error: {json_response['error_msg']}, "
  206. f"please confirm if the model you have chosen is already paid for."
  207. )