streamable_azure_open_ai.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. from langchain.callbacks.manager import Callbacks, CallbackManagerForLLMRun
  2. from langchain.llms import AzureOpenAI
  3. from langchain.llms.openai import _streaming_response_template, completion_with_retry, _update_response, \
  4. update_token_usage
  5. from langchain.schema import LLMResult
  6. from typing import Optional, List, Dict, Mapping, Any, Union, Tuple
  7. from pydantic import root_validator
  8. from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
  9. class StreamableAzureOpenAI(AzureOpenAI):
  10. openai_api_type: str = "azure"
  11. openai_api_version: str = ""
  12. request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
  13. """Timeout for requests to OpenAI completion API. Default is 600 seconds."""
  14. max_retries: int = 1
  15. """Maximum number of retries to make when generating."""
  16. @root_validator()
  17. def validate_environment(cls, values: Dict) -> Dict:
  18. """Validate that api key and python package exists in environment."""
  19. try:
  20. import openai
  21. values["client"] = openai.Completion
  22. except ImportError:
  23. raise ValueError(
  24. "Could not import openai python package. "
  25. "Please install it with `pip install openai`."
  26. )
  27. if values["streaming"] and values["n"] > 1:
  28. raise ValueError("Cannot stream results when n > 1.")
  29. if values["streaming"] and values["best_of"] > 1:
  30. raise ValueError("Cannot stream results when best_of > 1.")
  31. return values
  32. @property
  33. def _invocation_params(self) -> Dict[str, Any]:
  34. return {**super()._invocation_params, **{
  35. "api_type": self.openai_api_type,
  36. "api_base": self.openai_api_base,
  37. "api_version": self.openai_api_version,
  38. "api_key": self.openai_api_key,
  39. "organization": self.openai_organization if self.openai_organization else None,
  40. }}
  41. @property
  42. def _identifying_params(self) -> Mapping[str, Any]:
  43. return {**super()._identifying_params, **{
  44. "api_type": self.openai_api_type,
  45. "api_base": self.openai_api_base,
  46. "api_version": self.openai_api_version,
  47. "api_key": self.openai_api_key,
  48. "organization": self.openai_organization if self.openai_organization else None,
  49. }}
  50. @handle_openai_exceptions
  51. def generate(
  52. self,
  53. prompts: List[str],
  54. stop: Optional[List[str]] = None,
  55. callbacks: Callbacks = None,
  56. **kwargs: Any,
  57. ) -> LLMResult:
  58. return super().generate(prompts, stop, callbacks, **kwargs)
  59. @classmethod
  60. def get_kwargs_from_model_params(cls, params: dict):
  61. return params
  62. def _generate(
  63. self,
  64. prompts: List[str],
  65. stop: Optional[List[str]] = None,
  66. run_manager: Optional[CallbackManagerForLLMRun] = None,
  67. **kwargs: Any,
  68. ) -> LLMResult:
  69. """Call out to OpenAI's endpoint with k unique prompts.
  70. Args:
  71. prompts: The prompts to pass into the model.
  72. stop: Optional list of stop words to use when generating.
  73. Returns:
  74. The full LLM output.
  75. Example:
  76. .. code-block:: python
  77. response = openai.generate(["Tell me a joke."])
  78. """
  79. params = self._invocation_params
  80. params = {**params, **kwargs}
  81. sub_prompts = self.get_sub_prompts(params, prompts, stop)
  82. choices = []
  83. token_usage: Dict[str, int] = {}
  84. # Get the token usage from the response.
  85. # Includes prompt, completion, and total tokens used.
  86. _keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
  87. for _prompts in sub_prompts:
  88. if self.streaming:
  89. if len(_prompts) > 1:
  90. raise ValueError("Cannot stream results with multiple prompts.")
  91. params["stream"] = True
  92. response = _streaming_response_template()
  93. for stream_resp in completion_with_retry(
  94. self, prompt=_prompts, **params
  95. ):
  96. if len(stream_resp["choices"]) > 0:
  97. if run_manager:
  98. run_manager.on_llm_new_token(
  99. stream_resp["choices"][0]["text"],
  100. verbose=self.verbose,
  101. logprobs=stream_resp["choices"][0]["logprobs"],
  102. )
  103. _update_response(response, stream_resp)
  104. choices.extend(response["choices"])
  105. else:
  106. response = completion_with_retry(self, prompt=_prompts, **params)
  107. choices.extend(response["choices"])
  108. if not self.streaming:
  109. # Can't update token usage if streaming
  110. update_token_usage(_keys, response, token_usage)
  111. return self.create_llm_result(choices, prompts, token_usage)