|
@@ -18,7 +18,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
|
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
|
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, SystemPromptMessage, AssistantPromptMessage
|
|
|
from core.model_runtime.entities.common_entities import I18nObject
|
|
|
-from core.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType
|
|
|
+from core.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType, ModelPropertyKey
|
|
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
|
from core.model_runtime.model_providers.xinference.llm.xinference_helper import XinferenceHelper, XinferenceModelExtraParameter
|
|
|
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
|
|
@@ -56,10 +56,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
}
|
|
|
"""
|
|
|
try:
|
|
|
- XinferenceHelper.get_xinference_extra_parameter(
|
|
|
+ extra_param = XinferenceHelper.get_xinference_extra_parameter(
|
|
|
server_url=credentials['server_url'],
|
|
|
model_uid=credentials['model_uid']
|
|
|
)
|
|
|
+ if 'completion_type' not in credentials:
|
|
|
+ if 'chat' in extra_param.model_ability:
|
|
|
+ credentials['completion_type'] = 'chat'
|
|
|
+ elif 'generate' in extra_param.model_ability:
|
|
|
+ credentials['completion_type'] = 'completion'
|
|
|
+ else:
|
|
|
+ raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported')
|
|
|
+
|
|
|
except RuntimeError as e:
|
|
|
raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
|
|
|
except KeyError as e:
|
|
@@ -256,17 +264,26 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
]
|
|
|
|
|
|
completion_type = None
|
|
|
- extra_args = XinferenceHelper.get_xinference_extra_parameter(
|
|
|
- server_url=credentials['server_url'],
|
|
|
- model_uid=credentials['model_uid']
|
|
|
- )
|
|
|
|
|
|
- if 'chat' in extra_args.model_ability:
|
|
|
- completion_type = LLMMode.CHAT
|
|
|
- elif 'generate' in extra_args.model_ability:
|
|
|
- completion_type = LLMMode.COMPLETION
|
|
|
+ if 'completion_type' in credentials:
|
|
|
+ if credentials['completion_type'] == 'chat':
|
|
|
+ completion_type = LLMMode.CHAT.value
|
|
|
+ elif credentials['completion_type'] == 'completion':
|
|
|
+ completion_type = LLMMode.COMPLETION.value
|
|
|
+ else:
|
|
|
+ raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
|
|
|
else:
|
|
|
- raise NotImplementedError(f'xinference model ability {extra_args.model_ability} is not supported')
|
|
|
+ extra_args = XinferenceHelper.get_xinference_extra_parameter(
|
|
|
+ server_url=credentials['server_url'],
|
|
|
+ model_uid=credentials['model_uid']
|
|
|
+ )
|
|
|
+
|
|
|
+ if 'chat' in extra_args.model_ability:
|
|
|
+ completion_type = LLMMode.CHAT.value
|
|
|
+ elif 'generate' in extra_args.model_ability:
|
|
|
+ completion_type = LLMMode.COMPLETION.value
|
|
|
+ else:
|
|
|
+ raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
|
|
|
|
|
|
entity = AIModelEntity(
|
|
|
model=model,
|
|
@@ -276,7 +293,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
|
model_type=ModelType.LLM,
|
|
|
model_properties={
|
|
|
- 'mode': completion_type,
|
|
|
+ ModelPropertyKey.MODE: completion_type,
|
|
|
},
|
|
|
parameter_rules=rules
|
|
|
)
|