application_queue_manager.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. import queue
  2. import time
  3. from enum import Enum
  4. from typing import Any, Generator
  5. from core.entities.application_entities import InvokeFrom
  6. from core.entities.queue_entities import (AnnotationReplyEvent, AppQueueEvent, QueueAgentMessageEvent,
  7. QueueAgentThoughtEvent, QueueErrorEvent, QueueMessage, QueueMessageEndEvent,
  8. QueueMessageEvent, QueueMessageFileEvent, QueueMessageReplaceEvent,
  9. QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent)
  10. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
  11. from extensions.ext_redis import redis_client
  12. from models.model import MessageAgentThought, MessageFile
  13. from sqlalchemy.orm import DeclarativeMeta
  14. class PublishFrom(Enum):
  15. APPLICATION_MANAGER = 1
  16. TASK_PIPELINE = 2
  17. class ApplicationQueueManager:
  18. def __init__(self, task_id: str,
  19. user_id: str,
  20. invoke_from: InvokeFrom,
  21. conversation_id: str,
  22. app_mode: str,
  23. message_id: str) -> None:
  24. if not user_id:
  25. raise ValueError("user is required")
  26. self._task_id = task_id
  27. self._user_id = user_id
  28. self._invoke_from = invoke_from
  29. self._conversation_id = str(conversation_id)
  30. self._app_mode = app_mode
  31. self._message_id = str(message_id)
  32. user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
  33. redis_client.setex(ApplicationQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}")
  34. q = queue.Queue()
  35. self._q = q
  36. def listen(self) -> Generator:
  37. """
  38. Listen to queue
  39. :return:
  40. """
  41. # wait for 10 minutes to stop listen
  42. listen_timeout = 600
  43. start_time = time.time()
  44. last_ping_time = 0
  45. while True:
  46. try:
  47. message = self._q.get(timeout=1)
  48. if message is None:
  49. break
  50. yield message
  51. except queue.Empty:
  52. continue
  53. finally:
  54. elapsed_time = time.time() - start_time
  55. if elapsed_time >= listen_timeout or self._is_stopped():
  56. # publish two messages to make sure the client can receive the stop signal
  57. # and stop listening after the stop signal processed
  58. self.publish(
  59. QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL),
  60. PublishFrom.TASK_PIPELINE
  61. )
  62. self.stop_listen()
  63. if elapsed_time // 10 > last_ping_time:
  64. self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
  65. last_ping_time = elapsed_time // 10
  66. def stop_listen(self) -> None:
  67. """
  68. Stop listen to queue
  69. :return:
  70. """
  71. self._q.put(None)
  72. def publish_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None:
  73. """
  74. Publish chunk message to channel
  75. :param chunk: chunk
  76. :param pub_from: publish from
  77. :return:
  78. """
  79. self.publish(QueueMessageEvent(
  80. chunk=chunk
  81. ), pub_from)
  82. def publish_agent_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None:
  83. """
  84. Publish agent chunk message to channel
  85. :param chunk: chunk
  86. :param pub_from: publish from
  87. :return:
  88. """
  89. self.publish(QueueAgentMessageEvent(
  90. chunk=chunk
  91. ), pub_from)
  92. def publish_message_replace(self, text: str, pub_from: PublishFrom) -> None:
  93. """
  94. Publish message replace
  95. :param text: text
  96. :param pub_from: publish from
  97. :return:
  98. """
  99. self.publish(QueueMessageReplaceEvent(
  100. text=text
  101. ), pub_from)
  102. def publish_retriever_resources(self, retriever_resources: list[dict], pub_from: PublishFrom) -> None:
  103. """
  104. Publish retriever resources
  105. :return:
  106. """
  107. self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources), pub_from)
  108. def publish_annotation_reply(self, message_annotation_id: str, pub_from: PublishFrom) -> None:
  109. """
  110. Publish annotation reply
  111. :param message_annotation_id: message annotation id
  112. :param pub_from: publish from
  113. :return:
  114. """
  115. self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id), pub_from)
  116. def publish_message_end(self, llm_result: LLMResult, pub_from: PublishFrom) -> None:
  117. """
  118. Publish message end
  119. :param llm_result: llm result
  120. :param pub_from: publish from
  121. :return:
  122. """
  123. self.publish(QueueMessageEndEvent(llm_result=llm_result), pub_from)
  124. self.stop_listen()
  125. def publish_agent_thought(self, message_agent_thought: MessageAgentThought, pub_from: PublishFrom) -> None:
  126. """
  127. Publish agent thought
  128. :param message_agent_thought: message agent thought
  129. :param pub_from: publish from
  130. :return:
  131. """
  132. self.publish(QueueAgentThoughtEvent(
  133. agent_thought_id=message_agent_thought.id
  134. ), pub_from)
  135. def publish_message_file(self, message_file: MessageFile, pub_from: PublishFrom) -> None:
  136. """
  137. Publish agent thought
  138. :param message_file: message file
  139. :param pub_from: publish from
  140. :return:
  141. """
  142. self.publish(QueueMessageFileEvent(
  143. message_file_id=message_file.id
  144. ), pub_from)
  145. def publish_error(self, e, pub_from: PublishFrom) -> None:
  146. """
  147. Publish error
  148. :param e: error
  149. :param pub_from: publish from
  150. :return:
  151. """
  152. self.publish(QueueErrorEvent(
  153. error=e
  154. ), pub_from)
  155. self.stop_listen()
  156. def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
  157. """
  158. Publish event to queue
  159. :param event:
  160. :param pub_from:
  161. :return:
  162. """
  163. self._check_for_sqlalchemy_models(event.dict())
  164. message = QueueMessage(
  165. task_id=self._task_id,
  166. message_id=self._message_id,
  167. conversation_id=self._conversation_id,
  168. app_mode=self._app_mode,
  169. event=event
  170. )
  171. self._q.put(message)
  172. if isinstance(event, QueueStopEvent):
  173. self.stop_listen()
  174. if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
  175. raise ConversationTaskStoppedException()
  176. @classmethod
  177. def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
  178. """
  179. Set task stop flag
  180. :return:
  181. """
  182. result = redis_client.get(cls._generate_task_belong_cache_key(task_id))
  183. if result is None:
  184. return
  185. user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
  186. if result.decode('utf-8') != f"{user_prefix}-{user_id}":
  187. return
  188. stopped_cache_key = cls._generate_stopped_cache_key(task_id)
  189. redis_client.setex(stopped_cache_key, 600, 1)
  190. def _is_stopped(self) -> bool:
  191. """
  192. Check if task is stopped
  193. :return:
  194. """
  195. stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id)
  196. result = redis_client.get(stopped_cache_key)
  197. if result is not None:
  198. return True
  199. return False
  200. @classmethod
  201. def _generate_task_belong_cache_key(cls, task_id: str) -> str:
  202. """
  203. Generate task belong cache key
  204. :param task_id: task id
  205. :return:
  206. """
  207. return f"generate_task_belong:{task_id}"
  208. @classmethod
  209. def _generate_stopped_cache_key(cls, task_id: str) -> str:
  210. """
  211. Generate stopped cache key
  212. :param task_id: task id
  213. :return:
  214. """
  215. return f"generate_task_stopped:{task_id}"
  216. def _check_for_sqlalchemy_models(self, data: Any):
  217. # from entity to dict or list
  218. if isinstance(data, dict):
  219. for key, value in data.items():
  220. self._check_for_sqlalchemy_models(value)
  221. elif isinstance(data, list):
  222. for item in data:
  223. self._check_for_sqlalchemy_models(item)
  224. else:
  225. if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'):
  226. raise TypeError("Critical Error: Passing SQLAlchemy Model instances "
  227. "that cause thread safety issues is not allowed.")
  228. class ConversationTaskStoppedException(Exception):
  229. pass