message_entities.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import enum
  2. from typing import Any, cast
  3. from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
  4. from pydantic import BaseModel
  5. from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage, TextPromptMessageContent, \
  6. ImagePromptMessageContent, AssistantPromptMessage, SystemPromptMessage, ToolPromptMessage
  7. class PromptMessageFileType(enum.Enum):
  8. IMAGE = 'image'
  9. @staticmethod
  10. def value_of(value):
  11. for member in PromptMessageFileType:
  12. if member.value == value:
  13. return member
  14. raise ValueError(f"No matching enum found for value '{value}'")
  15. class PromptMessageFile(BaseModel):
  16. type: PromptMessageFileType
  17. data: Any
  18. class ImagePromptMessageFile(PromptMessageFile):
  19. class DETAIL(enum.Enum):
  20. LOW = 'low'
  21. HIGH = 'high'
  22. type: PromptMessageFileType = PromptMessageFileType.IMAGE
  23. detail: DETAIL = DETAIL.LOW
  24. class LCHumanMessageWithFiles(HumanMessage):
  25. # content: Union[str, List[Union[str, Dict]]]
  26. content: str
  27. files: list[PromptMessageFile]
  28. def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMessage]:
  29. prompt_messages = []
  30. for message in messages:
  31. if isinstance(message, HumanMessage):
  32. if isinstance(message, LCHumanMessageWithFiles):
  33. file_prompt_message_contents = []
  34. for file in message.files:
  35. if file.type == PromptMessageFileType.IMAGE:
  36. file = cast(ImagePromptMessageFile, file)
  37. file_prompt_message_contents.append(ImagePromptMessageContent(
  38. data=file.data,
  39. detail=ImagePromptMessageContent.DETAIL.HIGH
  40. if file.detail.value == "high" else ImagePromptMessageContent.DETAIL.LOW
  41. ))
  42. prompt_message_contents = [TextPromptMessageContent(data=message.content)]
  43. prompt_message_contents.extend(file_prompt_message_contents)
  44. prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
  45. else:
  46. prompt_messages.append(UserPromptMessage(content=message.content))
  47. elif isinstance(message, AIMessage):
  48. message_kwargs = {
  49. 'content': message.content
  50. }
  51. if 'function_call' in message.additional_kwargs:
  52. message_kwargs['tool_calls'] = [
  53. AssistantPromptMessage.ToolCall(
  54. id=message.additional_kwargs['function_call']['id'],
  55. type='function',
  56. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  57. name=message.additional_kwargs['function_call']['name'],
  58. arguments=message.additional_kwargs['function_call']['arguments']
  59. )
  60. )
  61. ]
  62. prompt_messages.append(AssistantPromptMessage(**message_kwargs))
  63. elif isinstance(message, SystemMessage):
  64. prompt_messages.append(SystemPromptMessage(content=message.content))
  65. elif isinstance(message, FunctionMessage):
  66. prompt_messages.append(ToolPromptMessage(content=message.content, tool_call_id=message.name))
  67. return prompt_messages
  68. def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list[BaseMessage]:
  69. messages = []
  70. for prompt_message in prompt_messages:
  71. if isinstance(prompt_message, UserPromptMessage):
  72. if isinstance(prompt_message.content, str):
  73. messages.append(HumanMessage(content=prompt_message.content))
  74. else:
  75. message_contents = []
  76. for content in prompt_message.content:
  77. if isinstance(content, TextPromptMessageContent):
  78. message_contents.append(content.data)
  79. elif isinstance(content, ImagePromptMessageContent):
  80. message_contents.append({
  81. 'type': 'image',
  82. 'data': content.data,
  83. 'detail': content.detail.value
  84. })
  85. messages.append(HumanMessage(content=message_contents))
  86. elif isinstance(prompt_message, AssistantPromptMessage):
  87. message_kwargs = {
  88. 'content': prompt_message.content
  89. }
  90. if prompt_message.tool_calls:
  91. message_kwargs['additional_kwargs'] = {
  92. 'function_call': {
  93. 'id': prompt_message.tool_calls[0].id,
  94. 'name': prompt_message.tool_calls[0].function.name,
  95. 'arguments': prompt_message.tool_calls[0].function.arguments
  96. }
  97. }
  98. messages.append(AIMessage(**message_kwargs))
  99. elif isinstance(prompt_message, SystemPromptMessage):
  100. messages.append(SystemMessage(content=prompt_message.content))
  101. elif isinstance(prompt_message, ToolPromptMessage):
  102. messages.append(FunctionMessage(name=prompt_message.tool_call_id, content=prompt_message.content))
  103. return messages