streamable_open_ai.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import os
  2. from langchain.callbacks.manager import Callbacks
  3. from langchain.schema import LLMResult
  4. from typing import Optional, List, Dict, Any, Mapping
  5. from langchain import OpenAI
  6. from pydantic import root_validator
  7. from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
  8. class StreamableOpenAI(OpenAI):
  9. @root_validator()
  10. def validate_environment(cls, values: Dict) -> Dict:
  11. """Validate that api key and python package exists in environment."""
  12. try:
  13. import openai
  14. values["client"] = openai.Completion
  15. except ImportError:
  16. raise ValueError(
  17. "Could not import openai python package. "
  18. "Please install it with `pip install openai`."
  19. )
  20. if values["streaming"] and values["n"] > 1:
  21. raise ValueError("Cannot stream results when n > 1.")
  22. if values["streaming"] and values["best_of"] > 1:
  23. raise ValueError("Cannot stream results when best_of > 1.")
  24. return values
  25. @property
  26. def _invocation_params(self) -> Dict[str, Any]:
  27. return {**super()._invocation_params, **{
  28. "api_type": 'openai',
  29. "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
  30. "api_version": None,
  31. "api_key": self.openai_api_key,
  32. "organization": self.openai_organization if self.openai_organization else None,
  33. }}
  34. @property
  35. def _identifying_params(self) -> Mapping[str, Any]:
  36. return {**super()._identifying_params, **{
  37. "api_type": 'openai',
  38. "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
  39. "api_version": None,
  40. "api_key": self.openai_api_key,
  41. "organization": self.openai_organization if self.openai_organization else None,
  42. }}
  43. @handle_openai_exceptions
  44. def generate(
  45. self,
  46. prompts: List[str],
  47. stop: Optional[List[str]] = None,
  48. callbacks: Callbacks = None,
  49. **kwargs: Any,
  50. ) -> LLMResult:
  51. return super().generate(prompts, stop, callbacks, **kwargs)
  52. @classmethod
  53. def get_kwargs_from_model_params(cls, params: dict):
  54. return params