output_moderation.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import logging
  2. import threading
  3. import time
  4. from typing import Any, Optional
  5. from flask import Flask, current_app
  6. from pydantic import BaseModel
  7. from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
  8. from core.app.entities.queue_entities import QueueMessageReplaceEvent
  9. from core.moderation.base import ModerationAction, ModerationOutputsResult
  10. from core.moderation.factory import ModerationFactory
  11. logger = logging.getLogger(__name__)
  12. class ModerationRule(BaseModel):
  13. type: str
  14. config: dict[str, Any]
  15. class OutputModeration(BaseModel):
  16. DEFAULT_BUFFER_SIZE: int = 300
  17. tenant_id: str
  18. app_id: str
  19. rule: ModerationRule
  20. queue_manager: AppQueueManager
  21. thread: Optional[threading.Thread] = None
  22. thread_running: bool = True
  23. buffer: str = ''
  24. is_final_chunk: bool = False
  25. final_output: Optional[str] = None
  26. class Config:
  27. arbitrary_types_allowed = True
  28. def should_direct_output(self):
  29. return self.final_output is not None
  30. def get_final_output(self):
  31. return self.final_output
  32. def append_new_token(self, token: str):
  33. self.buffer += token
  34. if not self.thread:
  35. self.thread = self.start_thread()
  36. def moderation_completion(self, completion: str, public_event: bool = False) -> str:
  37. self.buffer = completion
  38. self.is_final_chunk = True
  39. result = self.moderation(
  40. tenant_id=self.tenant_id,
  41. app_id=self.app_id,
  42. moderation_buffer=completion
  43. )
  44. if not result or not result.flagged:
  45. return completion
  46. if result.action == ModerationAction.DIRECT_OUTPUT:
  47. final_output = result.preset_response
  48. else:
  49. final_output = result.text
  50. if public_event:
  51. self.queue_manager.publish(
  52. QueueMessageReplaceEvent(
  53. text=final_output
  54. ),
  55. PublishFrom.TASK_PIPELINE
  56. )
  57. return final_output
  58. def start_thread(self) -> threading.Thread:
  59. buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
  60. thread = threading.Thread(target=self.worker, kwargs={
  61. 'flask_app': current_app._get_current_object(),
  62. 'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
  63. })
  64. thread.start()
  65. return thread
  66. def stop_thread(self):
  67. if self.thread and self.thread.is_alive():
  68. self.thread_running = False
  69. def worker(self, flask_app: Flask, buffer_size: int):
  70. with flask_app.app_context():
  71. current_length = 0
  72. while self.thread_running:
  73. moderation_buffer = self.buffer
  74. buffer_length = len(moderation_buffer)
  75. if not self.is_final_chunk:
  76. chunk_length = buffer_length - current_length
  77. if 0 <= chunk_length < buffer_size:
  78. time.sleep(1)
  79. continue
  80. current_length = buffer_length
  81. result = self.moderation(
  82. tenant_id=self.tenant_id,
  83. app_id=self.app_id,
  84. moderation_buffer=moderation_buffer
  85. )
  86. if not result or not result.flagged:
  87. continue
  88. if result.action == ModerationAction.DIRECT_OUTPUT:
  89. final_output = result.preset_response
  90. self.final_output = final_output
  91. else:
  92. final_output = result.text + self.buffer[len(moderation_buffer):]
  93. # trigger replace event
  94. if self.thread_running:
  95. self.queue_manager.publish(
  96. QueueMessageReplaceEvent(
  97. text=final_output
  98. ),
  99. PublishFrom.TASK_PIPELINE
  100. )
  101. if result.action == ModerationAction.DIRECT_OUTPUT:
  102. break
  103. def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
  104. try:
  105. moderation_factory = ModerationFactory(
  106. name=self.rule.type,
  107. app_id=app_id,
  108. tenant_id=tenant_id,
  109. config=self.rule.config
  110. )
  111. result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
  112. return result
  113. except Exception as e:
  114. logger.error("Moderation Output error: %s", e)
  115. return None