123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135 |
- import enum
- from typing import Any, cast
- from langchain.schema import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage
- from pydantic import BaseModel
- from core.model_runtime.entities.message_entities import (
- AssistantPromptMessage,
- ImagePromptMessageContent,
- PromptMessage,
- SystemPromptMessage,
- TextPromptMessageContent,
- ToolPromptMessage,
- UserPromptMessage,
- )
- class PromptMessageFileType(enum.Enum):
- IMAGE = 'image'
- @staticmethod
- def value_of(value):
- for member in PromptMessageFileType:
- if member.value == value:
- return member
- raise ValueError(f"No matching enum found for value '{value}'")
- class PromptMessageFile(BaseModel):
- type: PromptMessageFileType
- data: Any
- class ImagePromptMessageFile(PromptMessageFile):
- class DETAIL(enum.Enum):
- LOW = 'low'
- HIGH = 'high'
- type: PromptMessageFileType = PromptMessageFileType.IMAGE
- detail: DETAIL = DETAIL.LOW
- class LCHumanMessageWithFiles(HumanMessage):
- # content: Union[str, List[Union[str, Dict]]]
- content: str
- files: list[PromptMessageFile]
- def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMessage]:
- prompt_messages = []
- for message in messages:
- if isinstance(message, HumanMessage):
- if isinstance(message, LCHumanMessageWithFiles):
- file_prompt_message_contents = []
- for file in message.files:
- if file.type == PromptMessageFileType.IMAGE:
- file = cast(ImagePromptMessageFile, file)
- file_prompt_message_contents.append(ImagePromptMessageContent(
- data=file.data,
- detail=ImagePromptMessageContent.DETAIL.HIGH
- if file.detail.value == "high" else ImagePromptMessageContent.DETAIL.LOW
- ))
- prompt_message_contents = [TextPromptMessageContent(data=message.content)]
- prompt_message_contents.extend(file_prompt_message_contents)
- prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
- else:
- prompt_messages.append(UserPromptMessage(content=message.content))
- elif isinstance(message, AIMessage):
- message_kwargs = {
- 'content': message.content
- }
- if 'function_call' in message.additional_kwargs:
- message_kwargs['tool_calls'] = [
- AssistantPromptMessage.ToolCall(
- id=message.additional_kwargs['function_call']['id'],
- type='function',
- function=AssistantPromptMessage.ToolCall.ToolCallFunction(
- name=message.additional_kwargs['function_call']['name'],
- arguments=message.additional_kwargs['function_call']['arguments']
- )
- )
- ]
- prompt_messages.append(AssistantPromptMessage(**message_kwargs))
- elif isinstance(message, SystemMessage):
- prompt_messages.append(SystemPromptMessage(content=message.content))
- elif isinstance(message, FunctionMessage):
- prompt_messages.append(ToolPromptMessage(content=message.content, tool_call_id=message.name))
- return prompt_messages
- def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list[BaseMessage]:
- messages = []
- for prompt_message in prompt_messages:
- if isinstance(prompt_message, UserPromptMessage):
- if isinstance(prompt_message.content, str):
- messages.append(HumanMessage(content=prompt_message.content))
- else:
- message_contents = []
- for content in prompt_message.content:
- if isinstance(content, TextPromptMessageContent):
- message_contents.append(content.data)
- elif isinstance(content, ImagePromptMessageContent):
- message_contents.append({
- 'type': 'image',
- 'data': content.data,
- 'detail': content.detail.value
- })
- messages.append(HumanMessage(content=message_contents))
- elif isinstance(prompt_message, AssistantPromptMessage):
- message_kwargs = {
- 'content': prompt_message.content
- }
- if prompt_message.tool_calls:
- message_kwargs['additional_kwargs'] = {
- 'function_call': {
- 'id': prompt_message.tool_calls[0].id,
- 'name': prompt_message.tool_calls[0].function.name,
- 'arguments': prompt_message.tool_calls[0].function.arguments
- }
- }
- messages.append(AIMessage(**message_kwargs))
- elif isinstance(prompt_message, SystemPromptMessage):
- messages.append(SystemMessage(content=prompt_message.content))
- elif isinstance(prompt_message, ToolPromptMessage):
- messages.append(FunctionMessage(name=prompt_message.tool_call_id, content=prompt_message.content))
- return messages
|