application_queue_manager.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. import queue
  2. import time
  3. from enum import Enum
  4. from typing import Generator, Any
  5. from sqlalchemy.orm import DeclarativeMeta
  6. from core.entities.application_entities import InvokeFrom
  7. from core.entities.queue_entities import QueueStopEvent, AppQueueEvent, QueuePingEvent, QueueErrorEvent, \
  8. QueueAgentThoughtEvent, QueueMessageEndEvent, QueueRetrieverResourcesEvent, QueueMessageReplaceEvent, \
  9. QueueMessageEvent, QueueMessage, AnnotationReplyEvent
  10. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
  11. from extensions.ext_redis import redis_client
  12. from models.model import MessageAgentThought
  13. class PublishFrom(Enum):
  14. APPLICATION_MANAGER = 1
  15. TASK_PIPELINE = 2
  16. class ApplicationQueueManager:
  17. def __init__(self, task_id: str,
  18. user_id: str,
  19. invoke_from: InvokeFrom,
  20. conversation_id: str,
  21. app_mode: str,
  22. message_id: str) -> None:
  23. if not user_id:
  24. raise ValueError("user is required")
  25. self._task_id = task_id
  26. self._user_id = user_id
  27. self._invoke_from = invoke_from
  28. self._conversation_id = str(conversation_id)
  29. self._app_mode = app_mode
  30. self._message_id = str(message_id)
  31. user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
  32. redis_client.setex(ApplicationQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}")
  33. q = queue.Queue()
  34. self._q = q
  35. def listen(self) -> Generator:
  36. """
  37. Listen to queue
  38. :return:
  39. """
  40. # wait for 10 minutes to stop listen
  41. listen_timeout = 600
  42. start_time = time.time()
  43. last_ping_time = 0
  44. while True:
  45. try:
  46. message = self._q.get(timeout=1)
  47. if message is None:
  48. break
  49. yield message
  50. except queue.Empty:
  51. continue
  52. finally:
  53. elapsed_time = time.time() - start_time
  54. if elapsed_time >= listen_timeout or self._is_stopped():
  55. # publish two messages to make sure the client can receive the stop signal
  56. # and stop listening after the stop signal processed
  57. self.publish(
  58. QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL),
  59. PublishFrom.TASK_PIPELINE
  60. )
  61. self.stop_listen()
  62. if elapsed_time // 10 > last_ping_time:
  63. self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
  64. last_ping_time = elapsed_time // 10
  65. def stop_listen(self) -> None:
  66. """
  67. Stop listen to queue
  68. :return:
  69. """
  70. self._q.put(None)
  71. def publish_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None:
  72. """
  73. Publish chunk message to channel
  74. :param chunk: chunk
  75. :param pub_from: publish from
  76. :return:
  77. """
  78. self.publish(QueueMessageEvent(
  79. chunk=chunk
  80. ), pub_from)
  81. def publish_message_replace(self, text: str, pub_from: PublishFrom) -> None:
  82. """
  83. Publish message replace
  84. :param text: text
  85. :param pub_from: publish from
  86. :return:
  87. """
  88. self.publish(QueueMessageReplaceEvent(
  89. text=text
  90. ), pub_from)
  91. def publish_retriever_resources(self, retriever_resources: list[dict], pub_from: PublishFrom) -> None:
  92. """
  93. Publish retriever resources
  94. :return:
  95. """
  96. self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources), pub_from)
  97. def publish_annotation_reply(self, message_annotation_id: str, pub_from: PublishFrom) -> None:
  98. """
  99. Publish annotation reply
  100. :param message_annotation_id: message annotation id
  101. :param pub_from: publish from
  102. :return:
  103. """
  104. self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id), pub_from)
  105. def publish_message_end(self, llm_result: LLMResult, pub_from: PublishFrom) -> None:
  106. """
  107. Publish message end
  108. :param llm_result: llm result
  109. :param pub_from: publish from
  110. :return:
  111. """
  112. self.publish(QueueMessageEndEvent(llm_result=llm_result), pub_from)
  113. self.stop_listen()
  114. def publish_agent_thought(self, message_agent_thought: MessageAgentThought, pub_from: PublishFrom) -> None:
  115. """
  116. Publish agent thought
  117. :param message_agent_thought: message agent thought
  118. :param pub_from: publish from
  119. :return:
  120. """
  121. self.publish(QueueAgentThoughtEvent(
  122. agent_thought_id=message_agent_thought.id
  123. ), pub_from)
  124. def publish_error(self, e, pub_from: PublishFrom) -> None:
  125. """
  126. Publish error
  127. :param e: error
  128. :param pub_from: publish from
  129. :return:
  130. """
  131. self.publish(QueueErrorEvent(
  132. error=e
  133. ), pub_from)
  134. self.stop_listen()
  135. def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
  136. """
  137. Publish event to queue
  138. :param event:
  139. :param pub_from:
  140. :return:
  141. """
  142. self._check_for_sqlalchemy_models(event.dict())
  143. message = QueueMessage(
  144. task_id=self._task_id,
  145. message_id=self._message_id,
  146. conversation_id=self._conversation_id,
  147. app_mode=self._app_mode,
  148. event=event
  149. )
  150. self._q.put(message)
  151. if isinstance(event, QueueStopEvent):
  152. self.stop_listen()
  153. if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
  154. raise ConversationTaskStoppedException()
  155. @classmethod
  156. def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
  157. """
  158. Set task stop flag
  159. :return:
  160. """
  161. result = redis_client.get(cls._generate_task_belong_cache_key(task_id))
  162. if result is None:
  163. return
  164. user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
  165. if result.decode('utf-8') != f"{user_prefix}-{user_id}":
  166. return
  167. stopped_cache_key = cls._generate_stopped_cache_key(task_id)
  168. redis_client.setex(stopped_cache_key, 600, 1)
  169. def _is_stopped(self) -> bool:
  170. """
  171. Check if task is stopped
  172. :return:
  173. """
  174. stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id)
  175. result = redis_client.get(stopped_cache_key)
  176. if result is not None:
  177. return True
  178. return False
  179. @classmethod
  180. def _generate_task_belong_cache_key(cls, task_id: str) -> str:
  181. """
  182. Generate task belong cache key
  183. :param task_id: task id
  184. :return:
  185. """
  186. return f"generate_task_belong:{task_id}"
  187. @classmethod
  188. def _generate_stopped_cache_key(cls, task_id: str) -> str:
  189. """
  190. Generate stopped cache key
  191. :param task_id: task id
  192. :return:
  193. """
  194. return f"generate_task_stopped:{task_id}"
  195. def _check_for_sqlalchemy_models(self, data: Any):
  196. # from entity to dict or list
  197. if isinstance(data, dict):
  198. for key, value in data.items():
  199. self._check_for_sqlalchemy_models(value)
  200. elif isinstance(data, list):
  201. for item in data:
  202. self._check_for_sqlalchemy_models(item)
  203. else:
  204. if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'):
  205. raise TypeError("Critical Error: Passing SQLAlchemy Model instances "
  206. "that cause thread safety issues is not allowed.")
  207. class ConversationTaskStoppedException(Exception):
  208. pass