configuration.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from typing import Any, Dict
  2. from core.helper import encrypter
  3. from core.tools.entities.tool_entities import ToolProviderCredentials
  4. from core.tools.provider.tool_provider import ToolProviderController
  5. from pydantic import BaseModel
  6. class ToolConfiguration(BaseModel):
  7. tenant_id: str
  8. provider_controller: ToolProviderController
  9. def _deep_copy(self, credentials: Dict[str, str]) -> Dict[str, str]:
  10. """
  11. deep copy credentials
  12. """
  13. return {key: value for key, value in credentials.items()}
  14. def encrypt_tool_credentials(self, credentials: Dict[str, str]) -> Dict[str, str]:
  15. """
  16. encrypt tool credentials with tenant id
  17. return a deep copy of credentials with encrypted values
  18. """
  19. credentials = self._deep_copy(credentials)
  20. # get fields need to be decrypted
  21. fields = self.provider_controller.get_credentials_schema()
  22. for field_name, field in fields.items():
  23. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  24. if field_name in credentials:
  25. encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
  26. credentials[field_name] = encrypted
  27. return credentials
  28. def mask_tool_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any]:
  29. """
  30. mask tool credentials
  31. return a deep copy of credentials with masked values
  32. """
  33. credentials = self._deep_copy(credentials)
  34. # get fields need to be decrypted
  35. fields = self.provider_controller.get_credentials_schema()
  36. for field_name, field in fields.items():
  37. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  38. if field_name in credentials:
  39. if len(credentials[field_name]) > 6:
  40. credentials[field_name] = \
  41. credentials[field_name][:2] + \
  42. '*' * (len(credentials[field_name]) - 4) +\
  43. credentials[field_name][-2:]
  44. else:
  45. credentials[field_name] = '*' * len(credentials[field_name])
  46. return credentials
  47. def decrypt_tool_credentials(self, credentials: Dict[str, str]) -> Dict[str, str]:
  48. """
  49. decrypt tool credentials with tenant id
  50. return a deep copy of credentials with decrypted values
  51. """
  52. credentials = self._deep_copy(credentials)
  53. # get fields need to be decrypted
  54. fields = self.provider_controller.get_credentials_schema()
  55. for field_name, field in fields.items():
  56. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  57. if field_name in credentials:
  58. try:
  59. credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
  60. except:
  61. pass
  62. return credentials