123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246 |
- import queue
- import time
- from enum import Enum
- from typing import Any, Generator
- from core.entities.application_entities import InvokeFrom
- from core.entities.queue_entities import (AnnotationReplyEvent, AppQueueEvent, QueueAgentThoughtEvent, QueueErrorEvent,
- QueueMessage, QueueMessageEndEvent, QueueMessageEvent,
- QueueMessageReplaceEvent, QueuePingEvent, QueueRetrieverResourcesEvent,
- QueueStopEvent)
- from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
- from extensions.ext_redis import redis_client
- from models.model import MessageAgentThought
- from sqlalchemy.orm import DeclarativeMeta
- class PublishFrom(Enum):
- APPLICATION_MANAGER = 1
- TASK_PIPELINE = 2
- class ApplicationQueueManager:
- def __init__(self, task_id: str,
- user_id: str,
- invoke_from: InvokeFrom,
- conversation_id: str,
- app_mode: str,
- message_id: str) -> None:
- if not user_id:
- raise ValueError("user is required")
- self._task_id = task_id
- self._user_id = user_id
- self._invoke_from = invoke_from
- self._conversation_id = str(conversation_id)
- self._app_mode = app_mode
- self._message_id = str(message_id)
- user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
- redis_client.setex(ApplicationQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}")
- q = queue.Queue()
- self._q = q
- def listen(self) -> Generator:
- """
- Listen to queue
- :return:
- """
- # wait for 10 minutes to stop listen
- listen_timeout = 600
- start_time = time.time()
- last_ping_time = 0
- while True:
- try:
- message = self._q.get(timeout=1)
- if message is None:
- break
- yield message
- except queue.Empty:
- continue
- finally:
- elapsed_time = time.time() - start_time
- if elapsed_time >= listen_timeout or self._is_stopped():
- # publish two messages to make sure the client can receive the stop signal
- # and stop listening after the stop signal processed
- self.publish(
- QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL),
- PublishFrom.TASK_PIPELINE
- )
- self.stop_listen()
- if elapsed_time // 10 > last_ping_time:
- self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
- last_ping_time = elapsed_time // 10
- def stop_listen(self) -> None:
- """
- Stop listen to queue
- :return:
- """
- self._q.put(None)
- def publish_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None:
- """
- Publish chunk message to channel
- :param chunk: chunk
- :param pub_from: publish from
- :return:
- """
- self.publish(QueueMessageEvent(
- chunk=chunk
- ), pub_from)
- def publish_message_replace(self, text: str, pub_from: PublishFrom) -> None:
- """
- Publish message replace
- :param text: text
- :param pub_from: publish from
- :return:
- """
- self.publish(QueueMessageReplaceEvent(
- text=text
- ), pub_from)
- def publish_retriever_resources(self, retriever_resources: list[dict], pub_from: PublishFrom) -> None:
- """
- Publish retriever resources
- :return:
- """
- self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources), pub_from)
- def publish_annotation_reply(self, message_annotation_id: str, pub_from: PublishFrom) -> None:
- """
- Publish annotation reply
- :param message_annotation_id: message annotation id
- :param pub_from: publish from
- :return:
- """
- self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id), pub_from)
- def publish_message_end(self, llm_result: LLMResult, pub_from: PublishFrom) -> None:
- """
- Publish message end
- :param llm_result: llm result
- :param pub_from: publish from
- :return:
- """
- self.publish(QueueMessageEndEvent(llm_result=llm_result), pub_from)
- self.stop_listen()
- def publish_agent_thought(self, message_agent_thought: MessageAgentThought, pub_from: PublishFrom) -> None:
- """
- Publish agent thought
- :param message_agent_thought: message agent thought
- :param pub_from: publish from
- :return:
- """
- self.publish(QueueAgentThoughtEvent(
- agent_thought_id=message_agent_thought.id
- ), pub_from)
- def publish_error(self, e, pub_from: PublishFrom) -> None:
- """
- Publish error
- :param e: error
- :param pub_from: publish from
- :return:
- """
- self.publish(QueueErrorEvent(
- error=e
- ), pub_from)
- self.stop_listen()
- def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
- """
- Publish event to queue
- :param event:
- :param pub_from:
- :return:
- """
- self._check_for_sqlalchemy_models(event.dict())
- message = QueueMessage(
- task_id=self._task_id,
- message_id=self._message_id,
- conversation_id=self._conversation_id,
- app_mode=self._app_mode,
- event=event
- )
- self._q.put(message)
- if isinstance(event, QueueStopEvent):
- self.stop_listen()
- if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
- raise ConversationTaskStoppedException()
- @classmethod
- def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
- """
- Set task stop flag
- :return:
- """
- result = redis_client.get(cls._generate_task_belong_cache_key(task_id))
- if result is None:
- return
- user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
- if result.decode('utf-8') != f"{user_prefix}-{user_id}":
- return
- stopped_cache_key = cls._generate_stopped_cache_key(task_id)
- redis_client.setex(stopped_cache_key, 600, 1)
- def _is_stopped(self) -> bool:
- """
- Check if task is stopped
- :return:
- """
- stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id)
- result = redis_client.get(stopped_cache_key)
- if result is not None:
- return True
- return False
- @classmethod
- def _generate_task_belong_cache_key(cls, task_id: str) -> str:
- """
- Generate task belong cache key
- :param task_id: task id
- :return:
- """
- return f"generate_task_belong:{task_id}"
- @classmethod
- def _generate_stopped_cache_key(cls, task_id: str) -> str:
- """
- Generate stopped cache key
- :param task_id: task id
- :return:
- """
- return f"generate_task_stopped:{task_id}"
- def _check_for_sqlalchemy_models(self, data: Any):
- # from entity to dict or list
- if isinstance(data, dict):
- for key, value in data.items():
- self._check_for_sqlalchemy_models(value)
- elif isinstance(data, list):
- for item in data:
- self._check_for_sqlalchemy_models(item)
- else:
- if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'):
- raise TypeError("Critical Error: Passing SQLAlchemy Model instances "
- "that cause thread safety issues is not allowed.")
- class ConversationTaskStoppedException(Exception):
- pass
|