Przeglądaj źródła

improve: extract method for safe loading yaml file and avoid using PyYaml's FullLoader (#4031)

Bowen Liang 11 miesięcy temu
rodzic
commit
3fda2245a4

+ 2 - 4
api/core/model_runtime/model_providers/__base/ai_model.py

@@ -3,8 +3,6 @@ import os
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from typing import Optional
 from typing import Optional
 
 
-import yaml
-
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
 from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
 from core.model_runtime.entities.model_entities import (
 from core.model_runtime.entities.model_entities import (
@@ -18,6 +16,7 @@ from core.model_runtime.entities.model_entities import (
 )
 )
 from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
 from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
 from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
 from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
+from core.tools.utils.yaml_utils import load_yaml_file
 from core.utils.position_helper import get_position_map, sort_by_position_map
 from core.utils.position_helper import get_position_map, sort_by_position_map
 
 
 
 
@@ -154,8 +153,7 @@ class AIModel(ABC):
         # traverse all model_schema_yaml_paths
         # traverse all model_schema_yaml_paths
         for model_schema_yaml_path in model_schema_yaml_paths:
         for model_schema_yaml_path in model_schema_yaml_paths:
             # read yaml data from yaml file
             # read yaml data from yaml file
-            with open(model_schema_yaml_path, encoding='utf-8') as f:
-                yaml_data = yaml.safe_load(f)
+            yaml_data = load_yaml_file(model_schema_yaml_path, ignore_error=True)
 
 
             new_parameter_rules = []
             new_parameter_rules = []
             for parameter_rule in yaml_data.get('parameter_rules', []):
             for parameter_rule in yaml_data.get('parameter_rules', []):

+ 2 - 6
api/core/model_runtime/model_providers/__base/model_provider.py

@@ -1,11 +1,10 @@
 import os
 import os
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 
 
-import yaml
-
 from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
 from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
 from core.model_runtime.entities.provider_entities import ProviderEntity
 from core.model_runtime.entities.provider_entities import ProviderEntity
 from core.model_runtime.model_providers.__base.ai_model import AIModel
 from core.model_runtime.model_providers.__base.ai_model import AIModel
+from core.tools.utils.yaml_utils import load_yaml_file
 from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source
 from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source
 
 
 
 
@@ -44,10 +43,7 @@ class ModelProvider(ABC):
 
 
         # read provider schema from yaml file
         # read provider schema from yaml file
         yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
         yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
-        yaml_data = {}
-        if os.path.exists(yaml_path):
-            with open(yaml_path, encoding='utf-8') as f:
-                yaml_data = yaml.safe_load(f)
+        yaml_data = load_yaml_file(yaml_path, ignore_error=True)
 
 
         try:
         try:
             # yaml_data to entity
             # yaml_data to entity

+ 16 - 18
api/core/tools/provider/builtin_tool_provider.py

@@ -2,8 +2,6 @@ from abc import abstractmethod
 from os import listdir, path
 from os import listdir, path
 from typing import Any
 from typing import Any
 
 
-from yaml import FullLoader, load
-
 from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
 from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
 from core.tools.entities.user_entities import UserToolProviderCredentials
 from core.tools.entities.user_entities import UserToolProviderCredentials
 from core.tools.errors import (
 from core.tools.errors import (
@@ -15,6 +13,7 @@ from core.tools.errors import (
 from core.tools.provider.tool_provider import ToolProviderController
 from core.tools.provider.tool_provider import ToolProviderController
 from core.tools.tool.builtin_tool import BuiltinTool
 from core.tools.tool.builtin_tool import BuiltinTool
 from core.tools.tool.tool import Tool
 from core.tools.tool.tool import Tool
+from core.tools.utils.yaml_utils import load_yaml_file
 from core.utils.module_import_helper import load_single_subclass_from_source
 from core.utils.module_import_helper import load_single_subclass_from_source
 
 
 
 
@@ -28,10 +27,9 @@ class BuiltinToolProviderController(ToolProviderController):
         provider = self.__class__.__module__.split('.')[-1]
         provider = self.__class__.__module__.split('.')[-1]
         yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml')
         yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml')
         try:
         try:
-            with open(yaml_path, 'rb') as f:
-                provider_yaml = load(f.read(), FullLoader)
-        except:
-            raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}')
+            provider_yaml = load_yaml_file(yaml_path)
+        except Exception as e:
+            raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}')
 
 
         if 'credentials_for_provider' in provider_yaml and provider_yaml['credentials_for_provider'] is not None:
         if 'credentials_for_provider' in provider_yaml and provider_yaml['credentials_for_provider'] is not None:
             # set credentials name
             # set credentials name
@@ -58,18 +56,18 @@ class BuiltinToolProviderController(ToolProviderController):
         tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
         tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
         tools = []
         tools = []
         for tool_file in tool_files:
         for tool_file in tool_files:
-            with open(path.join(tool_path, tool_file), encoding='utf-8') as f:
-                # get tool name
-                tool_name = tool_file.split(".")[0]
-                tool = load(f.read(), FullLoader)
-                # get tool class, import the module
-                assistant_tool_class = load_single_subclass_from_source(
-                    module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}',
-                    script_path=path.join(path.dirname(path.realpath(__file__)),
-                                           'builtin', provider, 'tools', f'{tool_name}.py'),
-                    parent_type=BuiltinTool)
-                tool["identity"]["provider"] = provider
-                tools.append(assistant_tool_class(**tool))
+            # get tool name
+            tool_name = tool_file.split(".")[0]
+            tool = load_yaml_file(path.join(tool_path, tool_file))
+
+            # get tool class, import the module
+            assistant_tool_class = load_single_subclass_from_source(
+                module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}',
+                script_path=path.join(path.dirname(path.realpath(__file__)),
+                                       'builtin', provider, 'tools', f'{tool_name}.py'),
+                parent_type=BuiltinTool)
+            tool["identity"]["provider"] = provider
+            tools.append(assistant_tool_class(**tool))
 
 
         self.tools = tools
         self.tools = tools
         return tools
         return tools

+ 17 - 17
api/core/tools/utils/configuration.py

@@ -23,7 +23,7 @@ class ToolConfigurationManager(BaseModel):
         deep copy credentials
         deep copy credentials
         """
         """
         return deepcopy(credentials)
         return deepcopy(credentials)
-    
+
     def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
     def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
         """
         """
         encrypt tool credentials with tenant id
         encrypt tool credentials with tenant id
@@ -39,9 +39,9 @@ class ToolConfigurationManager(BaseModel):
                 if field_name in credentials:
                 if field_name in credentials:
                     encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
                     encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
                     credentials[field_name] = encrypted
                     credentials[field_name] = encrypted
-        
+
         return credentials
         return credentials
-    
+
     def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]:
     def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]:
         """
         """
         mask tool credentials
         mask tool credentials
@@ -58,7 +58,7 @@ class ToolConfigurationManager(BaseModel):
                     if len(credentials[field_name]) > 6:
                     if len(credentials[field_name]) > 6:
                         credentials[field_name] = \
                         credentials[field_name] = \
                             credentials[field_name][:2] + \
                             credentials[field_name][:2] + \
-                            '*' * (len(credentials[field_name]) - 4) +\
+                            '*' * (len(credentials[field_name]) - 4) + \
                             credentials[field_name][-2:]
                             credentials[field_name][-2:]
                     else:
                     else:
                         credentials[field_name] = '*' * len(credentials[field_name])
                         credentials[field_name] = '*' * len(credentials[field_name])
@@ -72,7 +72,7 @@ class ToolConfigurationManager(BaseModel):
         return a deep copy of credentials with decrypted values
         return a deep copy of credentials with decrypted values
         """
         """
         cache = ToolProviderCredentialsCache(
         cache = ToolProviderCredentialsCache(
-            tenant_id=self.tenant_id, 
+            tenant_id=self.tenant_id,
             identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
             identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
             cache_type=ToolProviderCredentialsCacheType.PROVIDER
             cache_type=ToolProviderCredentialsCacheType.PROVIDER
         )
         )
@@ -92,10 +92,10 @@ class ToolConfigurationManager(BaseModel):
 
 
         cache.set(credentials)
         cache.set(credentials)
         return credentials
         return credentials
-    
+
     def delete_tool_credentials_cache(self):
     def delete_tool_credentials_cache(self):
         cache = ToolProviderCredentialsCache(
         cache = ToolProviderCredentialsCache(
-            tenant_id=self.tenant_id, 
+            tenant_id=self.tenant_id,
             identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
             identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
             cache_type=ToolProviderCredentialsCacheType.PROVIDER
             cache_type=ToolProviderCredentialsCacheType.PROVIDER
         )
         )
@@ -116,7 +116,7 @@ class ToolParameterConfigurationManager(BaseModel):
         deep copy parameters
         deep copy parameters
         """
         """
         return deepcopy(parameters)
         return deepcopy(parameters)
-    
+
     def _merge_parameters(self) -> list[ToolParameter]:
     def _merge_parameters(self) -> list[ToolParameter]:
         """
         """
         merge parameters
         merge parameters
@@ -139,7 +139,7 @@ class ToolParameterConfigurationManager(BaseModel):
                 current_parameters.append(runtime_parameter)
                 current_parameters.append(runtime_parameter)
 
 
         return current_parameters
         return current_parameters
-    
+
     def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
     def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
         """
         """
         mask tool parameters
         mask tool parameters
@@ -157,13 +157,13 @@ class ToolParameterConfigurationManager(BaseModel):
                     if len(parameters[parameter.name]) > 6:
                     if len(parameters[parameter.name]) > 6:
                         parameters[parameter.name] = \
                         parameters[parameter.name] = \
                             parameters[parameter.name][:2] + \
                             parameters[parameter.name][:2] + \
-                            '*' * (len(parameters[parameter.name]) - 4) +\
+                            '*' * (len(parameters[parameter.name]) - 4) + \
                             parameters[parameter.name][-2:]
                             parameters[parameter.name][-2:]
                     else:
                     else:
                         parameters[parameter.name] = '*' * len(parameters[parameter.name])
                         parameters[parameter.name] = '*' * len(parameters[parameter.name])
 
 
         return parameters
         return parameters
-    
+
     def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
     def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
         """
         """
         encrypt tool parameters with tenant id
         encrypt tool parameters with tenant id
@@ -180,9 +180,9 @@ class ToolParameterConfigurationManager(BaseModel):
                 if parameter.name in parameters:
                 if parameter.name in parameters:
                     encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
                     encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
                     parameters[parameter.name] = encrypted
                     parameters[parameter.name] = encrypted
-        
+
         return parameters
         return parameters
-    
+
     def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
     def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
         """
         """
         decrypt tool parameters with tenant id
         decrypt tool parameters with tenant id
@@ -190,7 +190,7 @@ class ToolParameterConfigurationManager(BaseModel):
         return a deep copy of parameters with decrypted values
         return a deep copy of parameters with decrypted values
         """
         """
         cache = ToolParameterCache(
         cache = ToolParameterCache(
-            tenant_id=self.tenant_id, 
+            tenant_id=self.tenant_id,
             provider=f'{self.provider_type}.{self.provider_name}',
             provider=f'{self.provider_type}.{self.provider_name}',
             tool_name=self.tool_runtime.identity.name,
             tool_name=self.tool_runtime.identity.name,
             cache_type=ToolParameterCacheType.PARAMETER,
             cache_type=ToolParameterCacheType.PARAMETER,
@@ -212,15 +212,15 @@ class ToolParameterConfigurationManager(BaseModel):
                         parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
                         parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
                     except:
                     except:
                         pass
                         pass
-        
+
         if has_secret_input:
         if has_secret_input:
             cache.set(parameters)
             cache.set(parameters)
 
 
         return parameters
         return parameters
-    
+
     def delete_tool_parameters_cache(self):
     def delete_tool_parameters_cache(self):
         cache = ToolParameterCache(
         cache = ToolParameterCache(
-            tenant_id=self.tenant_id, 
+            tenant_id=self.tenant_id,
             provider=f'{self.provider_type}.{self.provider_name}',
             provider=f'{self.provider_type}.{self.provider_name}',
             tool_name=self.tool_runtime.identity.name,
             tool_name=self.tool_runtime.identity.name,
             cache_type=ToolParameterCacheType.PARAMETER,
             cache_type=ToolParameterCacheType.PARAMETER,

+ 34 - 0
api/core/tools/utils/yaml_utils.py

@@ -0,0 +1,34 @@
+import logging
+import os
+
+import yaml
+from yaml import YAMLError
+
+
+def load_yaml_file(file_path: str, ignore_error: bool = False) -> dict:
+    """
+    Safe loading a YAML file to a dict
+    :param file_path: the path of the YAML file
+    :param ignore_error:
+        if True, return empty dict if error occurs and the error will be logged in warning level
+        if False, raise error if error occurs
+    :return: a dict of the YAML content
+    """
+    try:
+        if not file_path or not os.path.exists(file_path):
+            raise FileNotFoundError(f'Failed to load YAML file {file_path}: file not found')
+
+        with open(file_path, encoding='utf-8') as file:
+            try:
+                return yaml.safe_load(file)
+            except Exception as e:
+                raise YAMLError(f'Failed to load YAML file {file_path}: {e}')
+    except FileNotFoundError as e:
+        logging.debug(f'Failed to load YAML file {file_path}: {e}')
+        return {}
+    except Exception as e:
+        if ignore_error:
+            logging.warning(f'Failed to load YAML file {file_path}: {e}')
+            return {}
+        else:
+            raise e

+ 10 - 17
api/core/utils/position_helper.py

@@ -1,10 +1,9 @@
-import logging
 import os
 import os
 from collections import OrderedDict
 from collections import OrderedDict
 from collections.abc import Callable
 from collections.abc import Callable
 from typing import Any, AnyStr
 from typing import Any, AnyStr
 
 
-import yaml
+from core.tools.utils.yaml_utils import load_yaml_file
 
 
 
 
 def get_position_map(
 def get_position_map(
@@ -17,21 +16,15 @@ def get_position_map(
     :param file_name: the YAML file name, default to '_position.yaml'
     :param file_name: the YAML file name, default to '_position.yaml'
     :return: a dict with name as key and index as value
     :return: a dict with name as key and index as value
     """
     """
-    try:
-        position_file_name = os.path.join(folder_path, file_name)
-        if not os.path.exists(position_file_name):
-            return {}
-
-        with open(position_file_name, encoding='utf-8') as f:
-            positions = yaml.safe_load(f)
-        position_map = {}
-        for index, name in enumerate(positions):
-            if name and isinstance(name, str):
-                position_map[name.strip()] = index
-        return position_map
-    except:
-        logging.warning(f'Failed to load the YAML position file {folder_path}/{file_name}.')
-        return {}
+    position_file_name = os.path.join(folder_path, file_name)
+    positions = load_yaml_file(position_file_name, ignore_error=True)
+    position_map = {}
+    index = 0
+    for _, name in enumerate(positions):
+        if name and isinstance(name, str):
+            position_map[name.strip()] = index
+            index += 1
+    return position_map
 
 
 
 
 def sort_by_position_map(
 def sort_by_position_map(

+ 1 - 0
api/pyproject.toml

@@ -14,6 +14,7 @@ select = [
     "I", # isort rules
     "I", # isort rules
     "UP",   # pyupgrade rules
     "UP",   # pyupgrade rules
     "RUF019", # unnecessary-key-check
     "RUF019", # unnecessary-key-check
+    "S506", # unsafe-yaml-load
 ]
 ]
 ignore = [
 ignore = [
     "F403", # undefined-local-with-import-star
     "F403", # undefined-local-with-import-star

+ 0 - 0
api/tests/unit_tests/utils/__init__.py


+ 34 - 0
api/tests/unit_tests/utils/position_helper/test_position_helper.py

@@ -0,0 +1,34 @@
+from textwrap import dedent
+
+import pytest
+
+from core.utils.position_helper import get_position_map
+
+
+@pytest.fixture
+def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
+    monkeypatch.chdir(tmp_path)
+    tmp_path.joinpath("example_positions.yaml").write_text(dedent(
+        """\
+        - first
+        - second
+        # - commented
+        - third
+        
+        - 9999999999999
+        - forth
+        """))
+    return str(tmp_path)
+
+
+def test_position_helper(prepare_example_positions_yaml):
+    position_map = get_position_map(
+        folder_path=prepare_example_positions_yaml,
+        file_name='example_positions.yaml')
+    assert len(position_map) == 4
+    assert position_map == {
+        'first': 0,
+        'second': 1,
+        'third': 2,
+        'forth': 3,
+    }

+ 0 - 0
api/tests/unit_tests/utils/yaml/__init__.py


+ 74 - 0
api/tests/unit_tests/utils/yaml/test_yaml_utils.py

@@ -0,0 +1,74 @@
+from textwrap import dedent
+
+import pytest
+from yaml import YAMLError
+
+from core.tools.utils.yaml_utils import load_yaml_file
+
+EXAMPLE_YAML_FILE = 'example_yaml.yaml'
+INVALID_YAML_FILE = 'invalid_yaml.yaml'
+NON_EXISTING_YAML_FILE = 'non_existing_file.yaml'
+
+
+@pytest.fixture
+def prepare_example_yaml_file(tmp_path, monkeypatch) -> str:
+    monkeypatch.chdir(tmp_path)
+    file_path = tmp_path.joinpath(EXAMPLE_YAML_FILE)
+    file_path.write_text(dedent(
+        """\
+        address:
+            city: Example City
+            country: Example Country
+        age: 30
+        gender: male
+        languages:
+            - Python
+            - Java
+            - C++
+        empty_key:
+        """))
+    return str(file_path)
+
+
+@pytest.fixture
+def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str:
+    monkeypatch.chdir(tmp_path)
+    file_path = tmp_path.joinpath(INVALID_YAML_FILE)
+    file_path.write_text(dedent(
+        """\
+        address:
+                   city: Example City
+            country: Example Country
+        age: 30
+        gender: male
+        languages:
+        - Python
+        - Java
+        - C++
+        """))
+    return str(file_path)
+
+
+def test_load_yaml_non_existing_file():
+    assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {}
+    assert load_yaml_file(file_path='') == {}
+
+
+def test_load_valid_yaml_file(prepare_example_yaml_file):
+    yaml_data = load_yaml_file(file_path=prepare_example_yaml_file)
+    assert len(yaml_data) > 0
+    assert yaml_data['age'] == 30
+    assert yaml_data['gender'] == 'male'
+    assert yaml_data['address']['city'] == 'Example City'
+    assert set(yaml_data['languages']) == {'Python', 'Java', 'C++'}
+    assert yaml_data.get('empty_key') is None
+    assert yaml_data.get('non_existed_key') is None
+
+
+def test_load_invalid_yaml_file(prepare_invalid_yaml_file):
+    # yaml syntax error
+    with pytest.raises(YAMLError):
+        load_yaml_file(file_path=prepare_invalid_yaml_file)
+
+    # ignore error
+    assert load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=True) == {}