input_moderation.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import logging
  2. from typing import Optional
  3. from core.app.app_config.entities import AppConfig
  4. from core.moderation.base import ModerationAction, ModerationException
  5. from core.moderation.factory import ModerationFactory
  6. from core.ops.entities.trace_entity import TraceTaskName
  7. from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
  8. from core.ops.utils import measure_time
  9. logger = logging.getLogger(__name__)
  10. class InputModeration:
  11. def check(
  12. self, app_id: str,
  13. tenant_id: str,
  14. app_config: AppConfig,
  15. inputs: dict,
  16. query: str,
  17. message_id: str,
  18. trace_manager: Optional[TraceQueueManager] = None
  19. ) -> tuple[bool, dict, str]:
  20. """
  21. Process sensitive_word_avoidance.
  22. :param app_id: app id
  23. :param tenant_id: tenant id
  24. :param app_config: app config
  25. :param inputs: inputs
  26. :param query: query
  27. :param message_id: message id
  28. :param trace_manager: trace manager
  29. :return:
  30. """
  31. if not app_config.sensitive_word_avoidance:
  32. return False, inputs, query
  33. sensitive_word_avoidance_config = app_config.sensitive_word_avoidance
  34. moderation_type = sensitive_word_avoidance_config.type
  35. moderation_factory = ModerationFactory(
  36. name=moderation_type,
  37. app_id=app_id,
  38. tenant_id=tenant_id,
  39. config=sensitive_word_avoidance_config.config
  40. )
  41. with measure_time() as timer:
  42. moderation_result = moderation_factory.moderation_for_inputs(inputs, query)
  43. if trace_manager:
  44. trace_manager.add_trace_task(
  45. TraceTask(
  46. TraceTaskName.MODERATION_TRACE,
  47. message_id=message_id,
  48. moderation_result=moderation_result,
  49. inputs=inputs,
  50. timer=timer
  51. )
  52. )
  53. if not moderation_result.flagged:
  54. return False, inputs, query
  55. if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
  56. raise ModerationException(moderation_result.preset_response)
  57. elif moderation_result.action == ModerationAction.OVERRIDED:
  58. inputs = moderation_result.inputs
  59. query = moderation_result.query
  60. return True, inputs, query