extensible.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import enum
  2. import importlib.util
  3. import json
  4. import logging
  5. import os
  6. from pathlib import Path
  7. from typing import Any, Optional
  8. from pydantic import BaseModel
  9. from core.helper.position_helper import sort_to_dict_by_position_map
  10. class ExtensionModule(enum.Enum):
  11. MODERATION = "moderation"
  12. EXTERNAL_DATA_TOOL = "external_data_tool"
  13. class ModuleExtension(BaseModel):
  14. extension_class: Any = None
  15. name: str
  16. label: Optional[dict] = None
  17. form_schema: Optional[list] = None
  18. builtin: bool = True
  19. position: Optional[int] = None
  20. class Extensible:
  21. module: ExtensionModule
  22. name: str
  23. tenant_id: str
  24. config: Optional[dict] = None
  25. def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None:
  26. self.tenant_id = tenant_id
  27. self.config = config
  28. @classmethod
  29. def scan_extensions(cls):
  30. extensions: list[ModuleExtension] = []
  31. position_map = {}
  32. # get the path of the current class
  33. current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py")
  34. current_dir_path = os.path.dirname(current_path)
  35. # traverse subdirectories
  36. for subdir_name in os.listdir(current_dir_path):
  37. if subdir_name.startswith("__"):
  38. continue
  39. subdir_path = os.path.join(current_dir_path, subdir_name)
  40. extension_name = subdir_name
  41. if os.path.isdir(subdir_path):
  42. file_names = os.listdir(subdir_path)
  43. # is builtin extension, builtin extension
  44. # in the front-end page and business logic, there are special treatments.
  45. builtin = False
  46. position = None
  47. if "__builtin__" in file_names:
  48. builtin = True
  49. builtin_file_path = os.path.join(subdir_path, "__builtin__")
  50. if os.path.exists(builtin_file_path):
  51. position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip())
  52. position_map[extension_name] = position
  53. if (extension_name + ".py") not in file_names:
  54. logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
  55. continue
  56. # Dynamic loading {subdir_name}.py file and find the subclass of Extensible
  57. py_path = os.path.join(subdir_path, extension_name + ".py")
  58. spec = importlib.util.spec_from_file_location(extension_name, py_path)
  59. if not spec or not spec.loader:
  60. raise Exception(f"Failed to load module {extension_name} from {py_path}")
  61. mod = importlib.util.module_from_spec(spec)
  62. spec.loader.exec_module(mod)
  63. extension_class = None
  64. for name, obj in vars(mod).items():
  65. if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
  66. extension_class = obj
  67. break
  68. if not extension_class:
  69. logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
  70. continue
  71. json_data = {}
  72. if not builtin:
  73. if "schema.json" not in file_names:
  74. logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
  75. continue
  76. json_path = os.path.join(subdir_path, "schema.json")
  77. json_data = {}
  78. if os.path.exists(json_path):
  79. with open(json_path, encoding="utf-8") as f:
  80. json_data = json.load(f)
  81. extensions.append(
  82. ModuleExtension(
  83. extension_class=extension_class,
  84. name=extension_name,
  85. label=json_data.get("label"),
  86. form_schema=json_data.get("form_schema"),
  87. builtin=builtin,
  88. position=position,
  89. )
  90. )
  91. sorted_extensions = sort_to_dict_by_position_map(
  92. position_map=position_map, data=extensions, name_func=lambda x: x.name
  93. )
  94. return sorted_extensions