input_moderation.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import logging
  2. from typing import Optional
  3. from core.app.app_config.entities import AppConfig
  4. from core.moderation.base import ModerationAction, ModerationException
  5. from core.moderation.factory import ModerationFactory
  6. from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
  7. from core.ops.utils import measure_time
  8. logger = logging.getLogger(__name__)
  9. class InputModeration:
  10. def check(
  11. self, app_id: str,
  12. tenant_id: str,
  13. app_config: AppConfig,
  14. inputs: dict,
  15. query: str,
  16. message_id: str,
  17. trace_manager: Optional[TraceQueueManager] = None
  18. ) -> tuple[bool, dict, str]:
  19. """
  20. Process sensitive_word_avoidance.
  21. :param app_id: app id
  22. :param tenant_id: tenant id
  23. :param app_config: app config
  24. :param inputs: inputs
  25. :param query: query
  26. :param message_id: message id
  27. :param trace_manager: trace manager
  28. :return:
  29. """
  30. if not app_config.sensitive_word_avoidance:
  31. return False, inputs, query
  32. sensitive_word_avoidance_config = app_config.sensitive_word_avoidance
  33. moderation_type = sensitive_word_avoidance_config.type
  34. moderation_factory = ModerationFactory(
  35. name=moderation_type,
  36. app_id=app_id,
  37. tenant_id=tenant_id,
  38. config=sensitive_word_avoidance_config.config
  39. )
  40. with measure_time() as timer:
  41. moderation_result = moderation_factory.moderation_for_inputs(inputs, query)
  42. if trace_manager:
  43. trace_manager.add_trace_task(
  44. TraceTask(
  45. TraceTaskName.MODERATION_TRACE,
  46. message_id=message_id,
  47. moderation_result=moderation_result,
  48. inputs=inputs,
  49. timer=timer
  50. )
  51. )
  52. if not moderation_result.flagged:
  53. return False, inputs, query
  54. if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
  55. raise ModerationException(moderation_result.preset_response)
  56. elif moderation_result.action == ModerationAction.OVERRIDED:
  57. inputs = moderation_result.inputs
  58. query = moderation_result.query
  59. return True, inputs, query