api_key_auth_service.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import json
  2. from core.helper import encrypter
  3. from extensions.ext_database import db
  4. from models.source import DataSourceApiKeyAuthBinding
  5. from services.auth.api_key_auth_factory import ApiKeyAuthFactory
  6. class ApiKeyAuthService:
  7. @staticmethod
  8. def get_provider_auth_list(tenant_id: str) -> list:
  9. data_source_api_key_bindings = (
  10. db.session.query(DataSourceApiKeyAuthBinding)
  11. .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))
  12. .all()
  13. )
  14. return data_source_api_key_bindings
  15. @staticmethod
  16. def create_provider_auth(tenant_id: str, args: dict):
  17. auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
  18. if auth_result:
  19. # Encrypt the api key
  20. api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"])
  21. args["credentials"]["config"]["api_key"] = api_key
  22. data_source_api_key_binding = DataSourceApiKeyAuthBinding()
  23. data_source_api_key_binding.tenant_id = tenant_id
  24. data_source_api_key_binding.category = args["category"]
  25. data_source_api_key_binding.provider = args["provider"]
  26. data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
  27. db.session.add(data_source_api_key_binding)
  28. db.session.commit()
  29. @staticmethod
  30. def get_auth_credentials(tenant_id: str, category: str, provider: str):
  31. data_source_api_key_bindings = (
  32. db.session.query(DataSourceApiKeyAuthBinding)
  33. .filter(
  34. DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
  35. DataSourceApiKeyAuthBinding.category == category,
  36. DataSourceApiKeyAuthBinding.provider == provider,
  37. DataSourceApiKeyAuthBinding.disabled.is_(False),
  38. )
  39. .first()
  40. )
  41. if not data_source_api_key_bindings:
  42. return None
  43. credentials = json.loads(data_source_api_key_bindings.credentials)
  44. return credentials
  45. @staticmethod
  46. def delete_provider_auth(tenant_id: str, binding_id: str):
  47. data_source_api_key_binding = (
  48. db.session.query(DataSourceApiKeyAuthBinding)
  49. .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
  50. .first()
  51. )
  52. if data_source_api_key_binding:
  53. db.session.delete(data_source_api_key_binding)
  54. db.session.commit()
  55. @classmethod
  56. def validate_api_key_auth_args(cls, args):
  57. if "category" not in args or not args["category"]:
  58. raise ValueError("category is required")
  59. if "provider" not in args or not args["provider"]:
  60. raise ValueError("provider is required")
  61. if "credentials" not in args or not args["credentials"]:
  62. raise ValueError("credentials is required")
  63. if not isinstance(args["credentials"], dict):
  64. raise ValueError("credentials must be a dictionary")
  65. if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]:
  66. raise ValueError("auth_type is required")