api.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. from pydantic import BaseModel
  2. from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor
  3. from core.helper.encrypter import decrypt_token
  4. from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
  5. from extensions.ext_database import db
  6. from models.api_based_extension import APIBasedExtension
  7. class ModerationInputParams(BaseModel):
  8. app_id: str = ""
  9. inputs: dict = {}
  10. query: str = ""
  11. class ModerationOutputParams(BaseModel):
  12. app_id: str = ""
  13. text: str
  14. class ApiModeration(Moderation):
  15. name: str = "api"
  16. @classmethod
  17. def validate_config(cls, tenant_id: str, config: dict) -> None:
  18. """
  19. Validate the incoming form config data.
  20. :param tenant_id: the id of workspace
  21. :param config: the form config data
  22. :return:
  23. """
  24. cls._validate_inputs_and_outputs_config(config, False)
  25. api_based_extension_id = config.get("api_based_extension_id")
  26. if not api_based_extension_id:
  27. raise ValueError("api_based_extension_id is required")
  28. extension = cls._get_api_based_extension(tenant_id, api_based_extension_id)
  29. if not extension:
  30. raise ValueError("API-based Extension not found. Please check it again.")
  31. def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
  32. flagged = False
  33. preset_response = ""
  34. if self.config['inputs_config']['enabled']:
  35. params = ModerationInputParams(
  36. app_id=self.app_id,
  37. inputs=inputs,
  38. query=query
  39. )
  40. result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump())
  41. return ModerationInputsResult(**result)
  42. return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
  43. def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
  44. flagged = False
  45. preset_response = ""
  46. if self.config['outputs_config']['enabled']:
  47. params = ModerationOutputParams(
  48. app_id=self.app_id,
  49. text=text
  50. )
  51. result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump())
  52. return ModerationOutputsResult(**result)
  53. return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
  54. def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict:
  55. extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id"))
  56. requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key))
  57. result = requestor.request(extension_point, params)
  58. return result
  59. @staticmethod
  60. def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
  61. extension = db.session.query(APIBasedExtension).filter(
  62. APIBasedExtension.tenant_id == tenant_id,
  63. APIBasedExtension.id == api_based_extension_id
  64. ).first()
  65. return extension