streamable_azure_open_ai.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import os
  2. from langchain.llms import AzureOpenAI
  3. from langchain.schema import LLMResult
  4. from typing import Optional, List, Dict, Mapping, Any
  5. from pydantic import root_validator
  6. from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
  7. class StreamableAzureOpenAI(AzureOpenAI):
  8. openai_api_type: str = "azure"
  9. openai_api_version: str = ""
  10. @root_validator()
  11. def validate_environment(cls, values: Dict) -> Dict:
  12. """Validate that api key and python package exists in environment."""
  13. try:
  14. import openai
  15. values["client"] = openai.Completion
  16. except ImportError:
  17. raise ValueError(
  18. "Could not import openai python package. "
  19. "Please install it with `pip install openai`."
  20. )
  21. if values["streaming"] and values["n"] > 1:
  22. raise ValueError("Cannot stream results when n > 1.")
  23. if values["streaming"] and values["best_of"] > 1:
  24. raise ValueError("Cannot stream results when best_of > 1.")
  25. return values
  26. @property
  27. def _invocation_params(self) -> Dict[str, Any]:
  28. return {**super()._invocation_params, **{
  29. "api_type": self.openai_api_type,
  30. "api_base": self.openai_api_base,
  31. "api_version": self.openai_api_version,
  32. "api_key": self.openai_api_key,
  33. "organization": self.openai_organization if self.openai_organization else None,
  34. }}
  35. @property
  36. def _identifying_params(self) -> Mapping[str, Any]:
  37. return {**super()._identifying_params, **{
  38. "api_type": self.openai_api_type,
  39. "api_base": self.openai_api_base,
  40. "api_version": self.openai_api_version,
  41. "api_key": self.openai_api_key,
  42. "organization": self.openai_organization if self.openai_organization else None,
  43. }}
  44. @handle_llm_exceptions
  45. def generate(
  46. self, prompts: List[str], stop: Optional[List[str]] = None
  47. ) -> LLMResult:
  48. return super().generate(prompts, stop)
  49. @handle_llm_exceptions_async
  50. async def agenerate(
  51. self, prompts: List[str], stop: Optional[List[str]] = None
  52. ) -> LLMResult:
  53. return await super().agenerate(prompts, stop)