input_moderation.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import logging
  2. from core.app.app_config.entities import AppConfig
  3. from core.moderation.base import ModerationAction, ModerationException
  4. from core.moderation.factory import ModerationFactory
  5. logger = logging.getLogger(__name__)
  6. class InputModeration:
  7. def check(self, app_id: str,
  8. tenant_id: str,
  9. app_config: AppConfig,
  10. inputs: dict,
  11. query: str) -> tuple[bool, dict, str]:
  12. """
  13. Process sensitive_word_avoidance.
  14. :param app_id: app id
  15. :param tenant_id: tenant id
  16. :param app_config: app config
  17. :param inputs: inputs
  18. :param query: query
  19. :return:
  20. """
  21. if not app_config.sensitive_word_avoidance:
  22. return False, inputs, query
  23. sensitive_word_avoidance_config = app_config.sensitive_word_avoidance
  24. moderation_type = sensitive_word_avoidance_config.type
  25. moderation_factory = ModerationFactory(
  26. name=moderation_type,
  27. app_id=app_id,
  28. tenant_id=tenant_id,
  29. config=sensitive_word_avoidance_config.config
  30. )
  31. moderation_result = moderation_factory.moderation_for_inputs(inputs, query)
  32. if not moderation_result.flagged:
  33. return False, inputs, query
  34. if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
  35. raise ModerationException(moderation_result.preset_response)
  36. elif moderation_result.action == ModerationAction.OVERRIDED:
  37. inputs = moderation_result.inputs
  38. query = moderation_result.query
  39. return True, inputs, query