message.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import enum
  2. from typing import Any, cast, Union, List, Dict
  3. from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
  4. from pydantic import BaseModel
  5. class LLMRunResult(BaseModel):
  6. content: str
  7. prompt_tokens: int
  8. completion_tokens: int
  9. source: list = None
  10. function_call: dict = None
  11. class MessageType(enum.Enum):
  12. USER = 'user'
  13. ASSISTANT = 'assistant'
  14. SYSTEM = 'system'
  15. class PromptMessageFileType(enum.Enum):
  16. IMAGE = 'image'
  17. @staticmethod
  18. def value_of(value):
  19. for member in PromptMessageFileType:
  20. if member.value == value:
  21. return member
  22. raise ValueError(f"No matching enum found for value '{value}'")
  23. class PromptMessageFile(BaseModel):
  24. type: PromptMessageFileType
  25. data: Any
  26. class ImagePromptMessageFile(PromptMessageFile):
  27. class DETAIL(enum.Enum):
  28. LOW = 'low'
  29. HIGH = 'high'
  30. type: PromptMessageFileType = PromptMessageFileType.IMAGE
  31. detail: DETAIL = DETAIL.LOW
  32. class PromptMessage(BaseModel):
  33. type: MessageType = MessageType.USER
  34. content: str = ''
  35. files: list[PromptMessageFile] = []
  36. function_call: dict = None
  37. class LCHumanMessageWithFiles(HumanMessage):
  38. # content: Union[str, List[Union[str, Dict]]]
  39. content: str
  40. files: list[PromptMessageFile]
  41. def to_lc_messages(messages: list[PromptMessage]):
  42. lc_messages = []
  43. for message in messages:
  44. if message.type == MessageType.USER:
  45. if not message.files:
  46. lc_messages.append(HumanMessage(content=message.content))
  47. else:
  48. lc_messages.append(LCHumanMessageWithFiles(content=message.content, files=message.files))
  49. elif message.type == MessageType.ASSISTANT:
  50. additional_kwargs = {}
  51. if message.function_call:
  52. additional_kwargs['function_call'] = message.function_call
  53. lc_messages.append(AIMessage(content=message.content, additional_kwargs=additional_kwargs))
  54. elif message.type == MessageType.SYSTEM:
  55. lc_messages.append(SystemMessage(content=message.content))
  56. return lc_messages
  57. def to_prompt_messages(messages: list[BaseMessage]):
  58. prompt_messages = []
  59. for message in messages:
  60. if isinstance(message, HumanMessage):
  61. if isinstance(message, LCHumanMessageWithFiles):
  62. prompt_messages.append(PromptMessage(
  63. content=message.content,
  64. type=MessageType.USER,
  65. files=message.files
  66. ))
  67. else:
  68. prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
  69. elif isinstance(message, AIMessage):
  70. message_kwargs = {
  71. 'content': message.content,
  72. 'type': MessageType.ASSISTANT
  73. }
  74. if 'function_call' in message.additional_kwargs:
  75. message_kwargs['function_call'] = message.additional_kwargs['function_call']
  76. prompt_messages.append(PromptMessage(**message_kwargs))
  77. elif isinstance(message, SystemMessage):
  78. prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
  79. elif isinstance(message, FunctionMessage):
  80. prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
  81. return prompt_messages
  82. def str_to_prompt_messages(texts: list[str]):
  83. prompt_messages = []
  84. for text in texts:
  85. prompt_messages.append(PromptMessage(content=text))
  86. return prompt_messages