Browse Source

feat: optimize model when app create (#875)

takatost 1 year ago
parent
commit
b7c29ea1b6

+ 38 - 0
.github/workflows/api-unit-tests.yml

@@ -0,0 +1,38 @@
+name: Run Pytest
+
+on:
+  pull_request:
+    branches:
+      - main
+  push:
+    branches:
+      - deploy/dev
+
+jobs:
+  test:
+    runs-on: ubuntu-latest
+
+    steps:
+    - name: Checkout code
+      uses: actions/checkout@v2
+
+    - name: Set up Python
+      uses: actions/setup-python@v2
+      with:
+        python-version: '3.10'
+
+    - name: Cache pip dependencies
+      uses: actions/cache@v2
+      with:
+        path: ~/.cache/pip
+        key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
+        restore-keys: ${{ runner.os }}-pip-
+
+    - name: Install dependencies
+      run: |
+        python -m pip install --upgrade pip
+        pip install pytest
+        pip install -r api/requirements.txt
+
+    - name: Run pytest
+      run: pytest api/tests/unit_tests

+ 43 - 18
api/controllers/console/app/app.py

@@ -1,5 +1,6 @@
 # -*- coding:utf-8 -*-
 import json
+import logging
 from datetime import datetime
 
 from flask_login import login_required, current_user
@@ -11,7 +12,9 @@ from controllers.console import api
 from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
+from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
 from core.model_providers.model_factory import ModelFactory
+from core.model_providers.model_provider_factory import ModelProviderFactory
 from core.model_providers.models.entity.model_params import ModelType
 from events.app_event import app_was_created, app_was_deleted
 from libs.helper import TimestampField
@@ -124,24 +127,34 @@ class AppListApi(Resource):
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
             raise Forbidden()
 
-        default_model = ModelFactory.get_default_model(
-            tenant_id=current_user.current_tenant_id,
-            model_type=ModelType.TEXT_GENERATION
-        )
-
-        if default_model:
-            default_model_provider = default_model.provider_name
-            default_model_name = default_model.model_name
-        else:
-            raise ProviderNotInitializeError(
-                f"No Text Generation Model available. Please configure a valid provider "
-                f"in the Settings -> Model Provider.")
+        try:
+            default_model = ModelFactory.get_text_generation_model(
+                tenant_id=current_user.current_tenant_id
+            )
+        except (ProviderTokenNotInitError, LLMBadRequestError):
+            default_model = None
+        except Exception as e:
+            logging.exception(e)
+            default_model = None
 
         if args['model_config'] is not None:
             # validate config
             model_config_dict = args['model_config']
-            model_config_dict["model"]["provider"] = default_model_provider
-            model_config_dict["model"]["name"] = default_model_name
+
+            # get model provider
+            model_provider = ModelProviderFactory.get_preferred_model_provider(
+                current_user.current_tenant_id,
+                model_config_dict["model"]["provider"]
+            )
+
+            if not model_provider:
+                if not default_model:
+                    raise ProviderNotInitializeError(
+                        f"No Default System Reasoning Model available. Please configure "
+                        f"in the Settings -> Model Provider.")
+                else:
+                    model_config_dict["model"]["provider"] = default_model.model_provider.provider_name
+                    model_config_dict["model"]["name"] = default_model.name
 
             model_configuration = AppModelConfigService.validate_configuration(
                 tenant_id=current_user.current_tenant_id,
@@ -169,10 +182,22 @@ class AppListApi(Resource):
             app = App(**model_config_template['app'])
             app_model_config = AppModelConfig(**model_config_template['model_config'])
 
-            model_dict = app_model_config.model_dict
-            model_dict['provider'] = default_model_provider
-            model_dict['name'] = default_model_name
-            app_model_config.model = json.dumps(model_dict)
+            # get model provider
+            model_provider = ModelProviderFactory.get_preferred_model_provider(
+                current_user.current_tenant_id,
+                app_model_config.model_dict["provider"]
+            )
+
+            if not model_provider:
+                if not default_model:
+                    raise ProviderNotInitializeError(
+                        f"No Default System Reasoning Model available. Please configure "
+                        f"in the Settings -> Model Provider.")
+                else:
+                    model_dict = app_model_config.model_dict
+                    model_dict['provider'] = default_model.model_provider.provider_name
+                    model_dict['name'] = default_model.name
+                    app_model_config.model = json.dumps(model_dict)
 
         app.name = args['name']
         app.mode = args['mode']

+ 3 - 2
api/tests/unit_tests/model_providers/test_huggingface_hub_provider.py

@@ -30,8 +30,9 @@ def decrypt_side_effect(tenant_id, encrypted_key):
 
 @patch('huggingface_hub.hf_api.ModelInfo')
 def test_hosted_inference_api_is_credentials_valid_or_raise_valid(mock_model_info, mocker):
-    mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation')
-    mocker.patch('langchain.llms.huggingface_hub.HuggingFaceHub._call', return_value="abc")
+    mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation', cardData={'inference': True})
+    mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value="abc")
+    mocker.patch('huggingface_hub.hf_api.HfApi.model_info', return_value=mock_model_info.return_value)
 
     MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
         model_name='test_model_name',

+ 19 - 2
api/tests/unit_tests/model_providers/test_replicate_provider.py

@@ -23,14 +23,31 @@ def decrypt_side_effect(tenant_id, encrypted_key):
     return encrypted_key.replace('encrypted_', '')
 
 
+def version_effect(id: str):
+    mock_version = MagicMock()
+    mock_version.openapi_schema = {
+        'components': {
+            'schemas': {
+                'Output': {
+                    'items': {
+                        'type': 'string'
+                    }
+                }
+            }
+        }
+    }
+
+    return mock_version
+
+@patch('replicate.version.VersionCollection.get', side_effect=version_effect)
 def test_is_credentials_valid_or_raise_valid(mocker):
     mock_query = MagicMock()
     mock_query.return_value = None
+
     mocker.patch('replicate.model.ModelCollection.get', return_value=mock_query)
-    mocker.patch('replicate.model.Model.versions', return_value=mock_query)
 
     MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
-        model_name='test_model_name',
+        model_name='username/test_model_name',
         model_type=ModelType.TEXT_GENERATION,
         credentials=VALIDATE_CREDENTIAL.copy()
     )

+ 1 - 1
api/tests/unit_tests/model_providers/test_tongyi_provider.py

@@ -26,7 +26,7 @@ def decrypt_side_effect(tenant_id, encrypted_key):
 
 
 def test_is_provider_credentials_valid_or_raise_valid(mocker):
-    mocker.patch('langchain.llms.tongyi.Tongyi._generate', return_value=LLMResult(generations=[[Generation(text="abc")]]))
+    mocker.patch('core.third_party.langchain.llms.tongyi_llm.EnhanceTongyi._generate', return_value=LLMResult(generations=[[Generation(text="abc")]]))
 
     MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)