llm_callback_handler.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. import logging
  2. import time
  3. from typing import Any, Dict, List, Union, Optional
  4. from langchain.callbacks.base import BaseCallbackHandler
  5. from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage
  6. from core.callback_handler.entity.llm_message import LLMMessage
  7. from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
  8. from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
  9. from core.llm.streamable_open_ai import StreamableOpenAI
  10. class LLMCallbackHandler(BaseCallbackHandler):
  11. def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
  12. conversation_message_task: ConversationMessageTask):
  13. self.llm = llm
  14. self.llm_message = LLMMessage()
  15. self.start_at = None
  16. self.conversation_message_task = conversation_message_task
  17. @property
  18. def always_verbose(self) -> bool:
  19. """Whether to call verbose callbacks even if verbose is False."""
  20. return True
  21. def on_llm_start(
  22. self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
  23. ) -> None:
  24. self.start_at = time.perf_counter()
  25. if 'Chat' in serialized['name']:
  26. real_prompts = []
  27. messages = []
  28. for prompt in prompts:
  29. role, content = prompt.split(': ', maxsplit=1)
  30. if role == 'human':
  31. role = 'user'
  32. message = HumanMessage(content=content)
  33. elif role == 'ai':
  34. role = 'assistant'
  35. message = AIMessage(content=content)
  36. else:
  37. message = SystemMessage(content=content)
  38. real_prompt = {
  39. "role": role,
  40. "text": content
  41. }
  42. real_prompts.append(real_prompt)
  43. messages.append(message)
  44. self.llm_message.prompt = real_prompts
  45. self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages)
  46. else:
  47. self.llm_message.prompt = [{
  48. "role": 'user',
  49. "text": prompts[0]
  50. }]
  51. self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
  52. def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
  53. end_at = time.perf_counter()
  54. self.llm_message.latency = end_at - self.start_at
  55. if not self.conversation_message_task.streaming:
  56. self.conversation_message_task.append_message_text(response.generations[0][0].text)
  57. self.llm_message.completion = response.generations[0][0].text
  58. self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
  59. else:
  60. self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
  61. self.conversation_message_task.save_message(self.llm_message)
  62. def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
  63. self.conversation_message_task.append_message_text(token)
  64. self.llm_message.completion += token
  65. def on_llm_error(
  66. self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
  67. ) -> None:
  68. """Do nothing."""
  69. if isinstance(error, ConversationTaskStoppedException):
  70. if self.conversation_message_task.streaming:
  71. end_at = time.perf_counter()
  72. self.llm_message.latency = end_at - self.start_at
  73. self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
  74. self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
  75. else:
  76. logging.error(error)
  77. def on_chain_start(
  78. self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
  79. ) -> None:
  80. pass
  81. def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
  82. pass
  83. def on_chain_error(
  84. self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
  85. ) -> None:
  86. pass
  87. def on_tool_start(
  88. self,
  89. serialized: Dict[str, Any],
  90. input_str: str,
  91. **kwargs: Any,
  92. ) -> None:
  93. pass
  94. def on_agent_action(
  95. self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
  96. ) -> Any:
  97. pass
  98. def on_tool_end(
  99. self,
  100. output: str,
  101. color: Optional[str] = None,
  102. observation_prefix: Optional[str] = None,
  103. llm_prefix: Optional[str] = None,
  104. **kwargs: Any,
  105. ) -> None:
  106. pass
  107. def on_tool_error(
  108. self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
  109. ) -> None:
  110. pass
  111. def on_text(
  112. self,
  113. text: str,
  114. color: Optional[str] = None,
  115. end: str = "",
  116. **kwargs: Optional[str],
  117. ) -> None:
  118. pass
  119. def on_agent_finish(
  120. self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
  121. ) -> None:
  122. pass