|
@@ -5,16 +5,18 @@ from typing import Generator, Optional, Union, cast
|
|
|
|
|
|
from core.app_runner.moderation_handler import ModerationRule, OutputModerationHandler
|
|
|
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
|
|
-from core.entities.application_entities import ApplicationGenerateEntity
|
|
|
+from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom
|
|
|
from core.entities.queue_entities import (AnnotationReplyEvent, QueueAgentThoughtEvent, QueueErrorEvent,
|
|
|
QueueMessageEndEvent, QueueMessageEvent, QueueMessageReplaceEvent,
|
|
|
QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent)
|
|
|
+from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
|
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
|
|
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent,
|
|
|
PromptMessage, PromptMessageContentType, PromptMessageRole,
|
|
|
TextPromptMessageContent)
|
|
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
|
+from core.model_runtime.utils.encoders import jsonable_encoder
|
|
|
from core.prompt.prompt_template import PromptTemplateParser
|
|
|
from events.message_event import message_was_created
|
|
|
from extensions.ext_database import db
|
|
@@ -135,6 +137,8 @@ class GenerateTaskPipeline:
|
|
|
completion_tokens
|
|
|
)
|
|
|
|
|
|
+ self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
|
|
|
+
|
|
|
# response moderation
|
|
|
if self._output_moderation_handler:
|
|
|
self._output_moderation_handler.stop_thread()
|
|
@@ -145,12 +149,13 @@ class GenerateTaskPipeline:
|
|
|
)
|
|
|
|
|
|
# Save message
|
|
|
- self._save_message(event.llm_result)
|
|
|
+ self._save_message(self._task_state.llm_result)
|
|
|
|
|
|
response = {
|
|
|
'event': 'message',
|
|
|
'task_id': self._application_generate_entity.task_id,
|
|
|
'id': self._message.id,
|
|
|
+ 'message_id': self._message.id,
|
|
|
'mode': self._conversation.mode,
|
|
|
'answer': event.llm_result.message.content,
|
|
|
'metadata': {},
|
|
@@ -161,7 +166,7 @@ class GenerateTaskPipeline:
|
|
|
response['conversation_id'] = self._conversation.id
|
|
|
|
|
|
if self._task_state.metadata:
|
|
|
- response['metadata'] = self._task_state.metadata
|
|
|
+ response['metadata'] = self._get_response_metadata()
|
|
|
|
|
|
return response
|
|
|
else:
|
|
@@ -176,7 +181,9 @@ class GenerateTaskPipeline:
|
|
|
event = message.event
|
|
|
|
|
|
if isinstance(event, QueueErrorEvent):
|
|
|
- raise self._handle_error(event)
|
|
|
+ data = self._error_to_stream_response_data(self._handle_error(event))
|
|
|
+ yield self._yield_response(data)
|
|
|
+ break
|
|
|
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
|
|
|
if isinstance(event, QueueMessageEndEvent):
|
|
|
self._task_state.llm_result = event.llm_result
|
|
@@ -213,6 +220,8 @@ class GenerateTaskPipeline:
|
|
|
completion_tokens
|
|
|
)
|
|
|
|
|
|
+ self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
|
|
|
+
|
|
|
# response moderation
|
|
|
if self._output_moderation_handler:
|
|
|
self._output_moderation_handler.stop_thread()
|
|
@@ -244,13 +253,14 @@ class GenerateTaskPipeline:
|
|
|
'event': 'message_end',
|
|
|
'task_id': self._application_generate_entity.task_id,
|
|
|
'id': self._message.id,
|
|
|
+ 'message_id': self._message.id,
|
|
|
}
|
|
|
|
|
|
if self._conversation.mode == 'chat':
|
|
|
response['conversation_id'] = self._conversation.id
|
|
|
|
|
|
if self._task_state.metadata:
|
|
|
- response['metadata'] = self._task_state.metadata
|
|
|
+ response['metadata'] = self._get_response_metadata()
|
|
|
|
|
|
yield self._yield_response(response)
|
|
|
elif isinstance(event, QueueRetrieverResourcesEvent):
|
|
@@ -410,6 +420,86 @@ class GenerateTaskPipeline:
|
|
|
else:
|
|
|
return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
|
|
|
|
|
|
+ def _error_to_stream_response_data(self, e: Exception) -> dict:
|
|
|
+ """
|
|
|
+ Error to stream response.
|
|
|
+ :param e: exception
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ if isinstance(e, ValueError):
|
|
|
+ data = {
|
|
|
+ 'code': 'invalid_param',
|
|
|
+ 'message': str(e),
|
|
|
+ 'status': 400
|
|
|
+ }
|
|
|
+ elif isinstance(e, ProviderTokenNotInitError):
|
|
|
+ data = {
|
|
|
+ 'code': 'provider_not_initialize',
|
|
|
+ 'message': e.description,
|
|
|
+ 'status': 400
|
|
|
+ }
|
|
|
+ elif isinstance(e, QuotaExceededError):
|
|
|
+ data = {
|
|
|
+ 'code': 'provider_quota_exceeded',
|
|
|
+ 'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
|
|
|
+ "Please go to Settings -> Model Provider to complete your own provider credentials.",
|
|
|
+ 'status': 400
|
|
|
+ }
|
|
|
+ elif isinstance(e, ModelCurrentlyNotSupportError):
|
|
|
+ data = {
|
|
|
+ 'code': 'model_currently_not_support',
|
|
|
+ 'message': e.description,
|
|
|
+ 'status': 400
|
|
|
+ }
|
|
|
+ elif isinstance(e, InvokeError):
|
|
|
+ data = {
|
|
|
+ 'code': 'completion_request_error',
|
|
|
+ 'message': e.description,
|
|
|
+ 'status': 400
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ logging.error(e)
|
|
|
+ data = {
|
|
|
+ 'code': 'internal_server_error',
|
|
|
+ 'message': 'Internal Server Error, please contact support.',
|
|
|
+ 'status': 500
|
|
|
+ }
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'event': 'error',
|
|
|
+ 'task_id': self._application_generate_entity.task_id,
|
|
|
+ 'message_id': self._message.id,
|
|
|
+ **data
|
|
|
+ }
|
|
|
+
|
|
|
+ def _get_response_metadata(self) -> dict:
|
|
|
+ """
|
|
|
+ Get response metadata by invoke from.
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ metadata = {}
|
|
|
+
|
|
|
+ # show_retrieve_source
|
|
|
+ if 'retriever_resources' in self._task_state.metadata:
|
|
|
+ if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
|
|
+ metadata['retriever_resources'] = self._task_state.metadata['retriever_resources']
|
|
|
+ else:
|
|
|
+ metadata['retriever_resources'] = []
|
|
|
+ for resource in self._task_state.metadata['retriever_resources']:
|
|
|
+ metadata['retriever_resources'].append({
|
|
|
+ 'segment_id': resource['segment_id'],
|
|
|
+ 'position': resource['position'],
|
|
|
+ 'document_name': resource['document_name'],
|
|
|
+ 'score': resource['score'],
|
|
|
+ 'content': resource['content'],
|
|
|
+ })
|
|
|
+
|
|
|
+ # show usage
|
|
|
+ if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
|
|
+ metadata['usage'] = self._task_state.metadata['usage']
|
|
|
+
|
|
|
+ return metadata
|
|
|
+
|
|
|
def _yield_response(self, response: dict) -> str:
|
|
|
"""
|
|
|
Yield response.
|