streamable_chat_open_ai.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import os
  2. from langchain.callbacks.manager import Callbacks
  3. from langchain.schema import BaseMessage, LLMResult
  4. from langchain.chat_models import ChatOpenAI
  5. from typing import Optional, List, Dict, Any
  6. from pydantic import root_validator
  7. from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
  8. class StreamableChatOpenAI(ChatOpenAI):
  9. @root_validator()
  10. def validate_environment(cls, values: Dict) -> Dict:
  11. """Validate that api key and python package exists in environment."""
  12. try:
  13. import openai
  14. except ImportError:
  15. raise ValueError(
  16. "Could not import openai python package. "
  17. "Please install it with `pip install openai`."
  18. )
  19. try:
  20. values["client"] = openai.ChatCompletion
  21. except AttributeError:
  22. raise ValueError(
  23. "`openai` has no `ChatCompletion` attribute, this is likely "
  24. "due to an old version of the openai package. Try upgrading it "
  25. "with `pip install --upgrade openai`."
  26. )
  27. if values["n"] < 1:
  28. raise ValueError("n must be at least 1.")
  29. if values["n"] > 1 and values["streaming"]:
  30. raise ValueError("n must be 1 when streaming.")
  31. return values
  32. @property
  33. def _default_params(self) -> Dict[str, Any]:
  34. """Get the default parameters for calling OpenAI API."""
  35. return {
  36. **super()._default_params,
  37. "api_type": 'openai',
  38. "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
  39. "api_version": None,
  40. "api_key": self.openai_api_key,
  41. "organization": self.openai_organization if self.openai_organization else None,
  42. }
  43. def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
  44. """Get the number of tokens in a list of messages.
  45. Args:
  46. messages: The messages to count the tokens of.
  47. Returns:
  48. The number of tokens in the messages.
  49. """
  50. tokens_per_message = 5
  51. tokens_per_request = 3
  52. message_tokens = tokens_per_request
  53. message_strs = ''
  54. for message in messages:
  55. message_strs += message.content
  56. message_tokens += tokens_per_message
  57. # calc once
  58. message_tokens += self.get_num_tokens(message_strs)
  59. return message_tokens
  60. @handle_llm_exceptions
  61. def generate(
  62. self,
  63. messages: List[List[BaseMessage]],
  64. stop: Optional[List[str]] = None,
  65. callbacks: Callbacks = None,
  66. **kwargs: Any,
  67. ) -> LLMResult:
  68. return super().generate(messages, stop, callbacks, **kwargs)
  69. @handle_llm_exceptions_async
  70. async def agenerate(
  71. self,
  72. messages: List[List[BaseMessage]],
  73. stop: Optional[List[str]] = None,
  74. callbacks: Callbacks = None,
  75. **kwargs: Any,
  76. ) -> LLMResult:
  77. return await super().agenerate(messages, stop, callbacks, **kwargs)