moderation_handler.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import logging
  2. import threading
  3. import time
  4. from typing import Any, Optional
  5. from flask import Flask, current_app
  6. from pydantic import BaseModel
  7. from core.application_queue_manager import PublishFrom
  8. from core.moderation.base import ModerationAction, ModerationOutputsResult
  9. from core.moderation.factory import ModerationFactory
  10. logger = logging.getLogger(__name__)
  11. class ModerationRule(BaseModel):
  12. type: str
  13. config: dict[str, Any]
  14. class OutputModerationHandler(BaseModel):
  15. DEFAULT_BUFFER_SIZE: int = 300
  16. tenant_id: str
  17. app_id: str
  18. rule: ModerationRule
  19. on_message_replace_func: Any
  20. thread: Optional[threading.Thread] = None
  21. thread_running: bool = True
  22. buffer: str = ''
  23. is_final_chunk: bool = False
  24. final_output: Optional[str] = None
  25. class Config:
  26. arbitrary_types_allowed = True
  27. def should_direct_output(self):
  28. return self.final_output is not None
  29. def get_final_output(self):
  30. return self.final_output
  31. def append_new_token(self, token: str):
  32. self.buffer += token
  33. if not self.thread:
  34. self.thread = self.start_thread()
  35. def moderation_completion(self, completion: str, public_event: bool = False) -> str:
  36. self.buffer = completion
  37. self.is_final_chunk = True
  38. result = self.moderation(
  39. tenant_id=self.tenant_id,
  40. app_id=self.app_id,
  41. moderation_buffer=completion
  42. )
  43. if not result or not result.flagged:
  44. return completion
  45. if result.action == ModerationAction.DIRECT_OUTPUT:
  46. final_output = result.preset_response
  47. else:
  48. final_output = result.text
  49. if public_event:
  50. self.on_message_replace_func(final_output, PublishFrom.TASK_PIPELINE)
  51. return final_output
  52. def start_thread(self) -> threading.Thread:
  53. buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
  54. thread = threading.Thread(target=self.worker, kwargs={
  55. 'flask_app': current_app._get_current_object(),
  56. 'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
  57. })
  58. thread.start()
  59. return thread
  60. def stop_thread(self):
  61. if self.thread and self.thread.is_alive():
  62. self.thread_running = False
  63. def worker(self, flask_app: Flask, buffer_size: int):
  64. with flask_app.app_context():
  65. current_length = 0
  66. while self.thread_running:
  67. moderation_buffer = self.buffer
  68. buffer_length = len(moderation_buffer)
  69. if not self.is_final_chunk:
  70. chunk_length = buffer_length - current_length
  71. if 0 <= chunk_length < buffer_size:
  72. time.sleep(1)
  73. continue
  74. current_length = buffer_length
  75. result = self.moderation(
  76. tenant_id=self.tenant_id,
  77. app_id=self.app_id,
  78. moderation_buffer=moderation_buffer
  79. )
  80. if not result or not result.flagged:
  81. continue
  82. if result.action == ModerationAction.DIRECT_OUTPUT:
  83. final_output = result.preset_response
  84. self.final_output = final_output
  85. else:
  86. final_output = result.text + self.buffer[len(moderation_buffer):]
  87. # trigger replace event
  88. if self.thread_running:
  89. self.on_message_replace_func(final_output, PublishFrom.TASK_PIPELINE)
  90. if result.action == ModerationAction.DIRECT_OUTPUT:
  91. break
  92. def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
  93. try:
  94. moderation_factory = ModerationFactory(
  95. name=self.rule.type,
  96. app_id=app_id,
  97. tenant_id=tenant_id,
  98. config=self.rule.config
  99. )
  100. result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
  101. return result
  102. except Exception as e:
  103. logger.error("Moderation Output error: %s", e)
  104. return None