streamable_chat_open_ai.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import os
  2. from langchain.callbacks.manager import Callbacks
  3. from langchain.schema import BaseMessage, LLMResult
  4. from langchain.chat_models import ChatOpenAI
  5. from typing import Optional, List, Dict, Any, Union, Tuple
  6. from pydantic import root_validator
  7. from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
  8. class StreamableChatOpenAI(ChatOpenAI):
  9. request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
  10. """Timeout for requests to OpenAI completion API. Default is 600 seconds."""
  11. max_retries: int = 1
  12. """Maximum number of retries to make when generating."""
  13. @root_validator()
  14. def validate_environment(cls, values: Dict) -> Dict:
  15. """Validate that api key and python package exists in environment."""
  16. try:
  17. import openai
  18. except ImportError:
  19. raise ValueError(
  20. "Could not import openai python package. "
  21. "Please install it with `pip install openai`."
  22. )
  23. try:
  24. values["client"] = openai.ChatCompletion
  25. except AttributeError:
  26. raise ValueError(
  27. "`openai` has no `ChatCompletion` attribute, this is likely "
  28. "due to an old version of the openai package. Try upgrading it "
  29. "with `pip install --upgrade openai`."
  30. )
  31. if values["n"] < 1:
  32. raise ValueError("n must be at least 1.")
  33. if values["n"] > 1 and values["streaming"]:
  34. raise ValueError("n must be 1 when streaming.")
  35. return values
  36. @property
  37. def _default_params(self) -> Dict[str, Any]:
  38. """Get the default parameters for calling OpenAI API."""
  39. return {
  40. **super()._default_params,
  41. "api_type": 'openai',
  42. "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
  43. "api_version": None,
  44. "api_key": self.openai_api_key,
  45. "organization": self.openai_organization if self.openai_organization else None,
  46. }
  47. @handle_openai_exceptions
  48. def generate(
  49. self,
  50. messages: List[List[BaseMessage]],
  51. stop: Optional[List[str]] = None,
  52. callbacks: Callbacks = None,
  53. **kwargs: Any,
  54. ) -> LLMResult:
  55. return super().generate(messages, stop, callbacks, **kwargs)
  56. @classmethod
  57. def get_kwargs_from_model_params(cls, params: dict):
  58. model_kwargs = {
  59. 'top_p': params.get('top_p', 1),
  60. 'frequency_penalty': params.get('frequency_penalty', 0),
  61. 'presence_penalty': params.get('presence_penalty', 0),
  62. }
  63. del params['top_p']
  64. del params['frequency_penalty']
  65. del params['presence_penalty']
  66. params['model_kwargs'] = model_kwargs
  67. return params