message.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import enum
  2. from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
  3. from pydantic import BaseModel
  4. class LLMRunResult(BaseModel):
  5. content: str
  6. prompt_tokens: int
  7. completion_tokens: int
  8. source: list = None
  9. class MessageType(enum.Enum):
  10. HUMAN = 'human'
  11. ASSISTANT = 'assistant'
  12. SYSTEM = 'system'
  13. class PromptMessage(BaseModel):
  14. type: MessageType = MessageType.HUMAN
  15. content: str = ''
  16. def to_lc_messages(messages: list[PromptMessage]):
  17. lc_messages = []
  18. for message in messages:
  19. if message.type == MessageType.HUMAN:
  20. lc_messages.append(HumanMessage(content=message.content))
  21. elif message.type == MessageType.ASSISTANT:
  22. lc_messages.append(AIMessage(content=message.content))
  23. elif message.type == MessageType.SYSTEM:
  24. lc_messages.append(SystemMessage(content=message.content))
  25. return lc_messages
  26. def to_prompt_messages(messages: list[BaseMessage]):
  27. prompt_messages = []
  28. for message in messages:
  29. if isinstance(message, HumanMessage):
  30. prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
  31. elif isinstance(message, AIMessage):
  32. prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT))
  33. elif isinstance(message, SystemMessage):
  34. prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
  35. return prompt_messages
  36. def str_to_prompt_messages(texts: list[str]):
  37. prompt_messages = []
  38. for text in texts:
  39. prompt_messages.append(PromptMessage(content=text))
  40. return prompt_messages