Yeuoly преди 1 година
родител
ревизия
5a756ca981

+ 0 - 10
api/core/model_runtime/model_providers/__base/ai_model.py

@@ -236,16 +236,6 @@ class AIModel(ABC):
         :param credentials: model credentials
         :return: model schema
         """
-        if 'schema' in credentials:
-            schema_dict = json.loads(credentials['schema'])
-
-            try:
-                model_instance = AIModelEntity.parse_obj(schema_dict)
-                return model_instance
-            except ValidationError as e:
-                logging.exception(f"Invalid model schema for {model}")
-                return self._get_customizable_model_schema(model, credentials)
-
         return self._get_customizable_model_schema(model, credentials)
     
     def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:

+ 4 - 4
api/core/model_runtime/model_providers/localai/llm/llm.py

@@ -1,7 +1,7 @@
 from typing import Generator, List, Optional, Union, cast
 from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta, LLMMode
 from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, UserPromptMessage, SystemPromptMessage
-from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType
+from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType, ModelPropertyKey
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
@@ -156,9 +156,9 @@ class LocalAILarguageModel(LargeLanguageModel):
     def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
         completion_model = None
         if credentials['completion_type'] == 'chat_completion':
-            completion_model = LLMMode.CHAT
+            completion_model = LLMMode.CHAT.value
         elif credentials['completion_type'] == 'completion':
-            completion_model = LLMMode.COMPLETION
+            completion_model = LLMMode.COMPLETION.value
         else:
             raise ValueError(f"Unknown completion type {credentials['completion_type']}")
             
@@ -202,7 +202,7 @@ class LocalAILarguageModel(LargeLanguageModel):
             ),
             fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
             model_type=ModelType.LLM,
-            model_properties={ 'mode': completion_model } if completion_model else {},
+            model_properties={ ModelPropertyKey.MODE: completion_model } if completion_model else {},
             parameter_rules=rules
         )
 

+ 2 - 2
api/core/model_runtime/model_providers/openai_api_compatible/_common.py

@@ -117,9 +117,9 @@ class _CommonOAI_API_Compat:
 
         if model_type == ModelType.LLM:
             if credentials['mode'] == 'chat':
-                entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT
+                entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
             elif credentials['mode'] == 'completion':
-                entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION
+                entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
             else:
                 raise ValueError(f"Unknown completion type {credentials['completion_type']}")
         

+ 2 - 2
api/core/model_runtime/model_providers/openllm/llm/llm.py

@@ -6,7 +6,7 @@ from core.model_runtime.model_providers.openllm.llm.openllm_generate import Open
 from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta, LLMMode
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, UserPromptMessage, SystemPromptMessage
-from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType
+from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType, ModelPropertyKey
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
     InvokeAuthorizationError, InvokeBadRequestError, InvokeError
@@ -198,7 +198,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
             fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
             model_type=ModelType.LLM,
             model_properties={ 
-                'mode':  LLMMode.COMPLETION,
+                ModelPropertyKey.MODE: LLMMode.COMPLETION.value,
             },
             parameter_rules=rules
         )

+ 2 - 2
api/core/model_runtime/model_providers/replicate/llm/llm.py

@@ -8,7 +8,7 @@ from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.llm_entities import LLMResult, LLMMode, LLMResultChunk, LLMResultChunkDelta
 from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, \
     PromptMessageRole, UserPromptMessage, SystemPromptMessage
-from core.model_runtime.entities.model_entities import ParameterRule, AIModelEntity, FetchFrom, ModelType
+from core.model_runtime.entities.model_entities import ParameterRule, AIModelEntity, FetchFrom, ModelType, ModelPropertyKey
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.model_runtime.model_providers.replicate._common import _CommonReplicate
@@ -91,7 +91,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
             fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
             model_type=ModelType.LLM,
             model_properties={
-                'mode': model_type.value
+                ModelPropertyKey.MODE: model_type.value
             },
             parameter_rules=self._get_customizable_model_parameter_rules(model, credentials)
         )

+ 29 - 12
api/core/model_runtime/model_providers/xinference/llm/llm.py

@@ -18,7 +18,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
 from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
 from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, SystemPromptMessage, AssistantPromptMessage
 from core.model_runtime.entities.common_entities import I18nObject
-from core.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType
+from core.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType, ModelPropertyKey
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.xinference.llm.xinference_helper import XinferenceHelper, XinferenceModelExtraParameter
 from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
@@ -56,10 +56,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
             }
         """
         try:
-            XinferenceHelper.get_xinference_extra_parameter(
+            extra_param = XinferenceHelper.get_xinference_extra_parameter(
                 server_url=credentials['server_url'],
                 model_uid=credentials['model_uid']
             )
+            if 'completion_type' not in credentials:
+                if 'chat' in extra_param.model_ability:
+                    credentials['completion_type'] = 'chat'
+                elif 'generate' in extra_param.model_ability:
+                    credentials['completion_type'] = 'completion'
+                else:
+                    raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported')
+
         except RuntimeError as e:
             raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
         except KeyError as e:
@@ -256,17 +264,26 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
         ]
 
         completion_type = None
-        extra_args = XinferenceHelper.get_xinference_extra_parameter(
-            server_url=credentials['server_url'],
-            model_uid=credentials['model_uid']
-        )
 
-        if 'chat' in extra_args.model_ability:
-            completion_type = LLMMode.CHAT
-        elif 'generate' in extra_args.model_ability:
-            completion_type = LLMMode.COMPLETION
+        if 'completion_type' in credentials:
+            if credentials['completion_type'] == 'chat':
+                completion_type = LLMMode.CHAT.value
+            elif credentials['completion_type'] == 'completion':
+                completion_type = LLMMode.COMPLETION.value
+            else:
+                raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
         else:
-            raise NotImplementedError(f'xinference model ability {extra_args.model_ability} is not supported')
+            extra_args = XinferenceHelper.get_xinference_extra_parameter(
+                server_url=credentials['server_url'],
+                model_uid=credentials['model_uid']
+            )
+
+            if 'chat' in extra_args.model_ability:
+                completion_type = LLMMode.CHAT.value
+            elif 'generate' in extra_args.model_ability:
+                completion_type = LLMMode.COMPLETION.value
+            else:
+                raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
 
         entity = AIModelEntity(
             model=model,
@@ -276,7 +293,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
             fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
             model_type=ModelType.LLM,
             model_properties={ 
-                'mode':  completion_type,
+                ModelPropertyKey.MODE: completion_type,
             },
             parameter_rules=rules
         )