llm_callback_handler.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. import logging
  2. import threading
  3. import time
  4. from typing import Any, Dict, List, Union, Optional
  5. from flask import Flask, current_app
  6. from langchain.callbacks.base import BaseCallbackHandler
  7. from langchain.schema import LLMResult, BaseMessage
  8. from pydantic import BaseModel
  9. from core.callback_handler.entity.llm_message import LLMMessage
  10. from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
  11. ConversationTaskInterruptException
  12. from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage, LCHumanMessageWithFiles, \
  13. ImagePromptMessageFile
  14. from core.model_providers.models.llm.base import BaseLLM
  15. from core.moderation.base import ModerationOutputsResult, ModerationAction
  16. from core.moderation.factory import ModerationFactory
  17. class ModerationRule(BaseModel):
  18. type: str
  19. config: Dict[str, Any]
  20. class LLMCallbackHandler(BaseCallbackHandler):
  21. raise_error: bool = True
  22. def __init__(self, model_instance: BaseLLM,
  23. conversation_message_task: ConversationMessageTask):
  24. self.model_instance = model_instance
  25. self.llm_message = LLMMessage()
  26. self.start_at = None
  27. self.conversation_message_task = conversation_message_task
  28. self.output_moderation_handler = None
  29. self.init_output_moderation()
  30. def init_output_moderation(self):
  31. app_model_config = self.conversation_message_task.app_model_config
  32. sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict
  33. if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"):
  34. self.output_moderation_handler = OutputModerationHandler(
  35. tenant_id=self.conversation_message_task.tenant_id,
  36. app_id=self.conversation_message_task.app.id,
  37. rule=ModerationRule(
  38. type=sensitive_word_avoidance_dict.get("type"),
  39. config=sensitive_word_avoidance_dict.get("config")
  40. ),
  41. on_message_replace_func=self.conversation_message_task.on_message_replace
  42. )
  43. @property
  44. def always_verbose(self) -> bool:
  45. """Whether to call verbose callbacks even if verbose is False."""
  46. return True
  47. def on_chat_model_start(
  48. self,
  49. serialized: Dict[str, Any],
  50. messages: List[List[BaseMessage]],
  51. **kwargs: Any
  52. ) -> Any:
  53. real_prompts = []
  54. for message in messages[0]:
  55. if message.type == 'human':
  56. role = 'user'
  57. elif message.type == 'ai':
  58. role = 'assistant'
  59. else:
  60. role = 'system'
  61. real_prompts.append({
  62. "role": role,
  63. "text": message.content,
  64. "files": [{
  65. "type": file.type.value,
  66. "data": file.data[:10] + '...[TRUNCATED]...' + file.data[-10:],
  67. "detail": file.detail.value if isinstance(file, ImagePromptMessageFile) else None,
  68. } for file in (message.files if isinstance(message, LCHumanMessageWithFiles) else [])]
  69. })
  70. self.llm_message.prompt = real_prompts
  71. self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0]))
  72. def on_llm_start(
  73. self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
  74. ) -> None:
  75. self.llm_message.prompt = [{
  76. "role": 'user',
  77. "text": prompts[0]
  78. }]
  79. self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
  80. def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
  81. if self.output_moderation_handler:
  82. self.output_moderation_handler.stop_thread()
  83. self.llm_message.completion = self.output_moderation_handler.moderation_completion(
  84. completion=response.generations[0][0].text,
  85. public_event=True if self.conversation_message_task.streaming else False
  86. )
  87. else:
  88. self.llm_message.completion = response.generations[0][0].text
  89. if not self.conversation_message_task.streaming:
  90. self.conversation_message_task.append_message_text(self.llm_message.completion)
  91. if response.llm_output and 'token_usage' in response.llm_output:
  92. if 'prompt_tokens' in response.llm_output['token_usage']:
  93. self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
  94. if 'completion_tokens' in response.llm_output['token_usage']:
  95. self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
  96. else:
  97. self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
  98. [PromptMessage(content=self.llm_message.completion)])
  99. else:
  100. self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
  101. [PromptMessage(content=self.llm_message.completion)])
  102. self.conversation_message_task.save_message(self.llm_message)
  103. def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
  104. if self.output_moderation_handler and self.output_moderation_handler.should_direct_output():
  105. # stop subscribe new token when output moderation should direct output
  106. ex = ConversationTaskInterruptException()
  107. self.on_llm_error(error=ex)
  108. raise ex
  109. try:
  110. self.conversation_message_task.append_message_text(token)
  111. self.llm_message.completion += token
  112. if self.output_moderation_handler:
  113. self.output_moderation_handler.append_new_token(token)
  114. except ConversationTaskStoppedException as ex:
  115. self.on_llm_error(error=ex)
  116. raise ex
  117. def on_llm_error(
  118. self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
  119. ) -> None:
  120. """Do nothing."""
  121. if self.output_moderation_handler:
  122. self.output_moderation_handler.stop_thread()
  123. if isinstance(error, ConversationTaskStoppedException):
  124. if self.conversation_message_task.streaming:
  125. self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
  126. [PromptMessage(content=self.llm_message.completion)]
  127. )
  128. self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
  129. if isinstance(error, ConversationTaskInterruptException):
  130. self.llm_message.completion = self.output_moderation_handler.get_final_output()
  131. self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
  132. [PromptMessage(content=self.llm_message.completion)]
  133. )
  134. self.conversation_message_task.save_message(llm_message=self.llm_message)
  135. else:
  136. logging.debug("on_llm_error: %s", error)
  137. class OutputModerationHandler(BaseModel):
  138. DEFAULT_BUFFER_SIZE: int = 300
  139. tenant_id: str
  140. app_id: str
  141. rule: ModerationRule
  142. on_message_replace_func: Any
  143. thread: Optional[threading.Thread] = None
  144. thread_running: bool = True
  145. buffer: str = ''
  146. is_final_chunk: bool = False
  147. final_output: Optional[str] = None
  148. class Config:
  149. arbitrary_types_allowed = True
  150. def should_direct_output(self):
  151. return self.final_output is not None
  152. def get_final_output(self):
  153. return self.final_output
  154. def append_new_token(self, token: str):
  155. self.buffer += token
  156. if not self.thread:
  157. self.thread = self.start_thread()
  158. def moderation_completion(self, completion: str, public_event: bool = False) -> str:
  159. self.buffer = completion
  160. self.is_final_chunk = True
  161. result = self.moderation(
  162. tenant_id=self.tenant_id,
  163. app_id=self.app_id,
  164. moderation_buffer=completion
  165. )
  166. if not result or not result.flagged:
  167. return completion
  168. if result.action == ModerationAction.DIRECT_OUTPUT:
  169. final_output = result.preset_response
  170. else:
  171. final_output = result.text
  172. if public_event:
  173. self.on_message_replace_func(final_output)
  174. return final_output
  175. def start_thread(self) -> threading.Thread:
  176. buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
  177. thread = threading.Thread(target=self.worker, kwargs={
  178. 'flask_app': current_app._get_current_object(),
  179. 'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
  180. })
  181. thread.start()
  182. return thread
  183. def stop_thread(self):
  184. if self.thread and self.thread.is_alive():
  185. self.thread_running = False
  186. def worker(self, flask_app: Flask, buffer_size: int):
  187. with flask_app.app_context():
  188. current_length = 0
  189. while self.thread_running:
  190. moderation_buffer = self.buffer
  191. buffer_length = len(moderation_buffer)
  192. if not self.is_final_chunk:
  193. chunk_length = buffer_length - current_length
  194. if 0 <= chunk_length < buffer_size:
  195. time.sleep(1)
  196. continue
  197. current_length = buffer_length
  198. result = self.moderation(
  199. tenant_id=self.tenant_id,
  200. app_id=self.app_id,
  201. moderation_buffer=moderation_buffer
  202. )
  203. if not result or not result.flagged:
  204. continue
  205. if result.action == ModerationAction.DIRECT_OUTPUT:
  206. final_output = result.preset_response
  207. self.final_output = final_output
  208. else:
  209. final_output = result.text + self.buffer[len(moderation_buffer):]
  210. # trigger replace event
  211. if self.thread_running:
  212. self.on_message_replace_func(final_output)
  213. if result.action == ModerationAction.DIRECT_OUTPUT:
  214. break
  215. def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
  216. try:
  217. moderation_factory = ModerationFactory(
  218. name=self.rule.type,
  219. app_id=app_id,
  220. tenant_id=tenant_id,
  221. config=self.rule.config
  222. )
  223. result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
  224. return result
  225. except Exception as e:
  226. logging.error("Moderation Output error: %s", e)
  227. return None