streamable_azure_open_ai.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. from langchain.callbacks.manager import Callbacks
  2. from langchain.llms import AzureOpenAI
  3. from langchain.schema import LLMResult
  4. from typing import Optional, List, Dict, Mapping, Any, Union, Tuple
  5. from pydantic import root_validator
  6. from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
  7. class StreamableAzureOpenAI(AzureOpenAI):
  8. openai_api_type: str = "azure"
  9. openai_api_version: str = ""
  10. request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
  11. """Timeout for requests to OpenAI completion API. Default is 600 seconds."""
  12. max_retries: int = 1
  13. """Maximum number of retries to make when generating."""
  14. @root_validator()
  15. def validate_environment(cls, values: Dict) -> Dict:
  16. """Validate that api key and python package exists in environment."""
  17. try:
  18. import openai
  19. values["client"] = openai.Completion
  20. except ImportError:
  21. raise ValueError(
  22. "Could not import openai python package. "
  23. "Please install it with `pip install openai`."
  24. )
  25. if values["streaming"] and values["n"] > 1:
  26. raise ValueError("Cannot stream results when n > 1.")
  27. if values["streaming"] and values["best_of"] > 1:
  28. raise ValueError("Cannot stream results when best_of > 1.")
  29. return values
  30. @property
  31. def _invocation_params(self) -> Dict[str, Any]:
  32. return {**super()._invocation_params, **{
  33. "api_type": self.openai_api_type,
  34. "api_base": self.openai_api_base,
  35. "api_version": self.openai_api_version,
  36. "api_key": self.openai_api_key,
  37. "organization": self.openai_organization if self.openai_organization else None,
  38. }}
  39. @property
  40. def _identifying_params(self) -> Mapping[str, Any]:
  41. return {**super()._identifying_params, **{
  42. "api_type": self.openai_api_type,
  43. "api_base": self.openai_api_base,
  44. "api_version": self.openai_api_version,
  45. "api_key": self.openai_api_key,
  46. "organization": self.openai_organization if self.openai_organization else None,
  47. }}
  48. @handle_openai_exceptions
  49. def generate(
  50. self,
  51. prompts: List[str],
  52. stop: Optional[List[str]] = None,
  53. callbacks: Callbacks = None,
  54. **kwargs: Any,
  55. ) -> LLMResult:
  56. return super().generate(prompts, stop, callbacks, **kwargs)
  57. @classmethod
  58. def get_kwargs_from_model_params(cls, params: dict):
  59. return params