api_based_extension_service.py 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
  2. from core.helper.encrypter import decrypt_token, encrypt_token
  3. from extensions.ext_database import db
  4. from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
  5. class APIBasedExtensionService:
  6. @staticmethod
  7. def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]:
  8. extension_list = db.session.query(APIBasedExtension) \
  9. .filter_by(tenant_id=tenant_id) \
  10. .order_by(APIBasedExtension.created_at.desc()) \
  11. .all()
  12. for extension in extension_list:
  13. extension.api_key = decrypt_token(extension.tenant_id, extension.api_key)
  14. return extension_list
  15. @classmethod
  16. def save(cls, extension_data: APIBasedExtension) -> APIBasedExtension:
  17. cls._validation(extension_data)
  18. extension_data.api_key = encrypt_token(extension_data.tenant_id, extension_data.api_key)
  19. db.session.add(extension_data)
  20. db.session.commit()
  21. return extension_data
  22. @staticmethod
  23. def delete(extension_data: APIBasedExtension) -> None:
  24. db.session.delete(extension_data)
  25. db.session.commit()
  26. @staticmethod
  27. def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
  28. extension = db.session.query(APIBasedExtension) \
  29. .filter_by(tenant_id=tenant_id) \
  30. .filter_by(id=api_based_extension_id) \
  31. .first()
  32. if not extension:
  33. raise ValueError("API based extension is not found")
  34. extension.api_key = decrypt_token(extension.tenant_id, extension.api_key)
  35. return extension
  36. @classmethod
  37. def _validation(cls, extension_data: APIBasedExtension) -> None:
  38. # name
  39. if not extension_data.name:
  40. raise ValueError("name must not be empty")
  41. if not extension_data.id:
  42. # case one: check new data, name must be unique
  43. is_name_existed = db.session.query(APIBasedExtension) \
  44. .filter_by(tenant_id=extension_data.tenant_id) \
  45. .filter_by(name=extension_data.name) \
  46. .first()
  47. if is_name_existed:
  48. raise ValueError("name must be unique, it is already existed")
  49. else:
  50. # case two: check existing data, name must be unique
  51. is_name_existed = db.session.query(APIBasedExtension) \
  52. .filter_by(tenant_id=extension_data.tenant_id) \
  53. .filter_by(name=extension_data.name) \
  54. .filter(APIBasedExtension.id != extension_data.id) \
  55. .first()
  56. if is_name_existed:
  57. raise ValueError("name must be unique, it is already existed")
  58. # api_endpoint
  59. if not extension_data.api_endpoint:
  60. raise ValueError("api_endpoint must not be empty")
  61. # api_key
  62. if not extension_data.api_key:
  63. raise ValueError("api_key must not be empty")
  64. if len(extension_data.api_key) < 5:
  65. raise ValueError("api_key must be at least 5 characters")
  66. # check endpoint
  67. cls._ping_connection(extension_data)
  68. @staticmethod
  69. def _ping_connection(extension_data: APIBasedExtension) -> None:
  70. try:
  71. client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key)
  72. resp = client.request(point=APIBasedExtensionPoint.PING, params={})
  73. if resp.get('result') != 'pong':
  74. raise ValueError(resp)
  75. except Exception as e:
  76. raise ValueError("connection error: {}".format(e))