Просмотр исходного кода

generalize helper for loading module from source (#2862)

Bowen Liang 1 год назад
Родитель
Сommit
08b727833e

+ 4 - 12
api/core/extension/extensible.py

@@ -1,5 +1,4 @@
 import enum
-import importlib.util
 import json
 import logging
 import os
@@ -7,6 +6,7 @@ from typing import Any, Optional
 
 from pydantic import BaseModel
 
+from core.utils.module_import_helper import load_single_subclass_from_source
 from core.utils.position_helper import sort_to_dict_by_position_map
 
 
@@ -73,17 +73,9 @@ class Extensible:
 
                 # Dynamic loading {subdir_name}.py file and find the subclass of Extensible
                 py_path = os.path.join(subdir_path, extension_name + '.py')
-                spec = importlib.util.spec_from_file_location(extension_name, py_path)
-                mod = importlib.util.module_from_spec(spec)
-                spec.loader.exec_module(mod)
-
-                extension_class = None
-                for name, obj in vars(mod).items():
-                    if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
-                        extension_class = obj
-                        break
-
-                if not extension_class:
+                try:
+                    extension_class = load_single_subclass_from_source(extension_name, py_path, cls)
+                except Exception:
                     logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
                     continue
 

+ 5 - 12
api/core/model_runtime/model_providers/__base/model_provider.py

@@ -1,4 +1,3 @@
-import importlib
 import os
 from abc import ABC, abstractmethod
 
@@ -7,6 +6,7 @@ import yaml
 from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
 from core.model_runtime.entities.provider_entities import ProviderEntity
 from core.model_runtime.model_providers.__base.ai_model import AIModel
+from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source
 
 
 class ModelProvider(ABC):
@@ -104,17 +104,10 @@ class ModelProvider(ABC):
 
         # Dynamic loading {model_type_name}.py file and find the subclass of AIModel
         parent_module = '.'.join(self.__class__.__module__.split('.')[:-1])
-        spec = importlib.util.spec_from_file_location(f"{parent_module}.{model_type_name}.{model_type_name}", model_type_py_path)
-        mod = importlib.util.module_from_spec(spec)
-        spec.loader.exec_module(mod)
-
-        model_class = None
-        for name, obj in vars(mod).items():
-            if (isinstance(obj, type) and issubclass(obj, AIModel) and not obj.__abstractmethods__
-                    and obj != AIModel and obj.__module__ == mod.__name__):
-                model_class = obj
-                break
-
+        mod = import_module_from_source(
+            f'{parent_module}.{model_type_name}.{model_type_name}', model_type_py_path)
+        model_class = next(filter(lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
+                                  get_subclasses_from_module(mod, AIModel)), None)
         if not model_class:
             raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}')
 

+ 5 - 10
api/core/model_runtime/model_providers/model_provider_factory.py

@@ -1,4 +1,3 @@
-import importlib
 import logging
 import os
 from typing import Optional
@@ -10,6 +9,7 @@ from core.model_runtime.entities.provider_entities import ProviderConfig, Provid
 from core.model_runtime.model_providers.__base.model_provider import ModelProvider
 from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
 from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
+from core.utils.module_import_helper import load_single_subclass_from_source
 from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map
 
 logger = logging.getLogger(__name__)
@@ -229,15 +229,10 @@ class ModelProviderFactory:
 
             # Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider
             py_path = os.path.join(model_provider_dir_path, model_provider_name + '.py')
-            spec = importlib.util.spec_from_file_location(f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}', py_path)
-            mod = importlib.util.module_from_spec(spec)
-            spec.loader.exec_module(mod)
-
-            model_provider_class = None
-            for name, obj in vars(mod).items():
-                if isinstance(obj, type) and issubclass(obj, ModelProvider) and obj != ModelProvider:
-                    model_provider_class = obj
-                    break
+            model_provider_class = load_single_subclass_from_source(
+                module_name=f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}',
+                script_path=py_path,
+                parent_type=ModelProvider)
 
             if not model_provider_class:
                 logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.")

+ 6 - 11
api/core/tools/provider/builtin_tool_provider.py

@@ -1,4 +1,3 @@
-import importlib
 from abc import abstractmethod
 from os import listdir, path
 from typing import Any
@@ -16,6 +15,7 @@ from core.tools.errors import (
 from core.tools.provider.tool_provider import ToolProviderController
 from core.tools.tool.builtin_tool import BuiltinTool
 from core.tools.tool.tool import Tool
+from core.utils.module_import_helper import load_single_subclass_from_source
 
 
 class BuiltinToolProviderController(ToolProviderController):
@@ -63,16 +63,11 @@ class BuiltinToolProviderController(ToolProviderController):
                 tool_name = tool_file.split(".")[0]
                 tool = load(f.read(), FullLoader)
                 # get tool class, import the module
-                py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, 'tools', f'{tool_name}.py')
-                spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.tools.{tool_name}', py_path)
-                mod = importlib.util.module_from_spec(spec)
-                spec.loader.exec_module(mod)
-
-                # get all the classes in the module
-                classes = [x for _, x in vars(mod).items() 
-                    if isinstance(x, type) and x not in [BuiltinTool, Tool] and issubclass(x, BuiltinTool)
-                ]
-                assistant_tool_class = classes[0]
+                assistant_tool_class = load_single_subclass_from_source(
+                    module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}',
+                    script_path=path.join(path.dirname(path.realpath(__file__)),
+                                           'builtin', provider, 'tools', f'{tool_name}.py'),
+                    parent_type=BuiltinTool)
                 tools.append(assistant_tool_class(**tool))
 
         self.tools = tools

+ 11 - 32
api/core/tools/tool_manager.py

@@ -1,4 +1,3 @@
-import importlib
 import json
 import logging
 import mimetypes
@@ -34,6 +33,7 @@ from core.tools.utils.configuration import (
     ToolParameterConfigurationManager,
 )
 from core.tools.utils.encoder import serialize_base_model_dict
+from core.utils.module_import_helper import load_single_subclass_from_source
 from extensions.ext_database import db
 from models.tools import ApiToolProvider, BuiltinToolProvider
 
@@ -72,21 +72,11 @@ class ToolManager:
 
         if provider_entity is None:
             # fetch the provider from .provider.builtin
-            py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py')
-            spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path)
-            mod = importlib.util.module_from_spec(spec)
-            spec.loader.exec_module(mod)
-
-            # get all the classes in the module
-            classes = [ x for _, x in vars(mod).items() 
-                       if isinstance(x, type) and x != ToolProviderController and issubclass(x, ToolProviderController)
-            ]
-            if len(classes) == 0:
-                raise ToolProviderNotFoundError(f'provider {provider} not found')
-            if len(classes) > 1:
-                raise ToolProviderNotFoundError(f'multiple providers found for {provider}')
-            
-            provider_entity = classes[0]()
+            provider_class = load_single_subclass_from_source(
+                module_name=f'core.tools.provider.builtin.{provider}.{provider}',
+                script_path=path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py'),
+                parent_type=ToolProviderController)
+            provider_entity = provider_class()
 
         return provider_entity.invoke(tool_id, tool_name, tool_parameters, credentials, prompt_messages)
     
@@ -330,23 +320,12 @@ class ToolManager:
                 if provider.startswith('__'):
                     continue
 
-                py_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, f'{provider}.py')
-                spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path)
-                mod = importlib.util.module_from_spec(spec)
-                spec.loader.exec_module(mod)
-
-                # load all classes
-                classes = [
-                    obj for name, obj in vars(mod).items() 
-                        if isinstance(obj, type) and obj != BuiltinToolProviderController and issubclass(obj, BuiltinToolProviderController)
-                ]
-                if len(classes) == 0:
-                    raise ToolProviderNotFoundError(f'provider {provider} not found')
-                if len(classes) > 1:
-                    raise ToolProviderNotFoundError(f'multiple providers found for {provider}')
-                
                 # init provider
-                provider_class = classes[0]
+                provider_class = load_single_subclass_from_source(
+                    module_name=f'core.tools.provider.builtin.{provider}.{provider}',
+                    script_path=path.join(path.dirname(path.realpath(__file__)),
+                                           'provider', 'builtin', provider, f'{provider}.py'),
+                    parent_type=BuiltinToolProviderController)
                 builtin_providers.append(provider_class())
 
         # cache the builtin providers

+ 62 - 0
api/core/utils/module_import_helper.py

@@ -0,0 +1,62 @@
+import importlib.util
+import logging
+import sys
+from types import ModuleType
+from typing import AnyStr
+
+
+def import_module_from_source(
+        module_name: str,
+        py_file_path: AnyStr,
+        use_lazy_loader: bool = False
+) -> ModuleType:
+    """
+    Importing a module from the source file directly
+    """
+    try:
+        existed_spec = importlib.util.find_spec(module_name)
+        if existed_spec:
+            spec = existed_spec
+        else:
+            # Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
+            spec = importlib.util.spec_from_file_location(module_name, py_file_path)
+            if use_lazy_loader:
+                # Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports
+                spec.loader = importlib.util.LazyLoader(spec.loader)
+        module = importlib.util.module_from_spec(spec)
+        if not existed_spec:
+            sys.modules[module_name] = module
+        spec.loader.exec_module(module)
+        return module
+    except Exception as e:
+        logging.exception(f'Failed to load module {module_name} from {py_file_path}: {str(e)}')
+        raise e
+
+
+def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]:
+    """
+    Get all the subclasses of the parent type from the module
+    """
+    classes = [x for _, x in vars(mod).items()
+               if isinstance(x, type) and x != parent_type and issubclass(x, parent_type)]
+    return classes
+
+
+def load_single_subclass_from_source(
+        module_name: str,
+        script_path: AnyStr,
+        parent_type: type,
+        use_lazy_loader: bool = False,
+) -> type:
+    """
+    Load a single subclass from the source
+    """
+    module = import_module_from_source(module_name, script_path, use_lazy_loader)
+    subclasses = get_subclasses_from_module(module, parent_type)
+    match len(subclasses):
+        case 1:
+            return subclasses[0]
+        case 0:
+            raise Exception(f'Missing subclass of {parent_type.__name__} in {script_path}')
+        case _:
+            raise Exception(f'Multiple subclasses of {parent_type.__name__} in {script_path}')

+ 7 - 0
api/tests/integration_tests/utils/child_class.py

@@ -0,0 +1,7 @@
+from tests.integration_tests.utils.parent_class import ParentClass
+
+
+class ChildClass(ParentClass):
+    def __init__(self, name: str):
+        super().__init__(name)
+        self.name = name

+ 7 - 0
api/tests/integration_tests/utils/lazy_load_class.py

@@ -0,0 +1,7 @@
+from tests.integration_tests.utils.parent_class import ParentClass
+
+
+class LazyLoadChildClass(ParentClass):
+    def __init__(self, name: str):
+        super().__init__(name)
+        self.name = name

+ 6 - 0
api/tests/integration_tests/utils/parent_class.py

@@ -0,0 +1,6 @@
+class ParentClass:
+    def __init__(self, name):
+        self.name = name
+
+    def get_name(self):
+        return self.name

+ 32 - 0
api/tests/integration_tests/utils/test_module_import_helper.py

@@ -0,0 +1,32 @@
+import os
+
+from core.utils.module_import_helper import load_single_subclass_from_source, import_module_from_source
+from tests.integration_tests.utils.parent_class import ParentClass
+
+
+def test_loading_subclass_from_source():
+    current_path = os.getcwd()
+    module = load_single_subclass_from_source(
+        module_name='ChildClass',
+        script_path=os.path.join(current_path, 'child_class.py'),
+        parent_type=ParentClass)
+    assert module and module.__name__ == 'ChildClass'
+
+
+def test_load_import_module_from_source():
+    current_path = os.getcwd()
+    module = import_module_from_source(
+        module_name='ChildClass',
+        py_file_path=os.path.join(current_path, 'child_class.py'))
+    assert module and module.__name__ == 'ChildClass'
+
+
+def test_lazy_loading_subclass_from_source():
+    current_path = os.getcwd()
+    clz = load_single_subclass_from_source(
+        module_name='LazyLoadChildClass',
+        script_path=os.path.join(current_path, 'lazy_load_class.py'),
+        parent_type=ParentClass,
+        use_lazy_loader=True)
+    instance = clz('dify')
+    assert instance.get_name() == 'dify'