Browse Source

refactor: extract cors configs into dify config and cleanup the config class (#5507)

Co-authored-by: takatost <takatost@gmail.com>
Bowen Liang 9 months ago
parent
commit
2a0f03a511

+ 0 - 2
.github/workflows/api-tests.yml

@@ -77,7 +77,6 @@ jobs:
             docker/docker-compose.pgvecto-rs.yaml
             docker/docker-compose.pgvector.yaml
             docker/docker-compose.chroma.yaml
-            docker/docker-compose.oracle.yaml
           services: |
             weaviate
             qdrant
@@ -87,7 +86,6 @@ jobs:
             pgvecto-rs
             pgvector
             chroma
-            oracle
 
       - name: Test Vector Stores
         run: poetry run -C api bash dev/pytest/pytest_vdb.sh

+ 1 - 3
api/app.py

@@ -24,7 +24,6 @@ from flask_cors import CORS
 from werkzeug.exceptions import Unauthorized
 
 from commands import register_commands
-from config import Config
 
 # DO NOT REMOVE BELOW
 from events import event_handlers
@@ -82,7 +81,6 @@ def create_flask_app_with_configs() -> Flask:
     with configs loaded from .env file
     """
     dify_app = DifyApp(__name__)
-    dify_app.config.from_object(Config())
     dify_app.config.from_mapping(DifyConfig().model_dump())
     return dify_app
 
@@ -232,7 +230,7 @@ def register_blueprints(app):
 app = create_app()
 celery = app.extensions["celery"]
 
-if app.config['TESTING']:
+if app.config.get('TESTING'):
     print("App is running in TESTING mode")
 
 

+ 0 - 42
api/config.py

@@ -1,42 +0,0 @@
-import os
-
-import dotenv
-
-DEFAULTS = {
-}
-
-
-def get_env(key):
-    return os.environ.get(key, DEFAULTS.get(key))
-
-
-def get_bool_env(key):
-    value = get_env(key)
-    return value.lower() == 'true' if value is not None else False
-
-
-def get_cors_allow_origins(env, default):
-    cors_allow_origins = []
-    if get_env(env):
-        for origin in get_env(env).split(','):
-            cors_allow_origins.append(origin)
-    else:
-        cors_allow_origins = [default]
-
-    return cors_allow_origins
-
-
-class Config:
-    """Application configuration class."""
-
-    def __init__(self):
-        dotenv.load_dotenv()
-
-        self.TESTING = False
-        self.APPLICATION_NAME = "langgenius/dify"
-
-        # cors settings
-        self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
-            'CONSOLE_CORS_ALLOW_ORIGINS', get_env('CONSOLE_WEB_URL'))
-        self.WEB_API_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
-            'WEB_API_CORS_ALLOW_ORIGINS', '*')

+ 10 - 0
api/configs/deploy/__init__.py

@@ -5,6 +5,16 @@ class DeploymentConfig(BaseModel):
     """
     Deployment configs
     """
+    APPLICATION_NAME: str = Field(
+        description='application name',
+        default='langgenius/dify',
+    )
+
+    TESTING: bool = Field(
+        description='',
+        default=False,
+    )
+
     EDITION: str = Field(
         description='deployment edition',
         default='SELF_HOSTED',

+ 23 - 1
api/configs/feature/__init__.py

@@ -1,6 +1,6 @@
 from typing import Optional
 
-from pydantic import AliasChoices, BaseModel, Field, NonNegativeInt, PositiveInt
+from pydantic import AliasChoices, BaseModel, Field, NonNegativeInt, PositiveInt, computed_field
 
 from configs.feature.hosted_service import HostedServiceConfig
 
@@ -125,6 +125,28 @@ class HttpConfig(BaseModel):
         default=False,
     )
 
+    inner_CONSOLE_CORS_ALLOW_ORIGINS: str = Field(
+        description='',
+        validation_alias=AliasChoices('CONSOLE_CORS_ALLOW_ORIGINS', 'CONSOLE_WEB_URL'),
+        default='',
+    )
+
+    @computed_field
+    @property
+    def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]:
+        return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(',')
+
+    inner_WEB_API_CORS_ALLOW_ORIGINS: Optional[str] = Field(
+        description='',
+        validation_alias=AliasChoices('WEB_API_CORS_ALLOW_ORIGINS'),
+        default='*',
+    )
+
+    @computed_field
+    @property
+    def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
+        return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(',')
+
 
 class InnerAPIConfig(BaseModel):
     """

+ 3 - 3
api/core/helper/code_executor/code_executor.py

@@ -1,4 +1,5 @@
 import logging
+import os
 import time
 from enum import Enum
 from threading import Lock
@@ -8,7 +9,6 @@ from httpx import get, post
 from pydantic import BaseModel
 from yarl import URL
 
-from config import get_env
 from core.helper.code_executor.entities import CodeDependency
 from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer
 from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer
@@ -18,8 +18,8 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
 logger = logging.getLogger(__name__)
 
 # Code Executor
-CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT')
-CODE_EXECUTION_API_KEY = get_env('CODE_EXECUTION_API_KEY')
+CODE_EXECUTION_ENDPOINT = os.environ.get('CODE_EXECUTION_ENDPOINT', 'http://sandbox:8194')
+CODE_EXECUTION_API_KEY = os.environ.get('CODE_EXECUTION_API_KEY', 'dify-sandbox')
 
 CODE_EXECUTION_TIMEOUT= (10, 60)
 

+ 7 - 3
api/tests/unit_tests/configs/test_dify_config.py

@@ -15,6 +15,7 @@ def example_env_file(tmp_path, monkeypatch) -> str:
     file_path.write_text(dedent(
         """
         CONSOLE_API_URL=https://example.com
+        CONSOLE_WEB_URL=https://example.com
         """))
     return str(file_path)
 
@@ -47,14 +48,13 @@ def test_flask_configs(example_env_file):
     flask_app.config.from_mapping(DifyConfig(_env_file=example_env_file).model_dump())
     config = flask_app.config
 
-    # configs read from dotenv directly
-    assert config['LOG_LEVEL'] == 'INFO'
-
     # configs read from pydantic-settings
+    assert config['LOG_LEVEL'] == 'INFO'
     assert config['COMMIT_SHA'] == ''
     assert config['EDITION'] == 'SELF_HOSTED'
     assert config['API_COMPRESSION_ENABLED'] is False
     assert config['SENTRY_TRACES_SAMPLE_RATE'] == 1.0
+    assert config['TESTING'] == False
 
     # value from env file
     assert config['CONSOLE_API_URL'] == 'https://example.com'
@@ -71,3 +71,7 @@ def test_flask_configs(example_env_file):
         'pool_recycle': 3600,
         'pool_size': 30,
     }
+
+    assert config['CONSOLE_WEB_URL']=='https://example.com'
+    assert config['CONSOLE_CORS_ALLOW_ORIGINS']==['https://example.com']
+    assert config['WEB_API_CORS_ALLOW_ORIGINS'] == ['*']

+ 7 - 1
dev/pytest/pytest_vdb.sh

@@ -1,4 +1,10 @@
 #!/bin/bash
 set -x
 
-pytest api/tests/integration_tests/vdb/
+pytest api/tests/integration_tests/vdb/chroma \
+  api/tests/integration_tests/vdb/milvus \
+  api/tests/integration_tests/vdb/pgvecto_rs \
+  api/tests/integration_tests/vdb/pgvector \
+  api/tests/integration_tests/vdb/qdrant \
+  api/tests/integration_tests/vdb/weaviate \
+  api/tests/integration_tests/vdb/test_vector_store.py