import json import logging import time from collections.abc import Generator from typing import Optional, Union, cast from pydantic import BaseModel from core.app_runner.moderation_handler import ModerationRule, OutputModerationHandler from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom from core.entities.queue_entities import ( AnnotationReplyEvent, QueueAgentMessageEvent, QueueAgentThoughtEvent, QueueErrorEvent, QueueMessageEndEvent, QueueMessageEvent, QueueMessageFileEvent, QueueMessageReplaceEvent, QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, PromptMessageContentType, PromptMessageRole, TextPromptMessageContent, ) from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.prompt_template import PromptTemplateParser from core.tools.tool_file_manager import ToolFileManager from events.message_event import message_was_created from extensions.ext_database import db from models.model import Conversation, Message, MessageAgentThought, MessageFile from services.annotation_service import AppAnnotationService logger = logging.getLogger(__name__) class TaskState(BaseModel): """ TaskState entity """ llm_result: LLMResult metadata: dict = {} class GenerateTaskPipeline: """ GenerateTaskPipeline is a class that generate stream output and state management for Application. """ def __init__(self, application_generate_entity: ApplicationGenerateEntity, queue_manager: ApplicationQueueManager, conversation: Conversation, message: Message) -> None: """ Initialize GenerateTaskPipeline. :param application_generate_entity: application generate entity :param queue_manager: queue manager :param conversation: conversation :param message: message """ self._application_generate_entity = application_generate_entity self._queue_manager = queue_manager self._conversation = conversation self._message = message self._task_state = TaskState( llm_result=LLMResult( model=self._application_generate_entity.app_orchestration_config_entity.model_config.model, prompt_messages=[], message=AssistantPromptMessage(content=""), usage=LLMUsage.empty_usage() ) ) self._start_at = time.perf_counter() self._output_moderation_handler = self._init_output_moderation() def process(self, stream: bool) -> Union[dict, Generator]: """ Process generate task pipeline. :return: """ if stream: return self._process_stream_response() else: return self._process_blocking_response() def _process_blocking_response(self) -> dict: """ Process blocking response. :return: """ for queue_message in self._queue_manager.listen(): event = queue_message.event if isinstance(event, QueueErrorEvent): raise self._handle_error(event) elif isinstance(event, QueueRetrieverResourcesEvent): self._task_state.metadata['retriever_resources'] = event.retriever_resources elif isinstance(event, AnnotationReplyEvent): annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) if annotation: account = annotation.account self._task_state.metadata['annotation_reply'] = { 'id': annotation.id, 'account': { 'id': annotation.account_id, 'name': account.name if account else 'Dify user' } } self._task_state.llm_result.message.content = annotation.content elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): if isinstance(event, QueueMessageEndEvent): self._task_state.llm_result = event.llm_result else: model_config = self._application_generate_entity.app_orchestration_config_entity.model_config model = model_config.model model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) # calculate num tokens prompt_tokens = 0 if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY: prompt_tokens = model_type_instance.get_num_tokens( model, model_config.credentials, self._task_state.llm_result.prompt_messages ) completion_tokens = 0 if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL: completion_tokens = model_type_instance.get_num_tokens( model, model_config.credentials, [self._task_state.llm_result.message] ) credentials = model_config.credentials # transform usage self._task_state.llm_result.usage = model_type_instance._calc_response_usage( model, credentials, prompt_tokens, completion_tokens ) self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage) # response moderation if self._output_moderation_handler: self._output_moderation_handler.stop_thread() self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion( completion=self._task_state.llm_result.message.content, public_event=False ) # Save message self._save_message(self._task_state.llm_result) response = { 'event': 'message', 'task_id': self._application_generate_entity.task_id, 'id': self._message.id, 'message_id': self._message.id, 'mode': self._conversation.mode, 'answer': event.llm_result.message.content, 'metadata': {}, 'created_at': int(self._message.created_at.timestamp()) } if self._conversation.mode == 'chat': response['conversation_id'] = self._conversation.id if self._task_state.metadata: response['metadata'] = self._get_response_metadata() return response else: continue def _process_stream_response(self) -> Generator: """ Process stream response. :return: """ for message in self._queue_manager.listen(): event = message.event if isinstance(event, QueueErrorEvent): data = self._error_to_stream_response_data(self._handle_error(event)) yield self._yield_response(data) break elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): if isinstance(event, QueueMessageEndEvent): self._task_state.llm_result = event.llm_result else: model_config = self._application_generate_entity.app_orchestration_config_entity.model_config model = model_config.model model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) # calculate num tokens prompt_tokens = 0 if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY: prompt_tokens = model_type_instance.get_num_tokens( model, model_config.credentials, self._task_state.llm_result.prompt_messages ) completion_tokens = 0 if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL: completion_tokens = model_type_instance.get_num_tokens( model, model_config.credentials, [self._task_state.llm_result.message] ) credentials = model_config.credentials # transform usage self._task_state.llm_result.usage = model_type_instance._calc_response_usage( model, credentials, prompt_tokens, completion_tokens ) self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage) # response moderation if self._output_moderation_handler: self._output_moderation_handler.stop_thread() self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion( completion=self._task_state.llm_result.message.content, public_event=False ) self._output_moderation_handler = None replace_response = { 'event': 'message_replace', 'task_id': self._application_generate_entity.task_id, 'message_id': self._message.id, 'answer': self._task_state.llm_result.message.content, 'created_at': int(self._message.created_at.timestamp()) } if self._conversation.mode == 'chat': replace_response['conversation_id'] = self._conversation.id yield self._yield_response(replace_response) # Save message self._save_message(self._task_state.llm_result) response = { 'event': 'message_end', 'task_id': self._application_generate_entity.task_id, 'id': self._message.id, 'message_id': self._message.id, } if self._conversation.mode == 'chat': response['conversation_id'] = self._conversation.id if self._task_state.metadata: response['metadata'] = self._get_response_metadata() yield self._yield_response(response) elif isinstance(event, QueueRetrieverResourcesEvent): self._task_state.metadata['retriever_resources'] = event.retriever_resources elif isinstance(event, AnnotationReplyEvent): annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) if annotation: account = annotation.account self._task_state.metadata['annotation_reply'] = { 'id': annotation.id, 'account': { 'id': annotation.account_id, 'name': account.name if account else 'Dify user' } } self._task_state.llm_result.message.content = annotation.content elif isinstance(event, QueueAgentThoughtEvent): agent_thought: MessageAgentThought = ( db.session.query(MessageAgentThought) .filter(MessageAgentThought.id == event.agent_thought_id) .first() ) db.session.refresh(agent_thought) if agent_thought: response = { 'event': 'agent_thought', 'id': agent_thought.id, 'task_id': self._application_generate_entity.task_id, 'message_id': self._message.id, 'position': agent_thought.position, 'thought': agent_thought.thought, 'observation': agent_thought.observation, 'tool': agent_thought.tool, 'tool_labels': agent_thought.tool_labels, 'tool_input': agent_thought.tool_input, 'created_at': int(self._message.created_at.timestamp()), 'message_files': agent_thought.files } if self._conversation.mode == 'chat': response['conversation_id'] = self._conversation.id yield self._yield_response(response) elif isinstance(event, QueueMessageFileEvent): message_file: MessageFile = ( db.session.query(MessageFile) .filter(MessageFile.id == event.message_file_id) .first() ) # get extension if '.' in message_file.url: extension = f'.{message_file.url.split(".")[-1]}' if len(extension) > 10: extension = '.bin' else: extension = '.bin' # add sign url url = ToolFileManager.sign_file(file_id=message_file.id, extension=extension) if message_file: response = { 'event': 'message_file', 'id': message_file.id, 'type': message_file.type, 'belongs_to': message_file.belongs_to or 'user', 'url': url } if self._conversation.mode == 'chat': response['conversation_id'] = self._conversation.id yield self._yield_response(response) elif isinstance(event, QueueMessageEvent | QueueAgentMessageEvent): chunk = event.chunk delta_text = chunk.delta.message.content if delta_text is None: continue if not self._task_state.llm_result.prompt_messages: self._task_state.llm_result.prompt_messages = chunk.prompt_messages if self._output_moderation_handler: if self._output_moderation_handler.should_direct_output(): # stop subscribe new token when output moderation should direct output self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output() self._queue_manager.publish_chunk_message(LLMResultChunk( model=self._task_state.llm_result.model, prompt_messages=self._task_state.llm_result.prompt_messages, delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage(content=self._task_state.llm_result.message.content) ) ), PublishFrom.TASK_PIPELINE) self._queue_manager.publish( QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE ) continue else: self._output_moderation_handler.append_new_token(delta_text) self._task_state.llm_result.message.content += delta_text response = self._handle_chunk(delta_text, agent=isinstance(event, QueueAgentMessageEvent)) yield self._yield_response(response) elif isinstance(event, QueueMessageReplaceEvent): response = { 'event': 'message_replace', 'task_id': self._application_generate_entity.task_id, 'message_id': self._message.id, 'answer': event.text, 'created_at': int(self._message.created_at.timestamp()) } if self._conversation.mode == 'chat': response['conversation_id'] = self._conversation.id yield self._yield_response(response) elif isinstance(event, QueuePingEvent): yield "event: ping\n\n" else: continue def _save_message(self, llm_result: LLMResult) -> None: """ Save message. :param llm_result: llm result :return: """ usage = llm_result.usage self._message = db.session.query(Message).filter(Message.id == self._message.id).first() self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages) self._message.message_tokens = usage.prompt_tokens self._message.message_unit_price = usage.prompt_unit_price self._message.message_price_unit = usage.prompt_price_unit self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \ if llm_result.message.content else '' self._message.answer_tokens = usage.completion_tokens self._message.answer_unit_price = usage.completion_unit_price self._message.answer_price_unit = usage.completion_price_unit self._message.provider_response_latency = time.perf_counter() - self._start_at self._message.total_price = usage.total_price db.session.commit() message_was_created.send( self._message, application_generate_entity=self._application_generate_entity, conversation=self._conversation, is_first_message=self._application_generate_entity.conversation_id is None, extras=self._application_generate_entity.extras ) def _handle_chunk(self, text: str, agent: bool = False) -> dict: """ Handle completed event. :param text: text :return: """ response = { 'event': 'message' if not agent else 'agent_message', 'id': self._message.id, 'task_id': self._application_generate_entity.task_id, 'message_id': self._message.id, 'answer': text, 'created_at': int(self._message.created_at.timestamp()) } if self._conversation.mode == 'chat': response['conversation_id'] = self._conversation.id return response def _handle_error(self, event: QueueErrorEvent) -> Exception: """ Handle error event. :param event: event :return: """ logger.debug("error: %s", event.error) e = event.error if isinstance(e, InvokeAuthorizationError): return InvokeAuthorizationError('Incorrect API key provided') elif isinstance(e, InvokeError) or isinstance(e, ValueError): return e else: return Exception(e.description if getattr(e, 'description', None) is not None else str(e)) def _error_to_stream_response_data(self, e: Exception) -> dict: """ Error to stream response. :param e: exception :return: """ error_responses = { ValueError: {'code': 'invalid_param', 'status': 400}, ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400}, QuotaExceededError: { 'code': 'provider_quota_exceeded', 'message': "Your quota for Dify Hosted Model Provider has been exhausted. " "Please go to Settings -> Model Provider to complete your own provider credentials.", 'status': 400 }, ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400}, InvokeError: {'code': 'completion_request_error', 'status': 400} } # Determine the response based on the type of exception data = None for k, v in error_responses.items(): if isinstance(e, k): data = v if data: data.setdefault('message', getattr(e, 'description', str(e))) else: logging.error(e) data = { 'code': 'internal_server_error', 'message': 'Internal Server Error, please contact support.', 'status': 500 } return { 'event': 'error', 'task_id': self._application_generate_entity.task_id, 'message_id': self._message.id, **data } def _get_response_metadata(self) -> dict: """ Get response metadata by invoke from. :return: """ metadata = {} # show_retrieve_source if 'retriever_resources' in self._task_state.metadata: if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: metadata['retriever_resources'] = self._task_state.metadata['retriever_resources'] else: metadata['retriever_resources'] = [] for resource in self._task_state.metadata['retriever_resources']: metadata['retriever_resources'].append({ 'segment_id': resource['segment_id'], 'position': resource['position'], 'document_name': resource['document_name'], 'score': resource['score'], 'content': resource['content'], }) # show annotation reply if 'annotation_reply' in self._task_state.metadata: if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: metadata['annotation_reply'] = self._task_state.metadata['annotation_reply'] # show usage if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: metadata['usage'] = self._task_state.metadata['usage'] return metadata def _yield_response(self, response: dict) -> str: """ Yield response. :param response: response :return: """ return "data: " + json.dumps(response) + "\n\n" def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]: """ Prompt messages to prompt for saving. :param prompt_messages: prompt messages :return: """ prompts = [] if self._application_generate_entity.app_orchestration_config_entity.model_config.mode == 'chat': for prompt_message in prompt_messages: if prompt_message.role == PromptMessageRole.USER: role = 'user' elif prompt_message.role == PromptMessageRole.ASSISTANT: role = 'assistant' elif prompt_message.role == PromptMessageRole.SYSTEM: role = 'system' else: continue text = '' files = [] if isinstance(prompt_message.content, list): for content in prompt_message.content: if content.type == PromptMessageContentType.TEXT: content = cast(TextPromptMessageContent, content) text += content.data else: content = cast(ImagePromptMessageContent, content) files.append({ "type": 'image', "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], "detail": content.detail.value }) else: text = prompt_message.content prompts.append({ "role": role, "text": text, "files": files }) else: prompt_message = prompt_messages[0] text = '' files = [] if isinstance(prompt_message.content, list): for content in prompt_message.content: if content.type == PromptMessageContentType.TEXT: content = cast(TextPromptMessageContent, content) text += content.data else: content = cast(ImagePromptMessageContent, content) files.append({ "type": 'image', "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], "detail": content.detail.value }) else: text = prompt_message.content params = { "role": 'user', "text": text, } if files: params['files'] = files prompts.append(params) return prompts def _init_output_moderation(self) -> Optional[OutputModerationHandler]: """ Init output moderation. :return: """ app_orchestration_config_entity = self._application_generate_entity.app_orchestration_config_entity sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance if sensitive_word_avoidance: return OutputModerationHandler( tenant_id=self._application_generate_entity.tenant_id, app_id=self._application_generate_entity.app_id, rule=ModerationRule( type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config ), on_message_replace_func=self._queue_manager.publish_message_replace )