瀏覽代碼

feat: server xinference support (#927)

takatost 1 年之前
父節點
當前提交
da3f10a55e

+ 3 - 0
api/core/model_providers/model_provider_factory.py

@@ -57,6 +57,9 @@ class ModelProviderFactory:
         elif provider_name == 'huggingface_hub':
         elif provider_name == 'huggingface_hub':
             from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
             from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
             return HuggingfaceHubProvider
             return HuggingfaceHubProvider
+        elif provider_name == 'xinference':
+            from core.model_providers.providers.xinference_provider import XinferenceProvider
+            return XinferenceProvider
         else:
         else:
             raise NotImplementedError
             raise NotImplementedError
 
 

+ 69 - 0
api/core/model_providers/models/llm/xinference_model.py

@@ -0,0 +1,69 @@
+from typing import List, Optional, Any
+
+from langchain.callbacks.manager import Callbacks
+from langchain.llms import Xinference
+from langchain.schema import LLMResult
+
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.entity.message import PromptMessage
+from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+
+
+class XinferenceModel(BaseLLM):
+    model_mode: ModelMode = ModelMode.COMPLETION
+
+    def _init_client(self) -> Any:
+        self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
+
+        client = Xinference(
+            **self.credentials,
+        )
+
+        client.callbacks = self.callbacks
+
+        return client
+
+    def _run(self, messages: List[PromptMessage],
+             stop: Optional[List[str]] = None,
+             callbacks: Callbacks = None,
+             **kwargs) -> LLMResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return self._client.generate(
+            [prompts],
+            stop,
+            callbacks,
+            generate_config={
+                "stop": stop,
+                "stream": self.streaming,
+                **self.provider_model_kwargs,
+            }
+        )
+
+    def get_num_tokens(self, messages: List[PromptMessage]) -> int:
+        """
+        get num tokens of prompt messages.
+
+        :param messages:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return max(self._client.get_num_tokens(prompts), 0)
+
+    def _set_model_kwargs(self, model_kwargs: ModelKwargs):
+        pass
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        return LLMBadRequestError(f"Xinference: {str(ex)}")
+
+    @classmethod
+    def support_streaming(cls):
+        return True

+ 141 - 0
api/core/model_providers/providers/xinference_provider.py

@@ -0,0 +1,141 @@
+import json
+from typing import Type
+
+from langchain.llms import Xinference
+
+from core.helper import encrypter
+from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
+from core.model_providers.models.llm.xinference_model import XinferenceModel
+from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
+
+from core.model_providers.models.base import BaseProviderModel
+from models.provider import ProviderType
+
+
+class XinferenceProvider(BaseModelProvider):
+    @property
+    def provider_name(self):
+        """
+        Returns the name of a provider.
+        """
+        return 'xinference'
+
+    def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
+        return []
+
+    def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
+        """
+        Returns the model class.
+
+        :param model_type:
+        :return:
+        """
+        if model_type == ModelType.TEXT_GENERATION:
+            model_class = XinferenceModel
+        else:
+            raise NotImplementedError
+
+        return model_class
+
+    def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
+        """
+        get model parameter rules.
+
+        :param model_name:
+        :param model_type:
+        :return:
+        """
+        return ModelKwargsRules(
+            temperature=KwargRule[float](min=0, max=2, default=1),
+            top_p=KwargRule[float](min=0, max=1, default=0.7),
+            presence_penalty=KwargRule[float](min=-2, max=2, default=0),
+            frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
+            max_tokens=KwargRule[int](alias='max_token', min=10, max=4000, default=256),
+        )
+
+    @classmethod
+    def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
+        """
+        check model credentials valid.
+
+        :param model_name:
+        :param model_type:
+        :param credentials:
+        """
+        if 'server_url' not in credentials:
+            raise CredentialsValidateFailedError('Xinference Server URL must be provided.')
+
+        if 'model_uid' not in credentials:
+            raise CredentialsValidateFailedError('Xinference Model UID must be provided.')
+
+        try:
+            credential_kwargs = {
+                'server_url': credentials['server_url'],
+                'model_uid': credentials['model_uid'],
+            }
+
+            llm = Xinference(
+                **credential_kwargs
+            )
+
+            llm("ping", generate_config={'max_tokens': 10})
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    @classmethod
+    def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
+                                  credentials: dict) -> dict:
+        """
+        encrypt model credentials for save.
+
+        :param tenant_id:
+        :param model_name:
+        :param model_type:
+        :param credentials:
+        :return:
+        """
+        credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
+        return credentials
+
+    def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
+        """
+        get credentials for llm use.
+
+        :param model_name:
+        :param model_type:
+        :param obfuscated:
+        :return:
+        """
+        if self.provider.provider_type != ProviderType.CUSTOM.value:
+            raise NotImplementedError
+
+        provider_model = self._get_provider_model(model_name, model_type)
+
+        if not provider_model.encrypted_config:
+            return {
+                'server_url': None,
+                'model_uid': None,
+            }
+
+        credentials = json.loads(provider_model.encrypted_config)
+        if credentials['server_url']:
+            credentials['server_url'] = encrypter.decrypt_token(
+                self.provider.tenant_id,
+                credentials['server_url']
+            )
+
+            if obfuscated:
+                credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
+
+        return credentials
+
+    @classmethod
+    def is_provider_credentials_valid_or_raise(cls, credentials: dict):
+        return
+
+    @classmethod
+    def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
+        return {}
+
+    def get_provider_credentials(self, obfuscated: bool = False) -> dict:
+        return {}

+ 2 - 1
api/core/model_providers/rules/_providers.json

@@ -8,5 +8,6 @@
   "wenxin",
   "wenxin",
   "chatglm",
   "chatglm",
   "replicate",
   "replicate",
-  "huggingface_hub"
+  "huggingface_hub",
+  "xinference"
 ]
 ]

+ 7 - 0
api/core/model_providers/rules/xinference.json

@@ -0,0 +1,7 @@
+{
+    "support_provider_types": [
+        "custom"
+    ],
+    "system_config": null,
+    "model_flexibility": "configurable"
+}

+ 2 - 1
api/requirements.txt

@@ -48,4 +48,5 @@ dashscope~=1.5.0
 huggingface_hub~=0.16.4
 huggingface_hub~=0.16.4
 transformers~=4.31.0
 transformers~=4.31.0
 stripe~=5.5.0
 stripe~=5.5.0
-pandas==1.5.3
+pandas==1.5.3
+xinference==0.2.0

+ 5 - 1
api/tests/integration_tests/.env.example

@@ -32,4 +32,8 @@ WENXIN_API_KEY=
 WENXIN_SECRET_KEY=
 WENXIN_SECRET_KEY=
 
 
 # ChatGLM Credentials
 # ChatGLM Credentials
-CHATGLM_API_BASE=
+CHATGLM_API_BASE=
+
+# Xinference Credentials
+XINFERENCE_SERVER_URL=
+XINFERENCE_MODEL_UID=

+ 3 - 2
api/tests/integration_tests/models/llm/test_anthropic_model.py

@@ -50,7 +50,9 @@ def test_get_num_tokens(mock_decrypt):
 
 
 
 
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
-def test_run(mock_decrypt):
+def test_run(mock_decrypt, mocker):
+    mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
+
     model = get_mock_model('claude-2')
     model = get_mock_model('claude-2')
     messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: ')]
     messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: ')]
     rst = model.run(
     rst = model.run(
@@ -58,4 +60,3 @@ def test_run(mock_decrypt):
         stop=['\nHuman:'],
         stop=['\nHuman:'],
     )
     )
     assert len(rst.content) > 0
     assert len(rst.content) > 0
-    assert rst.content.strip() == '2'

+ 2 - 1
api/tests/integration_tests/models/llm/test_azure_openai_model.py

@@ -76,6 +76,8 @@ def test_chat_get_num_tokens(mock_decrypt, mocker):
 
 
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
 def test_run(mock_decrypt, mocker):
 def test_run(mock_decrypt, mocker):
+    mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
+
     openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker)
     openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker)
     messages = [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')]
     messages = [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')]
     rst = openai_model.run(
     rst = openai_model.run(
@@ -83,4 +85,3 @@ def test_run(mock_decrypt, mocker):
         stop=['\nHuman:'],
         stop=['\nHuman:'],
     )
     )
     assert len(rst.content) > 0
     assert len(rst.content) > 0
-    assert rst.content.strip() == 'n'

+ 4 - 1
api/tests/integration_tests/models/llm/test_huggingface_hub_model.py

@@ -95,6 +95,8 @@ def test_inference_endpoints_get_num_tokens(mock_decrypt, mock_model_info, mocke
 
 
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
 def test_hosted_inference_api_run(mock_decrypt, mocker):
 def test_hosted_inference_api_run(mock_decrypt, mocker):
+    mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
+
     model = get_mock_model(
     model = get_mock_model(
         'google/flan-t5-base',
         'google/flan-t5-base',
         'hosted_inference_api',
         'hosted_inference_api',
@@ -111,6 +113,8 @@ def test_hosted_inference_api_run(mock_decrypt, mocker):
 
 
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
 def test_inference_endpoints_run(mock_decrypt, mocker):
 def test_inference_endpoints_run(mock_decrypt, mocker):
+    mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
+
     model = get_mock_model(
     model = get_mock_model(
         '',
         '',
         'inference_endpoints',
         'inference_endpoints',
@@ -121,4 +125,3 @@ def test_inference_endpoints_run(mock_decrypt, mocker):
         [PromptMessage(content='Answer the following yes/no question. Can you write a whole Haiku in a single tweet?')],
         [PromptMessage(content='Answer the following yes/no question. Can you write a whole Haiku in a single tweet?')],
     )
     )
     assert len(rst.content) > 0
     assert len(rst.content) > 0
-    assert rst.content.strip() == 'no'

+ 3 - 2
api/tests/integration_tests/models/llm/test_minimax_model.py

@@ -54,11 +54,12 @@ def test_get_num_tokens(mock_decrypt):
 
 
 
 
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
-def test_run(mock_decrypt):
+def test_run(mock_decrypt, mocker):
+    mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
+
     model = get_mock_model('abab5.5-chat')
     model = get_mock_model('abab5.5-chat')
     rst = model.run(
     rst = model.run(
         [PromptMessage(content='Human: Are you a real Human? you MUST only answer `y` or `n`? \nAssistant: ')],
         [PromptMessage(content='Human: Are you a real Human? you MUST only answer `y` or `n`? \nAssistant: ')],
         stop=['\nHuman:'],
         stop=['\nHuman:'],
     )
     )
     assert len(rst.content) > 0
     assert len(rst.content) > 0
-    assert rst.content.strip() == 'n'

+ 6 - 3
api/tests/integration_tests/models/llm/test_openai_model.py

@@ -58,7 +58,9 @@ def test_chat_get_num_tokens(mock_decrypt):
 
 
 
 
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
-def test_run(mock_decrypt):
+def test_run(mock_decrypt, mocker):
+    mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
+
     openai_model = get_mock_openai_model('text-davinci-003')
     openai_model = get_mock_openai_model('text-davinci-003')
     rst = openai_model.run(
     rst = openai_model.run(
         [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
         [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
@@ -69,7 +71,9 @@ def test_run(mock_decrypt):
 
 
 
 
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
-def test_chat_run(mock_decrypt):
+def test_chat_run(mock_decrypt, mocker):
+    mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
+
     openai_model = get_mock_openai_model('gpt-3.5-turbo')
     openai_model = get_mock_openai_model('gpt-3.5-turbo')
     messages = [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')]
     messages = [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')]
     rst = openai_model.run(
     rst = openai_model.run(
@@ -77,4 +81,3 @@ def test_chat_run(mock_decrypt):
         stop=['\nHuman:'],
         stop=['\nHuman:'],
     )
     )
     assert len(rst.content) > 0
     assert len(rst.content) > 0
-    assert rst.content.strip() == 'n'

+ 2 - 0
api/tests/integration_tests/models/llm/test_replicate_model.py

@@ -65,6 +65,8 @@ def test_get_num_tokens(mock_decrypt, mocker):
 
 
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
 def test_run(mock_decrypt, mocker):
 def test_run(mock_decrypt, mocker):
+    mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
+
     model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker)
     model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker)
     messages = [PromptMessage(content='Human: 1+1=? \nAnswer: ')]
     messages = [PromptMessage(content='Human: 1+1=? \nAnswer: ')]
     rst = model.run(
     rst = model.run(

+ 3 - 2
api/tests/integration_tests/models/llm/test_spark_model.py

@@ -58,7 +58,9 @@ def test_get_num_tokens(mock_decrypt):
 
 
 
 
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
-def test_run(mock_decrypt):
+def test_run(mock_decrypt, mocker):
+    mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
+
     model = get_mock_model('spark')
     model = get_mock_model('spark')
     messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
     messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
     rst = model.run(
     rst = model.run(
@@ -66,4 +68,3 @@ def test_run(mock_decrypt):
         stop=['\nHuman:'],
         stop=['\nHuman:'],
     )
     )
     assert len(rst.content) > 0
     assert len(rst.content) > 0
-    assert rst.content.strip() == '2'

+ 3 - 1
api/tests/integration_tests/models/llm/test_tongyi_model.py

@@ -52,7 +52,9 @@ def test_get_num_tokens(mock_decrypt):
 
 
 
 
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
-def test_run(mock_decrypt):
+def test_run(mock_decrypt, mocker):
+    mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
+
     model = get_mock_model('qwen-v1')
     model = get_mock_model('qwen-v1')
     rst = model.run(
     rst = model.run(
         [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
         [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],

+ 3 - 2
api/tests/integration_tests/models/llm/test_wenxin_model.py

@@ -52,7 +52,9 @@ def test_get_num_tokens(mock_decrypt):
 
 
 
 
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
-def test_run(mock_decrypt):
+def test_run(mock_decrypt, mocker):
+    mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
+
     model = get_mock_model('ernie-bot')
     model = get_mock_model('ernie-bot')
     messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
     messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
     rst = model.run(
     rst = model.run(
@@ -60,4 +62,3 @@ def test_run(mock_decrypt):
         stop=['\nHuman:'],
         stop=['\nHuman:'],
     )
     )
     assert len(rst.content) > 0
     assert len(rst.content) > 0
-    assert rst.content.strip() == '2'

+ 74 - 0
api/tests/integration_tests/models/llm/test_xinference_model.py

@@ -0,0 +1,74 @@
+import json
+import os
+from unittest.mock import patch, MagicMock
+
+from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
+from core.model_providers.models.llm.xinference_model import XinferenceModel
+from core.model_providers.providers.xinference_provider import XinferenceProvider
+from models.provider import Provider, ProviderType, ProviderModel
+
+
+def get_mock_provider():
+    return Provider(
+        id='provider_id',
+        tenant_id='tenant_id',
+        provider_name='xinference',
+        provider_type=ProviderType.CUSTOM.value,
+        encrypted_config='',
+        is_valid=True,
+    )
+
+
+def get_mock_model(model_name, mocker):
+    model_kwargs = ModelKwargs(
+        max_tokens=10,
+        temperature=0.01
+    )
+    server_url = os.environ['XINFERENCE_SERVER_URL']
+    model_uid = os.environ['XINFERENCE_MODEL_UID']
+    model_provider = XinferenceProvider(provider=get_mock_provider())
+
+    mock_query = MagicMock()
+    mock_query.filter.return_value.first.return_value = ProviderModel(
+        provider_name='xinference',
+        model_name=model_name,
+        model_type=ModelType.TEXT_GENERATION.value,
+        encrypted_config=json.dumps({
+            'server_url': server_url,
+            'model_uid': model_uid
+        }),
+        is_valid=True,
+    )
+    mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
+
+    return XinferenceModel(
+        model_provider=model_provider,
+        name=model_name,
+        model_kwargs=model_kwargs
+    )
+
+
+def decrypt_side_effect(tenant_id, encrypted_api_key):
+    return encrypted_api_key
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_get_num_tokens(mock_decrypt, mocker):
+    model = get_mock_model('llama-2-chat', mocker)
+    rst = model.get_num_tokens([
+        PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+    ])
+    assert rst == 5
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_run(mock_decrypt, mocker):
+    mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
+
+    model = get_mock_model('llama-2-chat', mocker)
+    messages = [PromptMessage(content='Human: 1+1=? \nAnswer: ')]
+    rst = model.run(
+        messages
+    )
+    assert len(rst.content) > 0

+ 124 - 0
api/tests/unit_tests/model_providers/test_xinference_provider.py

@@ -0,0 +1,124 @@
+import pytest
+from unittest.mock import patch, MagicMock
+import json
+
+from core.model_providers.models.entity.model_params import ModelType
+from core.model_providers.providers.base import CredentialsValidateFailedError
+from core.model_providers.providers.replicate_provider import ReplicateProvider
+from core.model_providers.providers.xinference_provider import XinferenceProvider
+from models.provider import ProviderType, Provider, ProviderModel
+
+PROVIDER_NAME = 'xinference'
+MODEL_PROVIDER_CLASS = XinferenceProvider
+VALIDATE_CREDENTIAL = {
+    'model_uid': 'fake-model-uid',
+    'server_url': 'http://127.0.0.1:9997/'
+}
+
+
+def encrypt_side_effect(tenant_id, encrypt_key):
+    return f'encrypted_{encrypt_key}'
+
+
+def decrypt_side_effect(tenant_id, encrypted_key):
+    return encrypted_key.replace('encrypted_', '')
+
+
+def test_is_credentials_valid_or_raise_valid(mocker):
+    mocker.patch('langchain.llms.xinference.Xinference._call',
+                 return_value="abc")
+
+    MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
+        model_name='username/test_model_name',
+        model_type=ModelType.TEXT_GENERATION,
+        credentials=VALIDATE_CREDENTIAL.copy()
+    )
+
+
+def test_is_credentials_valid_or_raise_invalid():
+    # raise CredentialsValidateFailedError if replicate_api_token is not in credentials
+    with pytest.raises(CredentialsValidateFailedError):
+        MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
+            model_name='test_model_name',
+            model_type=ModelType.TEXT_GENERATION,
+            credentials={}
+        )
+
+    # raise CredentialsValidateFailedError if replicate_api_token is invalid
+    with pytest.raises(CredentialsValidateFailedError):
+        MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
+            model_name='test_model_name',
+            model_type=ModelType.TEXT_GENERATION,
+            credentials={'server_url': 'invalid'})
+
+
+@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
+def test_encrypt_model_credentials(mock_encrypt):
+    api_key = 'http://127.0.0.1:9997/'
+    result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
+        tenant_id='tenant_id',
+        model_name='test_model_name',
+        model_type=ModelType.TEXT_GENERATION,
+        credentials=VALIDATE_CREDENTIAL.copy()
+    )
+    mock_encrypt.assert_called_with('tenant_id', api_key)
+    assert result['server_url'] == f'encrypted_{api_key}'
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_get_model_credentials_custom(mock_decrypt, mocker):
+    provider = Provider(
+        id='provider_id',
+        tenant_id='tenant_id',
+        provider_name=PROVIDER_NAME,
+        provider_type=ProviderType.CUSTOM.value,
+        encrypted_config=None,
+        is_valid=True,
+    )
+
+    encrypted_credential = VALIDATE_CREDENTIAL.copy()
+    encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
+
+    mock_query = MagicMock()
+    mock_query.filter.return_value.first.return_value = ProviderModel(
+        encrypted_config=json.dumps(encrypted_credential)
+    )
+    mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
+
+    model_provider = MODEL_PROVIDER_CLASS(provider=provider)
+    result = model_provider.get_model_credentials(
+        model_name='test_model_name',
+        model_type=ModelType.TEXT_GENERATION
+    )
+    assert result['server_url'] == 'http://127.0.0.1:9997/'
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
+    provider = Provider(
+        id='provider_id',
+        tenant_id='tenant_id',
+        provider_name=PROVIDER_NAME,
+        provider_type=ProviderType.CUSTOM.value,
+        encrypted_config=None,
+        is_valid=True,
+    )
+
+    encrypted_credential = VALIDATE_CREDENTIAL.copy()
+    encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
+
+    mock_query = MagicMock()
+    mock_query.filter.return_value.first.return_value = ProviderModel(
+        encrypted_config=json.dumps(encrypted_credential)
+    )
+    mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
+
+    model_provider = MODEL_PROVIDER_CLASS(provider=provider)
+    result = model_provider.get_model_credentials(
+        model_name='test_model_name',
+        model_type=ModelType.TEXT_GENERATION,
+        obfuscated=True
+    )
+    middle_token = result['server_url'][6:-2]
+    assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['server_url']) - 8, 0)
+    assert all(char == '*' for char in middle_token)