123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- from langchain.callbacks.manager import Callbacks, CallbackManagerForLLMRun
- from langchain.llms import AzureOpenAI
- from langchain.llms.openai import _streaming_response_template, completion_with_retry, _update_response, \
- update_token_usage
- from langchain.schema import LLMResult
- from typing import Optional, List, Dict, Mapping, Any, Union, Tuple
- from pydantic import root_validator
- from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
- class StreamableAzureOpenAI(AzureOpenAI):
- openai_api_type: str = "azure"
- openai_api_version: str = ""
- request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
- """Timeout for requests to OpenAI completion API. Default is 600 seconds."""
- max_retries: int = 1
- """Maximum number of retries to make when generating."""
- @root_validator()
- def validate_environment(cls, values: Dict) -> Dict:
- """Validate that api key and python package exists in environment."""
- try:
- import openai
- values["client"] = openai.Completion
- except ImportError:
- raise ValueError(
- "Could not import openai python package. "
- "Please install it with `pip install openai`."
- )
- if values["streaming"] and values["n"] > 1:
- raise ValueError("Cannot stream results when n > 1.")
- if values["streaming"] and values["best_of"] > 1:
- raise ValueError("Cannot stream results when best_of > 1.")
- return values
- @property
- def _invocation_params(self) -> Dict[str, Any]:
- return {**super()._invocation_params, **{
- "api_type": self.openai_api_type,
- "api_base": self.openai_api_base,
- "api_version": self.openai_api_version,
- "api_key": self.openai_api_key,
- "organization": self.openai_organization if self.openai_organization else None,
- }}
- @property
- def _identifying_params(self) -> Mapping[str, Any]:
- return {**super()._identifying_params, **{
- "api_type": self.openai_api_type,
- "api_base": self.openai_api_base,
- "api_version": self.openai_api_version,
- "api_key": self.openai_api_key,
- "organization": self.openai_organization if self.openai_organization else None,
- }}
- @handle_openai_exceptions
- def generate(
- self,
- prompts: List[str],
- stop: Optional[List[str]] = None,
- callbacks: Callbacks = None,
- **kwargs: Any,
- ) -> LLMResult:
- return super().generate(prompts, stop, callbacks, **kwargs)
- @classmethod
- def get_kwargs_from_model_params(cls, params: dict):
- return params
- def _generate(
- self,
- prompts: List[str],
- stop: Optional[List[str]] = None,
- run_manager: Optional[CallbackManagerForLLMRun] = None,
- **kwargs: Any,
- ) -> LLMResult:
- """Call out to OpenAI's endpoint with k unique prompts.
- Args:
- prompts: The prompts to pass into the model.
- stop: Optional list of stop words to use when generating.
- Returns:
- The full LLM output.
- Example:
- .. code-block:: python
- response = openai.generate(["Tell me a joke."])
- """
- params = self._invocation_params
- params = {**params, **kwargs}
- sub_prompts = self.get_sub_prompts(params, prompts, stop)
- choices = []
- token_usage: Dict[str, int] = {}
- # Get the token usage from the response.
- # Includes prompt, completion, and total tokens used.
- _keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
- for _prompts in sub_prompts:
- if self.streaming:
- if len(_prompts) > 1:
- raise ValueError("Cannot stream results with multiple prompts.")
- params["stream"] = True
- response = _streaming_response_template()
- for stream_resp in completion_with_retry(
- self, prompt=_prompts, **params
- ):
- if len(stream_resp["choices"]) > 0:
- if run_manager:
- run_manager.on_llm_new_token(
- stream_resp["choices"][0]["text"],
- verbose=self.verbose,
- logprobs=stream_resp["choices"][0]["logprobs"],
- )
- _update_response(response, stream_resp)
- choices.extend(response["choices"])
- else:
- response = completion_with_retry(self, prompt=_prompts, **params)
- choices.extend(response["choices"])
- if not self.streaming:
- # Can't update token usage if streaming
- update_token_usage(_keys, response, token_usage)
- return self.create_llm_result(choices, prompts, token_usage)
|