|
@@ -8,8 +8,9 @@ from typing import (
|
|
Union
|
|
Union
|
|
)
|
|
)
|
|
|
|
|
|
-from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, AssistantPromptMessage, \
|
|
|
|
- SystemPromptMessage
|
|
|
|
|
|
+from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, \
|
|
|
|
+ AssistantPromptMessage, \
|
|
|
|
+ SystemPromptMessage, PromptMessageRole
|
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \
|
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \
|
|
LLMResultChunkDelta
|
|
LLMResultChunkDelta
|
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
@@ -111,16 +112,39 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
if len(prompt_messages) == 0:
|
|
if len(prompt_messages) == 0:
|
|
raise ValueError('At least one message is required')
|
|
raise ValueError('At least one message is required')
|
|
|
|
|
|
- if prompt_messages[0].role.value == 'system':
|
|
|
|
|
|
+ if prompt_messages[0].role == PromptMessageRole.SYSTEM:
|
|
if not prompt_messages[0].content:
|
|
if not prompt_messages[0].content:
|
|
prompt_messages = prompt_messages[1:]
|
|
prompt_messages = prompt_messages[1:]
|
|
|
|
|
|
|
|
+ # resolve zhipuai model not support system message and user message, assistant message must be in sequence
|
|
|
|
+ new_prompt_messages = []
|
|
|
|
+ for prompt_message in prompt_messages:
|
|
|
|
+ copy_prompt_message = prompt_message.copy()
|
|
|
|
+ if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]:
|
|
|
|
+ if not isinstance(copy_prompt_message.content, str):
|
|
|
|
+ # not support image message
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER:
|
|
|
|
+ new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
|
|
|
|
+ else:
|
|
|
|
+ if copy_prompt_message.role == PromptMessageRole.USER:
|
|
|
|
+ new_prompt_messages.append(copy_prompt_message)
|
|
|
|
+ else:
|
|
|
|
+ new_prompt_message = UserPromptMessage(content=copy_prompt_message.content)
|
|
|
|
+ new_prompt_messages.append(new_prompt_message)
|
|
|
|
+ else:
|
|
|
|
+ if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.ASSISTANT:
|
|
|
|
+ new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
|
|
|
|
+ else:
|
|
|
|
+ new_prompt_messages.append(copy_prompt_message)
|
|
|
|
+
|
|
params = {
|
|
params = {
|
|
'model': model,
|
|
'model': model,
|
|
'prompt': [{
|
|
'prompt': [{
|
|
- 'role': prompt_message.role.value if prompt_message.role.value != 'system' else 'user',
|
|
|
|
|
|
+ 'role': prompt_message.role.value,
|
|
'content': prompt_message.content
|
|
'content': prompt_message.content
|
|
- } for prompt_message in prompt_messages],
|
|
|
|
|
|
+ } for prompt_message in new_prompt_messages],
|
|
**model_parameters
|
|
**model_parameters
|
|
}
|
|
}
|
|
|
|
|