cloud_service.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
  2. class CloudServiceModeration(Moderation):
  3. """
  4. The name of custom type must be unique, keep the same with directory and file name.
  5. """
  6. name: str = "cloud_service"
  7. @classmethod
  8. def validate_config(cls, tenant_id: str, config: dict) -> None:
  9. """
  10. schema.json validation. It will be called when user save the config.
  11. Example:
  12. .. code-block:: python
  13. config = {
  14. "cloud_provider": "GoogleCloud",
  15. "api_endpoint": "https://api.example.com",
  16. "api_keys": "123456",
  17. "inputs_config": {
  18. "enabled": True,
  19. "preset_response": "Your content violates our usage policy. Please revise and try again."
  20. },
  21. "outputs_config": {
  22. "enabled": True,
  23. "preset_response": "Your content violates our usage policy. Please revise and try again."
  24. }
  25. }
  26. :param tenant_id: the id of workspace
  27. :param config: the variables of form config
  28. :return:
  29. """
  30. cls._validate_inputs_and_outputs_config(config, True)
  31. if not config.get("cloud_provider"):
  32. raise ValueError("cloud_provider is required")
  33. if not config.get("api_endpoint"):
  34. raise ValueError("api_endpoint is required")
  35. if not config.get("api_keys"):
  36. raise ValueError("api_keys is required")
  37. def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
  38. """
  39. Moderation for inputs.
  40. :param inputs: user inputs
  41. :param query: the query of chat app, there is empty if is completion app
  42. :return: the moderation result
  43. """
  44. flagged = False
  45. preset_response = ""
  46. if self.config['inputs_config']['enabled']:
  47. preset_response = self.config['inputs_config']['preset_response']
  48. if query:
  49. inputs['query__'] = query
  50. flagged = self._is_violated(inputs)
  51. # return ModerationInputsResult(flagged=flagged, action=ModerationAction.OVERRIDED, inputs=inputs, query=query)
  52. return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
  53. def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
  54. """
  55. Moderation for outputs.
  56. :param text: the text of LLM response
  57. :return: the moderation result
  58. """
  59. flagged = False
  60. preset_response = ""
  61. if self.config['outputs_config']['enabled']:
  62. preset_response = self.config['outputs_config']['preset_response']
  63. flagged = self._is_violated({'text': text})
  64. # return ModerationOutputsResult(flagged=flagged, action=ModerationAction.OVERRIDED, text=text)
  65. return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
  66. def _is_violated(self, inputs: dict):
  67. """
  68. The main logic of moderation.
  69. :param inputs:
  70. :return: the moderation result
  71. """
  72. return False