generate_task_pipeline.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645
  1. import json
  2. import logging
  3. import time
  4. from collections.abc import Generator
  5. from typing import Optional, Union, cast
  6. from pydantic import BaseModel
  7. from core.app_runner.moderation_handler import ModerationRule, OutputModerationHandler
  8. from core.application_queue_manager import ApplicationQueueManager, PublishFrom
  9. from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom
  10. from core.entities.queue_entities import (
  11. AnnotationReplyEvent,
  12. QueueAgentMessageEvent,
  13. QueueAgentThoughtEvent,
  14. QueueErrorEvent,
  15. QueueMessageEndEvent,
  16. QueueMessageEvent,
  17. QueueMessageFileEvent,
  18. QueueMessageReplaceEvent,
  19. QueuePingEvent,
  20. QueueRetrieverResourcesEvent,
  21. QueueStopEvent,
  22. )
  23. from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
  24. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
  25. from core.model_runtime.entities.message_entities import (
  26. AssistantPromptMessage,
  27. ImagePromptMessageContent,
  28. PromptMessage,
  29. PromptMessageContentType,
  30. PromptMessageRole,
  31. TextPromptMessageContent,
  32. )
  33. from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
  34. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  35. from core.model_runtime.utils.encoders import jsonable_encoder
  36. from core.prompt.prompt_template import PromptTemplateParser
  37. from core.tools.tool_file_manager import ToolFileManager
  38. from events.message_event import message_was_created
  39. from extensions.ext_database import db
  40. from models.model import Conversation, Message, MessageAgentThought, MessageFile
  41. from services.annotation_service import AppAnnotationService
  42. logger = logging.getLogger(__name__)
  43. class TaskState(BaseModel):
  44. """
  45. TaskState entity
  46. """
  47. llm_result: LLMResult
  48. metadata: dict = {}
  49. class GenerateTaskPipeline:
  50. """
  51. GenerateTaskPipeline is a class that generate stream output and state management for Application.
  52. """
  53. def __init__(self, application_generate_entity: ApplicationGenerateEntity,
  54. queue_manager: ApplicationQueueManager,
  55. conversation: Conversation,
  56. message: Message) -> None:
  57. """
  58. Initialize GenerateTaskPipeline.
  59. :param application_generate_entity: application generate entity
  60. :param queue_manager: queue manager
  61. :param conversation: conversation
  62. :param message: message
  63. """
  64. self._application_generate_entity = application_generate_entity
  65. self._queue_manager = queue_manager
  66. self._conversation = conversation
  67. self._message = message
  68. self._task_state = TaskState(
  69. llm_result=LLMResult(
  70. model=self._application_generate_entity.app_orchestration_config_entity.model_config.model,
  71. prompt_messages=[],
  72. message=AssistantPromptMessage(content=""),
  73. usage=LLMUsage.empty_usage()
  74. )
  75. )
  76. self._start_at = time.perf_counter()
  77. self._output_moderation_handler = self._init_output_moderation()
  78. def process(self, stream: bool) -> Union[dict, Generator]:
  79. """
  80. Process generate task pipeline.
  81. :return:
  82. """
  83. if stream:
  84. return self._process_stream_response()
  85. else:
  86. return self._process_blocking_response()
  87. def _process_blocking_response(self) -> dict:
  88. """
  89. Process blocking response.
  90. :return:
  91. """
  92. for queue_message in self._queue_manager.listen():
  93. event = queue_message.event
  94. if isinstance(event, QueueErrorEvent):
  95. raise self._handle_error(event)
  96. elif isinstance(event, QueueRetrieverResourcesEvent):
  97. self._task_state.metadata['retriever_resources'] = event.retriever_resources
  98. elif isinstance(event, AnnotationReplyEvent):
  99. annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
  100. if annotation:
  101. account = annotation.account
  102. self._task_state.metadata['annotation_reply'] = {
  103. 'id': annotation.id,
  104. 'account': {
  105. 'id': annotation.account_id,
  106. 'name': account.name if account else 'Dify user'
  107. }
  108. }
  109. self._task_state.llm_result.message.content = annotation.content
  110. elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
  111. if isinstance(event, QueueMessageEndEvent):
  112. self._task_state.llm_result = event.llm_result
  113. else:
  114. model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
  115. model = model_config.model
  116. model_type_instance = model_config.provider_model_bundle.model_type_instance
  117. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  118. # calculate num tokens
  119. prompt_tokens = 0
  120. if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
  121. prompt_tokens = model_type_instance.get_num_tokens(
  122. model,
  123. model_config.credentials,
  124. self._task_state.llm_result.prompt_messages
  125. )
  126. completion_tokens = 0
  127. if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
  128. completion_tokens = model_type_instance.get_num_tokens(
  129. model,
  130. model_config.credentials,
  131. [self._task_state.llm_result.message]
  132. )
  133. credentials = model_config.credentials
  134. # transform usage
  135. self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
  136. model,
  137. credentials,
  138. prompt_tokens,
  139. completion_tokens
  140. )
  141. self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
  142. # response moderation
  143. if self._output_moderation_handler:
  144. self._output_moderation_handler.stop_thread()
  145. self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
  146. completion=self._task_state.llm_result.message.content,
  147. public_event=False
  148. )
  149. # Save message
  150. self._save_message(self._task_state.llm_result)
  151. response = {
  152. 'event': 'message',
  153. 'task_id': self._application_generate_entity.task_id,
  154. 'id': self._message.id,
  155. 'message_id': self._message.id,
  156. 'mode': self._conversation.mode,
  157. 'answer': event.llm_result.message.content,
  158. 'metadata': {},
  159. 'created_at': int(self._message.created_at.timestamp())
  160. }
  161. if self._conversation.mode == 'chat':
  162. response['conversation_id'] = self._conversation.id
  163. if self._task_state.metadata:
  164. response['metadata'] = self._get_response_metadata()
  165. return response
  166. else:
  167. continue
  168. def _process_stream_response(self) -> Generator:
  169. """
  170. Process stream response.
  171. :return:
  172. """
  173. for message in self._queue_manager.listen():
  174. event = message.event
  175. if isinstance(event, QueueErrorEvent):
  176. data = self._error_to_stream_response_data(self._handle_error(event))
  177. yield self._yield_response(data)
  178. break
  179. elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
  180. if isinstance(event, QueueMessageEndEvent):
  181. self._task_state.llm_result = event.llm_result
  182. else:
  183. model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
  184. model = model_config.model
  185. model_type_instance = model_config.provider_model_bundle.model_type_instance
  186. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  187. # calculate num tokens
  188. prompt_tokens = 0
  189. if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
  190. prompt_tokens = model_type_instance.get_num_tokens(
  191. model,
  192. model_config.credentials,
  193. self._task_state.llm_result.prompt_messages
  194. )
  195. completion_tokens = 0
  196. if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
  197. completion_tokens = model_type_instance.get_num_tokens(
  198. model,
  199. model_config.credentials,
  200. [self._task_state.llm_result.message]
  201. )
  202. credentials = model_config.credentials
  203. # transform usage
  204. self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
  205. model,
  206. credentials,
  207. prompt_tokens,
  208. completion_tokens
  209. )
  210. self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
  211. # response moderation
  212. if self._output_moderation_handler:
  213. self._output_moderation_handler.stop_thread()
  214. self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
  215. completion=self._task_state.llm_result.message.content,
  216. public_event=False
  217. )
  218. self._output_moderation_handler = None
  219. replace_response = {
  220. 'event': 'message_replace',
  221. 'task_id': self._application_generate_entity.task_id,
  222. 'message_id': self._message.id,
  223. 'answer': self._task_state.llm_result.message.content,
  224. 'created_at': int(self._message.created_at.timestamp())
  225. }
  226. if self._conversation.mode == 'chat':
  227. replace_response['conversation_id'] = self._conversation.id
  228. yield self._yield_response(replace_response)
  229. # Save message
  230. self._save_message(self._task_state.llm_result)
  231. response = {
  232. 'event': 'message_end',
  233. 'task_id': self._application_generate_entity.task_id,
  234. 'id': self._message.id,
  235. 'message_id': self._message.id,
  236. }
  237. if self._conversation.mode == 'chat':
  238. response['conversation_id'] = self._conversation.id
  239. if self._task_state.metadata:
  240. response['metadata'] = self._get_response_metadata()
  241. yield self._yield_response(response)
  242. elif isinstance(event, QueueRetrieverResourcesEvent):
  243. self._task_state.metadata['retriever_resources'] = event.retriever_resources
  244. elif isinstance(event, AnnotationReplyEvent):
  245. annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
  246. if annotation:
  247. account = annotation.account
  248. self._task_state.metadata['annotation_reply'] = {
  249. 'id': annotation.id,
  250. 'account': {
  251. 'id': annotation.account_id,
  252. 'name': account.name if account else 'Dify user'
  253. }
  254. }
  255. self._task_state.llm_result.message.content = annotation.content
  256. elif isinstance(event, QueueAgentThoughtEvent):
  257. agent_thought: MessageAgentThought = (
  258. db.session.query(MessageAgentThought)
  259. .filter(MessageAgentThought.id == event.agent_thought_id)
  260. .first()
  261. )
  262. db.session.refresh(agent_thought)
  263. if agent_thought:
  264. response = {
  265. 'event': 'agent_thought',
  266. 'id': agent_thought.id,
  267. 'task_id': self._application_generate_entity.task_id,
  268. 'message_id': self._message.id,
  269. 'position': agent_thought.position,
  270. 'thought': agent_thought.thought,
  271. 'observation': agent_thought.observation,
  272. 'tool': agent_thought.tool,
  273. 'tool_labels': agent_thought.tool_labels,
  274. 'tool_input': agent_thought.tool_input,
  275. 'created_at': int(self._message.created_at.timestamp()),
  276. 'message_files': agent_thought.files
  277. }
  278. if self._conversation.mode == 'chat':
  279. response['conversation_id'] = self._conversation.id
  280. yield self._yield_response(response)
  281. elif isinstance(event, QueueMessageFileEvent):
  282. message_file: MessageFile = (
  283. db.session.query(MessageFile)
  284. .filter(MessageFile.id == event.message_file_id)
  285. .first()
  286. )
  287. # get extension
  288. if '.' in message_file.url:
  289. extension = f'.{message_file.url.split(".")[-1]}'
  290. if len(extension) > 10:
  291. extension = '.bin'
  292. else:
  293. extension = '.bin'
  294. # add sign url
  295. url = ToolFileManager.sign_file(file_id=message_file.id, extension=extension)
  296. if message_file:
  297. response = {
  298. 'event': 'message_file',
  299. 'id': message_file.id,
  300. 'type': message_file.type,
  301. 'belongs_to': message_file.belongs_to or 'user',
  302. 'url': url
  303. }
  304. if self._conversation.mode == 'chat':
  305. response['conversation_id'] = self._conversation.id
  306. yield self._yield_response(response)
  307. elif isinstance(event, QueueMessageEvent | QueueAgentMessageEvent):
  308. chunk = event.chunk
  309. delta_text = chunk.delta.message.content
  310. if delta_text is None:
  311. continue
  312. if not self._task_state.llm_result.prompt_messages:
  313. self._task_state.llm_result.prompt_messages = chunk.prompt_messages
  314. if self._output_moderation_handler:
  315. if self._output_moderation_handler.should_direct_output():
  316. # stop subscribe new token when output moderation should direct output
  317. self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
  318. self._queue_manager.publish_chunk_message(LLMResultChunk(
  319. model=self._task_state.llm_result.model,
  320. prompt_messages=self._task_state.llm_result.prompt_messages,
  321. delta=LLMResultChunkDelta(
  322. index=0,
  323. message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
  324. )
  325. ), PublishFrom.TASK_PIPELINE)
  326. self._queue_manager.publish(
  327. QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
  328. PublishFrom.TASK_PIPELINE
  329. )
  330. continue
  331. else:
  332. self._output_moderation_handler.append_new_token(delta_text)
  333. self._task_state.llm_result.message.content += delta_text
  334. response = self._handle_chunk(delta_text, agent=isinstance(event, QueueAgentMessageEvent))
  335. yield self._yield_response(response)
  336. elif isinstance(event, QueueMessageReplaceEvent):
  337. response = {
  338. 'event': 'message_replace',
  339. 'task_id': self._application_generate_entity.task_id,
  340. 'message_id': self._message.id,
  341. 'answer': event.text,
  342. 'created_at': int(self._message.created_at.timestamp())
  343. }
  344. if self._conversation.mode == 'chat':
  345. response['conversation_id'] = self._conversation.id
  346. yield self._yield_response(response)
  347. elif isinstance(event, QueuePingEvent):
  348. yield "event: ping\n\n"
  349. else:
  350. continue
  351. def _save_message(self, llm_result: LLMResult) -> None:
  352. """
  353. Save message.
  354. :param llm_result: llm result
  355. :return:
  356. """
  357. usage = llm_result.usage
  358. self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
  359. self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
  360. self._message.message_tokens = usage.prompt_tokens
  361. self._message.message_unit_price = usage.prompt_unit_price
  362. self._message.message_price_unit = usage.prompt_price_unit
  363. self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
  364. if llm_result.message.content else ''
  365. self._message.answer_tokens = usage.completion_tokens
  366. self._message.answer_unit_price = usage.completion_unit_price
  367. self._message.answer_price_unit = usage.completion_price_unit
  368. self._message.provider_response_latency = time.perf_counter() - self._start_at
  369. self._message.total_price = usage.total_price
  370. db.session.commit()
  371. message_was_created.send(
  372. self._message,
  373. application_generate_entity=self._application_generate_entity,
  374. conversation=self._conversation,
  375. is_first_message=self._application_generate_entity.conversation_id is None,
  376. extras=self._application_generate_entity.extras
  377. )
  378. def _handle_chunk(self, text: str, agent: bool = False) -> dict:
  379. """
  380. Handle completed event.
  381. :param text: text
  382. :return:
  383. """
  384. response = {
  385. 'event': 'message' if not agent else 'agent_message',
  386. 'id': self._message.id,
  387. 'task_id': self._application_generate_entity.task_id,
  388. 'message_id': self._message.id,
  389. 'answer': text,
  390. 'created_at': int(self._message.created_at.timestamp())
  391. }
  392. if self._conversation.mode == 'chat':
  393. response['conversation_id'] = self._conversation.id
  394. return response
  395. def _handle_error(self, event: QueueErrorEvent) -> Exception:
  396. """
  397. Handle error event.
  398. :param event: event
  399. :return:
  400. """
  401. logger.debug("error: %s", event.error)
  402. e = event.error
  403. if isinstance(e, InvokeAuthorizationError):
  404. return InvokeAuthorizationError('Incorrect API key provided')
  405. elif isinstance(e, InvokeError) or isinstance(e, ValueError):
  406. return e
  407. else:
  408. return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
  409. def _error_to_stream_response_data(self, e: Exception) -> dict:
  410. """
  411. Error to stream response.
  412. :param e: exception
  413. :return:
  414. """
  415. error_responses = {
  416. ValueError: {'code': 'invalid_param', 'status': 400},
  417. ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400},
  418. QuotaExceededError: {
  419. 'code': 'provider_quota_exceeded',
  420. 'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
  421. "Please go to Settings -> Model Provider to complete your own provider credentials.",
  422. 'status': 400
  423. },
  424. ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400},
  425. InvokeError: {'code': 'completion_request_error', 'status': 400}
  426. }
  427. # Determine the response based on the type of exception
  428. data = None
  429. for k, v in error_responses.items():
  430. if isinstance(e, k):
  431. data = v
  432. if data:
  433. data.setdefault('message', getattr(e, 'description', str(e)))
  434. else:
  435. logging.error(e)
  436. data = {
  437. 'code': 'internal_server_error',
  438. 'message': 'Internal Server Error, please contact support.',
  439. 'status': 500
  440. }
  441. return {
  442. 'event': 'error',
  443. 'task_id': self._application_generate_entity.task_id,
  444. 'message_id': self._message.id,
  445. **data
  446. }
  447. def _get_response_metadata(self) -> dict:
  448. """
  449. Get response metadata by invoke from.
  450. :return:
  451. """
  452. metadata = {}
  453. # show_retrieve_source
  454. if 'retriever_resources' in self._task_state.metadata:
  455. if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
  456. metadata['retriever_resources'] = self._task_state.metadata['retriever_resources']
  457. else:
  458. metadata['retriever_resources'] = []
  459. for resource in self._task_state.metadata['retriever_resources']:
  460. metadata['retriever_resources'].append({
  461. 'segment_id': resource['segment_id'],
  462. 'position': resource['position'],
  463. 'document_name': resource['document_name'],
  464. 'score': resource['score'],
  465. 'content': resource['content'],
  466. })
  467. # show annotation reply
  468. if 'annotation_reply' in self._task_state.metadata:
  469. if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
  470. metadata['annotation_reply'] = self._task_state.metadata['annotation_reply']
  471. # show usage
  472. if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
  473. metadata['usage'] = self._task_state.metadata['usage']
  474. return metadata
  475. def _yield_response(self, response: dict) -> str:
  476. """
  477. Yield response.
  478. :param response: response
  479. :return:
  480. """
  481. return "data: " + json.dumps(response) + "\n\n"
  482. def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]:
  483. """
  484. Prompt messages to prompt for saving.
  485. :param prompt_messages: prompt messages
  486. :return:
  487. """
  488. prompts = []
  489. if self._application_generate_entity.app_orchestration_config_entity.model_config.mode == 'chat':
  490. for prompt_message in prompt_messages:
  491. if prompt_message.role == PromptMessageRole.USER:
  492. role = 'user'
  493. elif prompt_message.role == PromptMessageRole.ASSISTANT:
  494. role = 'assistant'
  495. elif prompt_message.role == PromptMessageRole.SYSTEM:
  496. role = 'system'
  497. else:
  498. continue
  499. text = ''
  500. files = []
  501. if isinstance(prompt_message.content, list):
  502. for content in prompt_message.content:
  503. if content.type == PromptMessageContentType.TEXT:
  504. content = cast(TextPromptMessageContent, content)
  505. text += content.data
  506. else:
  507. content = cast(ImagePromptMessageContent, content)
  508. files.append({
  509. "type": 'image',
  510. "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
  511. "detail": content.detail.value
  512. })
  513. else:
  514. text = prompt_message.content
  515. prompts.append({
  516. "role": role,
  517. "text": text,
  518. "files": files
  519. })
  520. else:
  521. prompt_message = prompt_messages[0]
  522. text = ''
  523. files = []
  524. if isinstance(prompt_message.content, list):
  525. for content in prompt_message.content:
  526. if content.type == PromptMessageContentType.TEXT:
  527. content = cast(TextPromptMessageContent, content)
  528. text += content.data
  529. else:
  530. content = cast(ImagePromptMessageContent, content)
  531. files.append({
  532. "type": 'image',
  533. "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
  534. "detail": content.detail.value
  535. })
  536. else:
  537. text = prompt_message.content
  538. params = {
  539. "role": 'user',
  540. "text": text,
  541. }
  542. if files:
  543. params['files'] = files
  544. prompts.append(params)
  545. return prompts
  546. def _init_output_moderation(self) -> Optional[OutputModerationHandler]:
  547. """
  548. Init output moderation.
  549. :return:
  550. """
  551. app_orchestration_config_entity = self._application_generate_entity.app_orchestration_config_entity
  552. sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance
  553. if sensitive_word_avoidance:
  554. return OutputModerationHandler(
  555. tenant_id=self._application_generate_entity.tenant_id,
  556. app_id=self._application_generate_entity.app_id,
  557. rule=ModerationRule(
  558. type=sensitive_word_avoidance.type,
  559. config=sensitive_word_avoidance.config
  560. ),
  561. on_message_replace_func=self._queue_manager.publish_message_replace
  562. )