base.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from abc import ABC, abstractmethod
  2. from typing import Optional
  3. from pydantic import BaseModel
  4. from enum import Enum
  5. from core.extension.extensible import Extensible, ExtensionModule
  6. class ModerationAction(Enum):
  7. DIRECT_OUTPUT = 'direct_output'
  8. OVERRIDED = 'overrided'
  9. class ModerationInputsResult(BaseModel):
  10. flagged: bool = False
  11. action: ModerationAction
  12. preset_response: str = ""
  13. inputs: dict = {}
  14. query: str = ""
  15. class ModerationOutputsResult(BaseModel):
  16. flagged: bool = False
  17. action: ModerationAction
  18. preset_response: str = ""
  19. text: str = ""
  20. class Moderation(Extensible, ABC):
  21. """
  22. The base class of moderation.
  23. """
  24. module: ExtensionModule = ExtensionModule.MODERATION
  25. def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None:
  26. super().__init__(tenant_id, config)
  27. self.app_id = app_id
  28. @classmethod
  29. @abstractmethod
  30. def validate_config(cls, tenant_id: str, config: dict) -> None:
  31. """
  32. Validate the incoming form config data.
  33. :param tenant_id: the id of workspace
  34. :param config: the form config data
  35. :return:
  36. """
  37. raise NotImplementedError
  38. @abstractmethod
  39. def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
  40. """
  41. Moderation for inputs.
  42. After the user inputs, this method will be called to perform sensitive content review
  43. on the user inputs and return the processed results.
  44. :param inputs: user inputs
  45. :param query: query string (required in chat app)
  46. :return:
  47. """
  48. raise NotImplementedError
  49. @abstractmethod
  50. def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
  51. """
  52. Moderation for outputs.
  53. When LLM outputs content, the front end will pass the output content (may be segmented)
  54. to this method for sensitive content review, and the output content will be shielded if the review fails.
  55. :param text: LLM output content
  56. :return:
  57. """
  58. raise NotImplementedError
  59. @classmethod
  60. def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_required: bool) -> None:
  61. # inputs_config
  62. inputs_config = config.get("inputs_config")
  63. if not isinstance(inputs_config, dict):
  64. raise ValueError("inputs_config must be a dict")
  65. # outputs_config
  66. outputs_config = config.get("outputs_config")
  67. if not isinstance(outputs_config, dict):
  68. raise ValueError("outputs_config must be a dict")
  69. inputs_config_enabled = inputs_config.get("enabled")
  70. outputs_config_enabled = outputs_config.get("enabled")
  71. if not inputs_config_enabled and not outputs_config_enabled:
  72. raise ValueError("At least one of inputs_config or outputs_config must be enabled")
  73. # preset_response
  74. if not is_preset_response_required:
  75. return
  76. if inputs_config_enabled:
  77. if not inputs_config.get("preset_response"):
  78. raise ValueError("inputs_config.preset_response is required")
  79. if len(inputs_config.get("preset_response")) > 100:
  80. raise ValueError("inputs_config.preset_response must be less than 100 characters")
  81. if outputs_config_enabled:
  82. if not outputs_config.get("preset_response"):
  83. raise ValueError("outputs_config.preset_response is required")
  84. if len(outputs_config.get("preset_response")) > 100:
  85. raise ValueError("outputs_config.preset_response must be less than 100 characters")
  86. class ModerationException(Exception):
  87. pass