streamable_chat_open_ai.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import os
  2. from langchain.callbacks.manager import Callbacks
  3. from langchain.schema import BaseMessage, LLMResult
  4. from langchain.chat_models import ChatOpenAI
  5. from typing import Optional, List, Dict, Any
  6. from pydantic import root_validator
  7. from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
  8. class StreamableChatOpenAI(ChatOpenAI):
  9. @root_validator()
  10. def validate_environment(cls, values: Dict) -> Dict:
  11. """Validate that api key and python package exists in environment."""
  12. try:
  13. import openai
  14. except ImportError:
  15. raise ValueError(
  16. "Could not import openai python package. "
  17. "Please install it with `pip install openai`."
  18. )
  19. try:
  20. values["client"] = openai.ChatCompletion
  21. except AttributeError:
  22. raise ValueError(
  23. "`openai` has no `ChatCompletion` attribute, this is likely "
  24. "due to an old version of the openai package. Try upgrading it "
  25. "with `pip install --upgrade openai`."
  26. )
  27. if values["n"] < 1:
  28. raise ValueError("n must be at least 1.")
  29. if values["n"] > 1 and values["streaming"]:
  30. raise ValueError("n must be 1 when streaming.")
  31. return values
  32. @property
  33. def _default_params(self) -> Dict[str, Any]:
  34. """Get the default parameters for calling OpenAI API."""
  35. return {
  36. **super()._default_params,
  37. "api_type": 'openai',
  38. "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
  39. "api_version": None,
  40. "api_key": self.openai_api_key,
  41. "organization": self.openai_organization if self.openai_organization else None,
  42. }
  43. @handle_openai_exceptions
  44. def generate(
  45. self,
  46. messages: List[List[BaseMessage]],
  47. stop: Optional[List[str]] = None,
  48. callbacks: Callbacks = None,
  49. **kwargs: Any,
  50. ) -> LLMResult:
  51. return super().generate(messages, stop, callbacks, **kwargs)
  52. @classmethod
  53. def get_kwargs_from_model_params(cls, params: dict):
  54. model_kwargs = {
  55. 'top_p': params.get('top_p', 1),
  56. 'frequency_penalty': params.get('frequency_penalty', 0),
  57. 'presence_penalty': params.get('presence_penalty', 0),
  58. }
  59. del params['top_p']
  60. del params['frequency_penalty']
  61. del params['presence_penalty']
  62. params['model_kwargs'] = model_kwargs
  63. return params