application_queue_manager.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import queue
  2. import time
  3. from typing import Generator, Any
  4. from sqlalchemy.orm import DeclarativeMeta
  5. from core.entities.application_entities import InvokeFrom
  6. from core.entities.queue_entities import QueueStopEvent, AppQueueEvent, QueuePingEvent, QueueErrorEvent, \
  7. QueueAgentThoughtEvent, QueueMessageEndEvent, QueueRetrieverResourcesEvent, QueueMessageReplaceEvent, \
  8. QueueMessageEvent, QueueMessage, AnnotationReplyEvent
  9. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
  10. from extensions.ext_redis import redis_client
  11. from models.model import MessageAgentThought
  12. class ApplicationQueueManager:
  13. def __init__(self, task_id: str,
  14. user_id: str,
  15. invoke_from: InvokeFrom,
  16. conversation_id: str,
  17. app_mode: str,
  18. message_id: str) -> None:
  19. if not user_id:
  20. raise ValueError("user is required")
  21. self._task_id = task_id
  22. self._user_id = user_id
  23. self._invoke_from = invoke_from
  24. self._conversation_id = str(conversation_id)
  25. self._app_mode = app_mode
  26. self._message_id = str(message_id)
  27. user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
  28. redis_client.setex(ApplicationQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}")
  29. q = queue.Queue()
  30. self._q = q
  31. def listen(self) -> Generator:
  32. """
  33. Listen to queue
  34. :return:
  35. """
  36. # wait for 10 minutes to stop listen
  37. listen_timeout = 600
  38. start_time = time.time()
  39. last_ping_time = 0
  40. while True:
  41. try:
  42. message = self._q.get(timeout=1)
  43. if message is None:
  44. break
  45. yield message
  46. except queue.Empty:
  47. continue
  48. finally:
  49. elapsed_time = time.time() - start_time
  50. if elapsed_time >= listen_timeout or self._is_stopped():
  51. # publish two messages to make sure the client can receive the stop signal
  52. # and stop listening after the stop signal processed
  53. self.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))
  54. self.stop_listen()
  55. if elapsed_time // 10 > last_ping_time:
  56. self.publish(QueuePingEvent())
  57. last_ping_time = elapsed_time // 10
  58. def stop_listen(self) -> None:
  59. """
  60. Stop listen to queue
  61. :return:
  62. """
  63. self._q.put(None)
  64. def publish_chunk_message(self, chunk: LLMResultChunk) -> None:
  65. """
  66. Publish chunk message to channel
  67. :param chunk: chunk
  68. :return:
  69. """
  70. self.publish(QueueMessageEvent(
  71. chunk=chunk
  72. ))
  73. def publish_message_replace(self, text: str) -> None:
  74. """
  75. Publish message replace
  76. :param text: text
  77. :return:
  78. """
  79. self.publish(QueueMessageReplaceEvent(
  80. text=text
  81. ))
  82. def publish_retriever_resources(self, retriever_resources: list[dict]) -> None:
  83. """
  84. Publish retriever resources
  85. :return:
  86. """
  87. self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources))
  88. def publish_annotation_reply(self, message_annotation_id: str) -> None:
  89. """
  90. Publish annotation reply
  91. :param message_annotation_id: message annotation id
  92. :return:
  93. """
  94. self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id))
  95. def publish_message_end(self, llm_result: LLMResult) -> None:
  96. """
  97. Publish message end
  98. :param llm_result: llm result
  99. :return:
  100. """
  101. self.publish(QueueMessageEndEvent(llm_result=llm_result))
  102. self.stop_listen()
  103. def publish_agent_thought(self, message_agent_thought: MessageAgentThought) -> None:
  104. """
  105. Publish agent thought
  106. :param message_agent_thought: message agent thought
  107. :return:
  108. """
  109. self.publish(QueueAgentThoughtEvent(
  110. agent_thought_id=message_agent_thought.id
  111. ))
  112. def publish_error(self, e) -> None:
  113. """
  114. Publish error
  115. :param e: error
  116. :return:
  117. """
  118. self.publish(QueueErrorEvent(
  119. error=e
  120. ))
  121. self.stop_listen()
  122. def publish(self, event: AppQueueEvent) -> None:
  123. """
  124. Publish event to queue
  125. :param event:
  126. :return:
  127. """
  128. self._check_for_sqlalchemy_models(event.dict())
  129. message = QueueMessage(
  130. task_id=self._task_id,
  131. message_id=self._message_id,
  132. conversation_id=self._conversation_id,
  133. app_mode=self._app_mode,
  134. event=event
  135. )
  136. self._q.put(message)
  137. if isinstance(event, QueueStopEvent):
  138. self.stop_listen()
  139. @classmethod
  140. def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
  141. """
  142. Set task stop flag
  143. :return:
  144. """
  145. result = redis_client.get(cls._generate_task_belong_cache_key(task_id))
  146. if result is None:
  147. return
  148. user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
  149. if result.decode('utf-8') != f"{user_prefix}-{user_id}":
  150. return
  151. stopped_cache_key = cls._generate_stopped_cache_key(task_id)
  152. redis_client.setex(stopped_cache_key, 600, 1)
  153. def _is_stopped(self) -> bool:
  154. """
  155. Check if task is stopped
  156. :return:
  157. """
  158. stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id)
  159. result = redis_client.get(stopped_cache_key)
  160. if result is not None:
  161. redis_client.delete(stopped_cache_key)
  162. return True
  163. return False
  164. @classmethod
  165. def _generate_task_belong_cache_key(cls, task_id: str) -> str:
  166. """
  167. Generate task belong cache key
  168. :param task_id: task id
  169. :return:
  170. """
  171. return f"generate_task_belong:{task_id}"
  172. @classmethod
  173. def _generate_stopped_cache_key(cls, task_id: str) -> str:
  174. """
  175. Generate stopped cache key
  176. :param task_id: task id
  177. :return:
  178. """
  179. return f"generate_task_stopped:{task_id}"
  180. def _check_for_sqlalchemy_models(self, data: Any):
  181. # from entity to dict or list
  182. if isinstance(data, dict):
  183. for key, value in data.items():
  184. self._check_for_sqlalchemy_models(value)
  185. elif isinstance(data, list):
  186. for item in data:
  187. self._check_for_sqlalchemy_models(item)
  188. else:
  189. if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'):
  190. raise TypeError("Critical Error: Passing SQLAlchemy Model instances "
  191. "that cause thread safety issues is not allowed.")
  192. class ConversationTaskStoppedException(Exception):
  193. pass