message_entities.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import enum
  2. from typing import Any, cast
  3. from langchain.schema import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage
  4. from pydantic import BaseModel
  5. from core.model_runtime.entities.message_entities import (
  6. AssistantPromptMessage,
  7. ImagePromptMessageContent,
  8. PromptMessage,
  9. SystemPromptMessage,
  10. TextPromptMessageContent,
  11. ToolPromptMessage,
  12. UserPromptMessage,
  13. )
  14. class PromptMessageFileType(enum.Enum):
  15. IMAGE = 'image'
  16. @staticmethod
  17. def value_of(value):
  18. for member in PromptMessageFileType:
  19. if member.value == value:
  20. return member
  21. raise ValueError(f"No matching enum found for value '{value}'")
  22. class PromptMessageFile(BaseModel):
  23. type: PromptMessageFileType
  24. data: Any
  25. class ImagePromptMessageFile(PromptMessageFile):
  26. class DETAIL(enum.Enum):
  27. LOW = 'low'
  28. HIGH = 'high'
  29. type: PromptMessageFileType = PromptMessageFileType.IMAGE
  30. detail: DETAIL = DETAIL.LOW
  31. class LCHumanMessageWithFiles(HumanMessage):
  32. # content: Union[str, List[Union[str, Dict]]]
  33. content: str
  34. files: list[PromptMessageFile]
  35. def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMessage]:
  36. prompt_messages = []
  37. for message in messages:
  38. if isinstance(message, HumanMessage):
  39. if isinstance(message, LCHumanMessageWithFiles):
  40. file_prompt_message_contents = []
  41. for file in message.files:
  42. if file.type == PromptMessageFileType.IMAGE:
  43. file = cast(ImagePromptMessageFile, file)
  44. file_prompt_message_contents.append(ImagePromptMessageContent(
  45. data=file.data,
  46. detail=ImagePromptMessageContent.DETAIL.HIGH
  47. if file.detail.value == "high" else ImagePromptMessageContent.DETAIL.LOW
  48. ))
  49. prompt_message_contents = [TextPromptMessageContent(data=message.content)]
  50. prompt_message_contents.extend(file_prompt_message_contents)
  51. prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
  52. else:
  53. prompt_messages.append(UserPromptMessage(content=message.content))
  54. elif isinstance(message, AIMessage):
  55. message_kwargs = {
  56. 'content': message.content
  57. }
  58. if 'function_call' in message.additional_kwargs:
  59. message_kwargs['tool_calls'] = [
  60. AssistantPromptMessage.ToolCall(
  61. id=message.additional_kwargs['function_call']['id'],
  62. type='function',
  63. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  64. name=message.additional_kwargs['function_call']['name'],
  65. arguments=message.additional_kwargs['function_call']['arguments']
  66. )
  67. )
  68. ]
  69. prompt_messages.append(AssistantPromptMessage(**message_kwargs))
  70. elif isinstance(message, SystemMessage):
  71. prompt_messages.append(SystemPromptMessage(content=message.content))
  72. elif isinstance(message, FunctionMessage):
  73. prompt_messages.append(ToolPromptMessage(content=message.content, tool_call_id=message.name))
  74. return prompt_messages
  75. def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list[BaseMessage]:
  76. messages = []
  77. for prompt_message in prompt_messages:
  78. if isinstance(prompt_message, UserPromptMessage):
  79. if isinstance(prompt_message.content, str):
  80. messages.append(HumanMessage(content=prompt_message.content))
  81. else:
  82. message_contents = []
  83. for content in prompt_message.content:
  84. if isinstance(content, TextPromptMessageContent):
  85. message_contents.append(content.data)
  86. elif isinstance(content, ImagePromptMessageContent):
  87. message_contents.append({
  88. 'type': 'image',
  89. 'data': content.data,
  90. 'detail': content.detail.value
  91. })
  92. messages.append(HumanMessage(content=message_contents))
  93. elif isinstance(prompt_message, AssistantPromptMessage):
  94. message_kwargs = {
  95. 'content': prompt_message.content
  96. }
  97. if prompt_message.tool_calls:
  98. message_kwargs['additional_kwargs'] = {
  99. 'function_call': {
  100. 'id': prompt_message.tool_calls[0].id,
  101. 'name': prompt_message.tool_calls[0].function.name,
  102. 'arguments': prompt_message.tool_calls[0].function.arguments
  103. }
  104. }
  105. messages.append(AIMessage(**message_kwargs))
  106. elif isinstance(prompt_message, SystemPromptMessage):
  107. messages.append(SystemMessage(content=prompt_message.content))
  108. elif isinstance(prompt_message, ToolPromptMessage):
  109. messages.append(FunctionMessage(name=prompt_message.tool_call_id, content=prompt_message.content))
  110. return messages