streamable_open_ai.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import os
  2. from langchain.schema import LLMResult
  3. from typing import Optional, List, Dict, Any, Mapping
  4. from langchain import OpenAI
  5. from pydantic import root_validator
  6. from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
  7. class StreamableOpenAI(OpenAI):
  8. @root_validator()
  9. def validate_environment(cls, values: Dict) -> Dict:
  10. """Validate that api key and python package exists in environment."""
  11. try:
  12. import openai
  13. values["client"] = openai.Completion
  14. except ImportError:
  15. raise ValueError(
  16. "Could not import openai python package. "
  17. "Please install it with `pip install openai`."
  18. )
  19. if values["streaming"] and values["n"] > 1:
  20. raise ValueError("Cannot stream results when n > 1.")
  21. if values["streaming"] and values["best_of"] > 1:
  22. raise ValueError("Cannot stream results when best_of > 1.")
  23. return values
  24. @property
  25. def _invocation_params(self) -> Dict[str, Any]:
  26. return {**super()._invocation_params, **{
  27. "api_type": 'openai',
  28. "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
  29. "api_version": None,
  30. "api_key": self.openai_api_key,
  31. "organization": self.openai_organization if self.openai_organization else None,
  32. }}
  33. @property
  34. def _identifying_params(self) -> Mapping[str, Any]:
  35. return {**super()._identifying_params, **{
  36. "api_type": 'openai',
  37. "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
  38. "api_version": None,
  39. "api_key": self.openai_api_key,
  40. "organization": self.openai_organization if self.openai_organization else None,
  41. }}
  42. @handle_llm_exceptions
  43. def generate(
  44. self, prompts: List[str], stop: Optional[List[str]] = None
  45. ) -> LLMResult:
  46. return super().generate(prompts, stop)
  47. @handle_llm_exceptions_async
  48. async def agenerate(
  49. self, prompts: List[str], stop: Optional[List[str]] = None
  50. ) -> LLMResult:
  51. return await super().agenerate(prompts, stop)