streamable_chat_open_ai.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import os
  2. from langchain.schema import BaseMessage, ChatResult, LLMResult
  3. from langchain.chat_models import ChatOpenAI
  4. from typing import Optional, List, Dict, Any
  5. from pydantic import root_validator
  6. from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
  7. class StreamableChatOpenAI(ChatOpenAI):
  8. @root_validator()
  9. def validate_environment(cls, values: Dict) -> Dict:
  10. """Validate that api key and python package exists in environment."""
  11. try:
  12. import openai
  13. except ImportError:
  14. raise ValueError(
  15. "Could not import openai python package. "
  16. "Please install it with `pip install openai`."
  17. )
  18. try:
  19. values["client"] = openai.ChatCompletion
  20. except AttributeError:
  21. raise ValueError(
  22. "`openai` has no `ChatCompletion` attribute, this is likely "
  23. "due to an old version of the openai package. Try upgrading it "
  24. "with `pip install --upgrade openai`."
  25. )
  26. if values["n"] < 1:
  27. raise ValueError("n must be at least 1.")
  28. if values["n"] > 1 and values["streaming"]:
  29. raise ValueError("n must be 1 when streaming.")
  30. return values
  31. @property
  32. def _default_params(self) -> Dict[str, Any]:
  33. """Get the default parameters for calling OpenAI API."""
  34. return {
  35. **super()._default_params,
  36. "api_type": 'openai',
  37. "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
  38. "api_version": None,
  39. "api_key": self.openai_api_key,
  40. "organization": self.openai_organization if self.openai_organization else None,
  41. }
  42. def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
  43. """Get the number of tokens in a list of messages.
  44. Args:
  45. messages: The messages to count the tokens of.
  46. Returns:
  47. The number of tokens in the messages.
  48. """
  49. tokens_per_message = 5
  50. tokens_per_request = 3
  51. message_tokens = tokens_per_request
  52. message_strs = ''
  53. for message in messages:
  54. message_strs += message.content
  55. message_tokens += tokens_per_message
  56. # calc once
  57. message_tokens += self.get_num_tokens(message_strs)
  58. return message_tokens
  59. def _generate(
  60. self, messages: List[BaseMessage], stop: Optional[List[str]] = None
  61. ) -> ChatResult:
  62. self.callback_manager.on_llm_start(
  63. {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
  64. )
  65. chat_result = super()._generate(messages, stop)
  66. result = LLMResult(
  67. generations=[chat_result.generations],
  68. llm_output=chat_result.llm_output
  69. )
  70. self.callback_manager.on_llm_end(result, verbose=self.verbose)
  71. return chat_result
  72. async def _agenerate(
  73. self, messages: List[BaseMessage], stop: Optional[List[str]] = None
  74. ) -> ChatResult:
  75. if self.callback_manager.is_async:
  76. await self.callback_manager.on_llm_start(
  77. {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
  78. )
  79. else:
  80. self.callback_manager.on_llm_start(
  81. {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
  82. )
  83. chat_result = super()._generate(messages, stop)
  84. result = LLMResult(
  85. generations=[chat_result.generations],
  86. llm_output=chat_result.llm_output
  87. )
  88. if self.callback_manager.is_async:
  89. await self.callback_manager.on_llm_end(result, verbose=self.verbose)
  90. else:
  91. self.callback_manager.on_llm_end(result, verbose=self.verbose)
  92. return chat_result
  93. @handle_llm_exceptions
  94. def generate(
  95. self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
  96. ) -> LLMResult:
  97. return super().generate(messages, stop)
  98. @handle_llm_exceptions_async
  99. async def agenerate(
  100. self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
  101. ) -> LLMResult:
  102. return await super().agenerate(messages, stop)