streamable_chat_open_ai.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from langchain.schema import BaseMessage, ChatResult, LLMResult
  2. from langchain.chat_models import ChatOpenAI
  3. from typing import Optional, List
  4. from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
  5. class StreamableChatOpenAI(ChatOpenAI):
  6. def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
  7. """Get the number of tokens in a list of messages.
  8. Args:
  9. messages: The messages to count the tokens of.
  10. Returns:
  11. The number of tokens in the messages.
  12. """
  13. tokens_per_message = 5
  14. tokens_per_request = 3
  15. message_tokens = tokens_per_request
  16. message_strs = ''
  17. for message in messages:
  18. message_strs += message.content
  19. message_tokens += tokens_per_message
  20. # calc once
  21. message_tokens += self.get_num_tokens(message_strs)
  22. return message_tokens
  23. def _generate(
  24. self, messages: List[BaseMessage], stop: Optional[List[str]] = None
  25. ) -> ChatResult:
  26. self.callback_manager.on_llm_start(
  27. {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
  28. )
  29. chat_result = super()._generate(messages, stop)
  30. result = LLMResult(
  31. generations=[chat_result.generations],
  32. llm_output=chat_result.llm_output
  33. )
  34. self.callback_manager.on_llm_end(result, verbose=self.verbose)
  35. return chat_result
  36. async def _agenerate(
  37. self, messages: List[BaseMessage], stop: Optional[List[str]] = None
  38. ) -> ChatResult:
  39. if self.callback_manager.is_async:
  40. await self.callback_manager.on_llm_start(
  41. {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
  42. )
  43. else:
  44. self.callback_manager.on_llm_start(
  45. {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
  46. )
  47. chat_result = super()._generate(messages, stop)
  48. result = LLMResult(
  49. generations=[chat_result.generations],
  50. llm_output=chat_result.llm_output
  51. )
  52. if self.callback_manager.is_async:
  53. await self.callback_manager.on_llm_end(result, verbose=self.verbose)
  54. else:
  55. self.callback_manager.on_llm_end(result, verbose=self.verbose)
  56. return chat_result
  57. @handle_llm_exceptions
  58. def generate(
  59. self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
  60. ) -> LLMResult:
  61. return super().generate(messages, stop)
  62. @handle_llm_exceptions_async
  63. async def agenerate(
  64. self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
  65. ) -> LLMResult:
  66. return await super().agenerate(messages, stop)