conversation_service.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. from typing import Union, Optional
  2. from core.generator.llm_generator import LLMGenerator
  3. from libs.infinite_scroll_pagination import InfiniteScrollPagination
  4. from extensions.ext_database import db
  5. from models.account import Account
  6. from models.model import Conversation, App, EndUser, Message
  7. from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
  8. from services.errors.message import MessageNotExistsError
  9. class ConversationService:
  10. @classmethod
  11. def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]],
  12. last_id: Optional[str], limit: int,
  13. include_ids: Optional[list] = None, exclude_ids: Optional[list] = None,
  14. exclude_debug_conversation: bool = False) -> InfiniteScrollPagination:
  15. if not user:
  16. return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
  17. base_query = db.session.query(Conversation).filter(
  18. Conversation.is_deleted == False,
  19. Conversation.app_id == app_model.id,
  20. Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'),
  21. Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
  22. Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
  23. )
  24. if include_ids is not None:
  25. base_query = base_query.filter(Conversation.id.in_(include_ids))
  26. if exclude_ids is not None:
  27. base_query = base_query.filter(~Conversation.id.in_(exclude_ids))
  28. if exclude_debug_conversation:
  29. base_query = base_query.filter(Conversation.override_model_configs == None)
  30. if last_id:
  31. last_conversation = base_query.filter(
  32. Conversation.id == last_id,
  33. ).first()
  34. if not last_conversation:
  35. raise LastConversationNotExistsError()
  36. conversations = base_query.filter(
  37. Conversation.created_at < last_conversation.created_at,
  38. Conversation.id != last_conversation.id
  39. ).order_by(Conversation.created_at.desc()).limit(limit).all()
  40. else:
  41. conversations = base_query.order_by(Conversation.created_at.desc()).limit(limit).all()
  42. has_more = False
  43. if len(conversations) == limit:
  44. current_page_first_conversation = conversations[-1]
  45. rest_count = base_query.filter(
  46. Conversation.created_at < current_page_first_conversation.created_at,
  47. Conversation.id != current_page_first_conversation.id
  48. ).count()
  49. if rest_count > 0:
  50. has_more = True
  51. return InfiniteScrollPagination(
  52. data=conversations,
  53. limit=limit,
  54. has_more=has_more
  55. )
  56. @classmethod
  57. def rename(cls, app_model: App, conversation_id: str,
  58. user: Optional[Union[Account | EndUser]], name: str, auto_generate: bool):
  59. conversation = cls.get_conversation(app_model, conversation_id, user)
  60. if auto_generate:
  61. return cls.auto_generate_name(app_model, conversation)
  62. else:
  63. conversation.name = name
  64. db.session.commit()
  65. return conversation
  66. @classmethod
  67. def auto_generate_name(cls, app_model: App, conversation: Conversation):
  68. # get conversation first message
  69. message = db.session.query(Message) \
  70. .filter(
  71. Message.app_id == app_model.id,
  72. Message.conversation_id == conversation.id
  73. ).order_by(Message.created_at.asc()).first()
  74. if not message:
  75. raise MessageNotExistsError()
  76. # generate conversation name
  77. try:
  78. name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query)
  79. conversation.name = name
  80. except:
  81. pass
  82. db.session.commit()
  83. return conversation
  84. @classmethod
  85. def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account | EndUser]]):
  86. conversation = db.session.query(Conversation) \
  87. .filter(
  88. Conversation.id == conversation_id,
  89. Conversation.app_id == app_model.id,
  90. Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'),
  91. Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
  92. Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
  93. Conversation.is_deleted == False
  94. ).first()
  95. if not conversation:
  96. raise ConversationNotExistsError()
  97. return conversation
  98. @classmethod
  99. def delete(cls, app_model: App, conversation_id: str, user: Optional[Union[Account | EndUser]]):
  100. conversation = cls.get_conversation(app_model, conversation_id, user)
  101. conversation.is_deleted = True
  102. db.session.commit()