llm_callback_handler.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. import logging
  2. import threading
  3. import time
  4. from typing import Any, Dict, List, Union, Optional
  5. from flask import Flask, current_app
  6. from langchain.callbacks.base import BaseCallbackHandler
  7. from langchain.schema import LLMResult, BaseMessage
  8. from pydantic import BaseModel
  9. from core.callback_handler.entity.llm_message import LLMMessage
  10. from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
  11. ConversationTaskInterruptException
  12. from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage
  13. from core.model_providers.models.llm.base import BaseLLM
  14. from core.moderation.base import ModerationOutputsResult, ModerationAction
  15. from core.moderation.factory import ModerationFactory
  16. class ModerationRule(BaseModel):
  17. type: str
  18. config: Dict[str, Any]
  19. class LLMCallbackHandler(BaseCallbackHandler):
  20. raise_error: bool = True
  21. def __init__(self, model_instance: BaseLLM,
  22. conversation_message_task: ConversationMessageTask):
  23. self.model_instance = model_instance
  24. self.llm_message = LLMMessage()
  25. self.start_at = None
  26. self.conversation_message_task = conversation_message_task
  27. self.output_moderation_handler = None
  28. self.init_output_moderation()
  29. def init_output_moderation(self):
  30. app_model_config = self.conversation_message_task.app_model_config
  31. sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict
  32. if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"):
  33. self.output_moderation_handler = OutputModerationHandler(
  34. tenant_id=self.conversation_message_task.tenant_id,
  35. app_id=self.conversation_message_task.app.id,
  36. rule=ModerationRule(
  37. type=sensitive_word_avoidance_dict.get("type"),
  38. config=sensitive_word_avoidance_dict.get("config")
  39. ),
  40. on_message_replace_func=self.conversation_message_task.on_message_replace
  41. )
  42. @property
  43. def always_verbose(self) -> bool:
  44. """Whether to call verbose callbacks even if verbose is False."""
  45. return True
  46. def on_chat_model_start(
  47. self,
  48. serialized: Dict[str, Any],
  49. messages: List[List[BaseMessage]],
  50. **kwargs: Any
  51. ) -> Any:
  52. real_prompts = []
  53. for message in messages[0]:
  54. if message.type == 'human':
  55. role = 'user'
  56. elif message.type == 'ai':
  57. role = 'assistant'
  58. else:
  59. role = 'system'
  60. real_prompts.append({
  61. "role": role,
  62. "text": message.content
  63. })
  64. self.llm_message.prompt = real_prompts
  65. self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0]))
  66. def on_llm_start(
  67. self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
  68. ) -> None:
  69. self.llm_message.prompt = [{
  70. "role": 'user',
  71. "text": prompts[0]
  72. }]
  73. self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
  74. def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
  75. if self.output_moderation_handler:
  76. self.output_moderation_handler.stop_thread()
  77. self.llm_message.completion = self.output_moderation_handler.moderation_completion(
  78. completion=response.generations[0][0].text,
  79. public_event=True if self.conversation_message_task.streaming else False
  80. )
  81. else:
  82. self.llm_message.completion = response.generations[0][0].text
  83. if not self.conversation_message_task.streaming:
  84. self.conversation_message_task.append_message_text(self.llm_message.completion)
  85. if response.llm_output and 'token_usage' in response.llm_output:
  86. if 'prompt_tokens' in response.llm_output['token_usage']:
  87. self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
  88. if 'completion_tokens' in response.llm_output['token_usage']:
  89. self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
  90. else:
  91. self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
  92. [PromptMessage(content=self.llm_message.completion)])
  93. else:
  94. self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
  95. [PromptMessage(content=self.llm_message.completion)])
  96. self.conversation_message_task.save_message(self.llm_message)
  97. def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
  98. if self.output_moderation_handler and self.output_moderation_handler.should_direct_output():
  99. # stop subscribe new token when output moderation should direct output
  100. ex = ConversationTaskInterruptException()
  101. self.on_llm_error(error=ex)
  102. raise ex
  103. try:
  104. self.conversation_message_task.append_message_text(token)
  105. self.llm_message.completion += token
  106. if self.output_moderation_handler:
  107. self.output_moderation_handler.append_new_token(token)
  108. except ConversationTaskStoppedException as ex:
  109. self.on_llm_error(error=ex)
  110. raise ex
  111. def on_llm_error(
  112. self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
  113. ) -> None:
  114. """Do nothing."""
  115. if self.output_moderation_handler:
  116. self.output_moderation_handler.stop_thread()
  117. if isinstance(error, ConversationTaskStoppedException):
  118. if self.conversation_message_task.streaming:
  119. self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
  120. [PromptMessage(content=self.llm_message.completion)]
  121. )
  122. self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
  123. if isinstance(error, ConversationTaskInterruptException):
  124. self.llm_message.completion = self.output_moderation_handler.get_final_output()
  125. self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
  126. [PromptMessage(content=self.llm_message.completion)]
  127. )
  128. self.conversation_message_task.save_message(llm_message=self.llm_message)
  129. else:
  130. logging.debug("on_llm_error: %s", error)
  131. class OutputModerationHandler(BaseModel):
  132. DEFAULT_BUFFER_SIZE: int = 300
  133. tenant_id: str
  134. app_id: str
  135. rule: ModerationRule
  136. on_message_replace_func: Any
  137. thread: Optional[threading.Thread] = None
  138. thread_running: bool = True
  139. buffer: str = ''
  140. is_final_chunk: bool = False
  141. final_output: Optional[str] = None
  142. class Config:
  143. arbitrary_types_allowed = True
  144. def should_direct_output(self):
  145. return self.final_output is not None
  146. def get_final_output(self):
  147. return self.final_output
  148. def append_new_token(self, token: str):
  149. self.buffer += token
  150. if not self.thread:
  151. self.thread = self.start_thread()
  152. def moderation_completion(self, completion: str, public_event: bool = False) -> str:
  153. self.buffer = completion
  154. self.is_final_chunk = True
  155. result = self.moderation(
  156. tenant_id=self.tenant_id,
  157. app_id=self.app_id,
  158. moderation_buffer=completion
  159. )
  160. if not result or not result.flagged:
  161. return completion
  162. if result.action == ModerationAction.DIRECT_OUTPUT:
  163. final_output = result.preset_response
  164. else:
  165. final_output = result.text
  166. if public_event:
  167. self.on_message_replace_func(final_output)
  168. return final_output
  169. def start_thread(self) -> threading.Thread:
  170. buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
  171. thread = threading.Thread(target=self.worker, kwargs={
  172. 'flask_app': current_app._get_current_object(),
  173. 'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
  174. })
  175. thread.start()
  176. return thread
  177. def stop_thread(self):
  178. if self.thread and self.thread.is_alive():
  179. self.thread_running = False
  180. def worker(self, flask_app: Flask, buffer_size: int):
  181. with flask_app.app_context():
  182. current_length = 0
  183. while self.thread_running:
  184. moderation_buffer = self.buffer
  185. buffer_length = len(moderation_buffer)
  186. if not self.is_final_chunk:
  187. chunk_length = buffer_length - current_length
  188. if 0 <= chunk_length < buffer_size:
  189. time.sleep(1)
  190. continue
  191. current_length = buffer_length
  192. result = self.moderation(
  193. tenant_id=self.tenant_id,
  194. app_id=self.app_id,
  195. moderation_buffer=moderation_buffer
  196. )
  197. if not result or not result.flagged:
  198. continue
  199. if result.action == ModerationAction.DIRECT_OUTPUT:
  200. final_output = result.preset_response
  201. self.final_output = final_output
  202. else:
  203. final_output = result.text + self.buffer[len(moderation_buffer):]
  204. # trigger replace event
  205. if self.thread_running:
  206. self.on_message_replace_func(final_output)
  207. if result.action == ModerationAction.DIRECT_OUTPUT:
  208. break
  209. def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
  210. try:
  211. moderation_factory = ModerationFactory(
  212. name=self.rule.type,
  213. app_id=app_id,
  214. tenant_id=tenant_id,
  215. config=self.rule.config
  216. )
  217. result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
  218. return result
  219. except Exception as e:
  220. logging.error("Moderation Output error: %s", e)
  221. return None