wenxin.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. """Wrapper around Wenxin APIs."""
  2. from __future__ import annotations
  3. import json
  4. import logging
  5. from typing import (
  6. Any,
  7. Dict,
  8. List,
  9. Optional, Iterator,
  10. )
  11. import requests
  12. from langchain.llms.utils import enforce_stop_tokens
  13. from langchain.schema.output import GenerationChunk
  14. from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator
  15. from langchain.callbacks.manager import (
  16. CallbackManagerForLLMRun,
  17. )
  18. from langchain.llms.base import LLM
  19. from langchain.utils import get_from_dict_or_env
  20. logger = logging.getLogger(__name__)
  21. class _WenxinEndpointClient(BaseModel):
  22. """An API client that talks to a Wenxin llm endpoint."""
  23. base_url: str = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/"
  24. secret_key: str
  25. api_key: str
  26. def get_access_token(self) -> str:
  27. url = f"https://aip.baidubce.com/oauth/2.0/token?client_id={self.api_key}" \
  28. f"&client_secret={self.secret_key}&grant_type=client_credentials"
  29. headers = {
  30. 'Content-Type': 'application/json',
  31. 'Accept': 'application/json'
  32. }
  33. response = requests.post(url, headers=headers)
  34. if not response.ok:
  35. raise ValueError(f"Wenxin HTTP {response.status_code} error: {response.text}")
  36. if 'error' in response.json():
  37. raise ValueError(
  38. f"Wenxin API {response.json()['error']}"
  39. f" error: {response.json()['error_description']}"
  40. )
  41. access_token = response.json()['access_token']
  42. # todo add cache
  43. return access_token
  44. def post(self, request: dict) -> Any:
  45. if 'model' not in request:
  46. raise ValueError(f"Wenxin Model name is required")
  47. model_url_map = {
  48. 'ernie-bot': 'completions',
  49. 'ernie-bot-turbo': 'eb-instant',
  50. 'bloomz-7b': 'bloomz_7b1',
  51. }
  52. stream = 'stream' in request and request['stream']
  53. access_token = self.get_access_token()
  54. api_url = f"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}"
  55. headers = {"Content-Type": "application/json"}
  56. response = requests.post(api_url,
  57. headers=headers,
  58. json=request,
  59. stream=stream)
  60. if not response.ok:
  61. raise ValueError(f"Wenxin HTTP {response.status_code} error: {response.text}")
  62. if not stream:
  63. json_response = response.json()
  64. if 'error_code' in json_response:
  65. raise ValueError(
  66. f"Wenxin API {json_response['error_code']}"
  67. f" error: {json_response['error_msg']}"
  68. )
  69. return json_response["result"]
  70. else:
  71. return response
  72. class Wenxin(LLM):
  73. """Wrapper around Wenxin large language models.
  74. To use, you should have the environment variable
  75. ``WENXIN_API_KEY`` and ``WENXIN_SECRET_KEY`` set with your API key,
  76. or pass them as a named parameter to the constructor.
  77. Example:
  78. .. code-block:: python
  79. from langchain.llms.wenxin import Wenxin
  80. wenxin = Wenxin(model="<model_name>", api_key="my-api-key",
  81. secret_key="my-group-id")
  82. """
  83. _client: _WenxinEndpointClient = PrivateAttr()
  84. model: str = "ernie-bot"
  85. """Model name to use."""
  86. temperature: float = 0.7
  87. """A non-negative float that tunes the degree of randomness in generation."""
  88. top_p: float = 0.95
  89. """Total probability mass of tokens to consider at each step."""
  90. model_kwargs: Dict[str, Any] = Field(default_factory=dict)
  91. """Holds any model parameters valid for `create` call not explicitly specified."""
  92. streaming: bool = False
  93. """Whether to stream the response or return it all at once."""
  94. api_key: Optional[str] = None
  95. secret_key: Optional[str] = None
  96. class Config:
  97. """Configuration for this pydantic object."""
  98. extra = Extra.forbid
  99. @root_validator()
  100. def validate_environment(cls, values: Dict) -> Dict:
  101. """Validate that api key and python package exists in environment."""
  102. values["api_key"] = get_from_dict_or_env(
  103. values, "api_key", "WENXIN_API_KEY"
  104. )
  105. values["secret_key"] = get_from_dict_or_env(
  106. values, "secret_key", "WENXIN_SECRET_KEY"
  107. )
  108. return values
  109. @property
  110. def _default_params(self) -> Dict[str, Any]:
  111. """Get the default parameters for calling OpenAI API."""
  112. return {
  113. "model": self.model,
  114. "temperature": self.temperature,
  115. "top_p": self.top_p,
  116. "stream": self.streaming,
  117. **self.model_kwargs,
  118. }
  119. @property
  120. def _identifying_params(self) -> Dict[str, Any]:
  121. """Get the identifying parameters."""
  122. return {**{"model": self.model}, **self._default_params}
  123. @property
  124. def _llm_type(self) -> str:
  125. """Return type of llm."""
  126. return "wenxin"
  127. def __init__(self, **data: Any):
  128. super().__init__(**data)
  129. self._client = _WenxinEndpointClient(
  130. api_key=self.api_key,
  131. secret_key=self.secret_key,
  132. )
  133. def _call(
  134. self,
  135. prompt: str,
  136. stop: Optional[List[str]] = None,
  137. run_manager: Optional[CallbackManagerForLLMRun] = None,
  138. **kwargs: Any,
  139. ) -> str:
  140. r"""Call out to Wenxin's completion endpoint to chat
  141. Args:
  142. prompt: The prompt to pass into the model.
  143. Returns:
  144. The string generated by the model.
  145. Example:
  146. .. code-block:: python
  147. response = wenxin("Tell me a joke.")
  148. """
  149. if self.streaming:
  150. completion = ""
  151. for chunk in self._stream(
  152. prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
  153. ):
  154. completion += chunk.text
  155. else:
  156. request = self._default_params
  157. request["messages"] = [{"role": "user", "content": prompt}]
  158. request.update(kwargs)
  159. completion = self._client.post(request)
  160. if stop is not None:
  161. completion = enforce_stop_tokens(completion, stop)
  162. return completion
  163. def _stream(
  164. self,
  165. prompt: str,
  166. stop: Optional[List[str]] = None,
  167. run_manager: Optional[CallbackManagerForLLMRun] = None,
  168. **kwargs: Any,
  169. ) -> Iterator[GenerationChunk]:
  170. r"""Call wenxin completion_stream and return the resulting generator.
  171. Args:
  172. prompt: The prompt to pass into the model.
  173. stop: Optional list of stop words to use when generating.
  174. Returns:
  175. A generator representing the stream of tokens from Wenxin.
  176. Example:
  177. .. code-block:: python
  178. prompt = "Write a poem about a stream."
  179. prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
  180. generator = wenxin.stream(prompt)
  181. for token in generator:
  182. yield token
  183. """
  184. request = self._default_params
  185. request["messages"] = [{"role": "user", "content": prompt}]
  186. request.update(kwargs)
  187. for token in self._client.post(request).iter_lines():
  188. if token:
  189. token = token.decode("utf-8")
  190. completion = json.loads(token[5:])
  191. yield GenerationChunk(text=completion['result'])
  192. if run_manager:
  193. run_manager.on_llm_new_token(completion['result'])
  194. if completion['is_end']:
  195. break