api_key_auth_service.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import json
  2. from core.helper import encrypter
  3. from extensions.ext_database import db
  4. from models.source import DataSourceApiKeyAuthBinding
  5. from services.auth.api_key_auth_factory import ApiKeyAuthFactory
  6. class ApiKeyAuthService:
  7. @staticmethod
  8. def get_provider_auth_list(tenant_id: str) -> list:
  9. data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
  10. DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
  11. DataSourceApiKeyAuthBinding.disabled.is_(False)
  12. ).all()
  13. return data_source_api_key_bindings
  14. @staticmethod
  15. def create_provider_auth(tenant_id: str, args: dict):
  16. auth_result = ApiKeyAuthFactory(args['provider'], args['credentials']).validate_credentials()
  17. if auth_result:
  18. # Encrypt the api key
  19. api_key = encrypter.encrypt_token(tenant_id, args['credentials']['config']['api_key'])
  20. args['credentials']['config']['api_key'] = api_key
  21. data_source_api_key_binding = DataSourceApiKeyAuthBinding()
  22. data_source_api_key_binding.tenant_id = tenant_id
  23. data_source_api_key_binding.category = args['category']
  24. data_source_api_key_binding.provider = args['provider']
  25. data_source_api_key_binding.credentials = json.dumps(args['credentials'], ensure_ascii=False)
  26. db.session.add(data_source_api_key_binding)
  27. db.session.commit()
  28. @staticmethod
  29. def get_auth_credentials(tenant_id: str, category: str, provider: str):
  30. data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
  31. DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
  32. DataSourceApiKeyAuthBinding.category == category,
  33. DataSourceApiKeyAuthBinding.provider == provider,
  34. DataSourceApiKeyAuthBinding.disabled.is_(False)
  35. ).first()
  36. if not data_source_api_key_bindings:
  37. return None
  38. credentials = json.loads(data_source_api_key_bindings.credentials)
  39. return credentials
  40. @staticmethod
  41. def delete_provider_auth(tenant_id: str, binding_id: str):
  42. data_source_api_key_binding = db.session.query(DataSourceApiKeyAuthBinding).filter(
  43. DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
  44. DataSourceApiKeyAuthBinding.id == binding_id
  45. ).first()
  46. if data_source_api_key_binding:
  47. db.session.delete(data_source_api_key_binding)
  48. db.session.commit()
  49. @classmethod
  50. def validate_api_key_auth_args(cls, args):
  51. if 'category' not in args or not args['category']:
  52. raise ValueError('category is required')
  53. if 'provider' not in args or not args['provider']:
  54. raise ValueError('provider is required')
  55. if 'credentials' not in args or not args['credentials']:
  56. raise ValueError('credentials is required')
  57. if not isinstance(args['credentials'], dict):
  58. raise ValueError('credentials must be a dictionary')
  59. if 'auth_type' not in args['credentials'] or not args['credentials']['auth_type']:
  60. raise ValueError('auth_type is required')