base.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import base64
  2. from abc import ABC, abstractmethod
  3. from typing import Optional
  4. from extensions.ext_database import db
  5. from libs import rsa
  6. from models.account import Tenant
  7. from models.tool import ToolProvider, ToolProviderName
  8. class BaseToolProvider(ABC):
  9. def __init__(self, tenant_id: str):
  10. self.tenant_id = tenant_id
  11. @abstractmethod
  12. def get_provider_name(self) -> ToolProviderName:
  13. raise NotImplementedError
  14. @abstractmethod
  15. def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
  16. raise NotImplementedError
  17. @abstractmethod
  18. def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
  19. raise NotImplementedError
  20. @abstractmethod
  21. def credentials_to_func_kwargs(self) -> Optional[dict]:
  22. raise NotImplementedError
  23. @abstractmethod
  24. def credentials_validate(self, credentials: dict):
  25. raise NotImplementedError
  26. def get_provider(self, must_enabled: bool = False) -> Optional[ToolProvider]:
  27. """
  28. Returns the Provider instance for the given tenant_id and tool_name.
  29. """
  30. query = db.session.query(ToolProvider).filter(
  31. ToolProvider.tenant_id == self.tenant_id,
  32. ToolProvider.tool_name == self.get_provider_name().value
  33. )
  34. if must_enabled:
  35. query = query.filter(ToolProvider.is_enabled == True)
  36. return query.first()
  37. def encrypt_token(self, token) -> str:
  38. tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
  39. encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
  40. return base64.b64encode(encrypted_token).decode()
  41. def decrypt_token(self, token: str, obfuscated: bool = False) -> str:
  42. token = rsa.decrypt(base64.b64decode(token), self.tenant_id)
  43. if obfuscated:
  44. return self._obfuscated_token(token)
  45. return token
  46. def _obfuscated_token(self, token: str) -> str:
  47. return token[:6] + '*' * (len(token) - 8) + token[-2:]