generate_task_pipeline.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640
  1. import json
  2. import logging
  3. import time
  4. from typing import Generator, Optional, Union, cast
  5. from core.app_runner.moderation_handler import ModerationRule, OutputModerationHandler
  6. from core.application_queue_manager import ApplicationQueueManager, PublishFrom
  7. from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom
  8. from core.entities.queue_entities import (AnnotationReplyEvent, QueueAgentThoughtEvent, QueueErrorEvent,
  9. QueueMessageEndEvent, QueueMessageEvent, QueueMessageReplaceEvent,
  10. QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent,
  11. QueueMessageFileEvent, QueueAgentMessageEvent)
  12. from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
  13. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
  14. from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent,
  15. PromptMessage, PromptMessageContentType, PromptMessageRole,
  16. TextPromptMessageContent)
  17. from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
  18. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  19. from core.tools.tool_file_manager import ToolFileManager
  20. from core.tools.tool_manager import ToolManager
  21. from core.model_runtime.utils.encoders import jsonable_encoder
  22. from core.prompt.prompt_template import PromptTemplateParser
  23. from events.message_event import message_was_created
  24. from extensions.ext_database import db
  25. from models.model import Conversation, Message, MessageAgentThought, MessageFile
  26. from pydantic import BaseModel
  27. from services.annotation_service import AppAnnotationService
  28. logger = logging.getLogger(__name__)
  29. class TaskState(BaseModel):
  30. """
  31. TaskState entity
  32. """
  33. llm_result: LLMResult
  34. metadata: dict = {}
  35. class GenerateTaskPipeline:
  36. """
  37. GenerateTaskPipeline is a class that generate stream output and state management for Application.
  38. """
  39. def __init__(self, application_generate_entity: ApplicationGenerateEntity,
  40. queue_manager: ApplicationQueueManager,
  41. conversation: Conversation,
  42. message: Message) -> None:
  43. """
  44. Initialize GenerateTaskPipeline.
  45. :param application_generate_entity: application generate entity
  46. :param queue_manager: queue manager
  47. :param conversation: conversation
  48. :param message: message
  49. """
  50. self._application_generate_entity = application_generate_entity
  51. self._queue_manager = queue_manager
  52. self._conversation = conversation
  53. self._message = message
  54. self._task_state = TaskState(
  55. llm_result=LLMResult(
  56. model=self._application_generate_entity.app_orchestration_config_entity.model_config.model,
  57. prompt_messages=[],
  58. message=AssistantPromptMessage(content=""),
  59. usage=LLMUsage.empty_usage()
  60. )
  61. )
  62. self._start_at = time.perf_counter()
  63. self._output_moderation_handler = self._init_output_moderation()
  64. def process(self, stream: bool) -> Union[dict, Generator]:
  65. """
  66. Process generate task pipeline.
  67. :return:
  68. """
  69. if stream:
  70. return self._process_stream_response()
  71. else:
  72. return self._process_blocking_response()
  73. def _process_blocking_response(self) -> dict:
  74. """
  75. Process blocking response.
  76. :return:
  77. """
  78. for queue_message in self._queue_manager.listen():
  79. event = queue_message.event
  80. if isinstance(event, QueueErrorEvent):
  81. raise self._handle_error(event)
  82. elif isinstance(event, QueueRetrieverResourcesEvent):
  83. self._task_state.metadata['retriever_resources'] = event.retriever_resources
  84. elif isinstance(event, AnnotationReplyEvent):
  85. annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
  86. if annotation:
  87. account = annotation.account
  88. self._task_state.metadata['annotation_reply'] = {
  89. 'id': annotation.id,
  90. 'account': {
  91. 'id': annotation.account_id,
  92. 'name': account.name if account else 'Dify user'
  93. }
  94. }
  95. self._task_state.llm_result.message.content = annotation.content
  96. elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
  97. if isinstance(event, QueueMessageEndEvent):
  98. self._task_state.llm_result = event.llm_result
  99. else:
  100. model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
  101. model = model_config.model
  102. model_type_instance = model_config.provider_model_bundle.model_type_instance
  103. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  104. # calculate num tokens
  105. prompt_tokens = 0
  106. if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
  107. prompt_tokens = model_type_instance.get_num_tokens(
  108. model,
  109. model_config.credentials,
  110. self._task_state.llm_result.prompt_messages
  111. )
  112. completion_tokens = 0
  113. if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
  114. completion_tokens = model_type_instance.get_num_tokens(
  115. model,
  116. model_config.credentials,
  117. [self._task_state.llm_result.message]
  118. )
  119. credentials = model_config.credentials
  120. # transform usage
  121. self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
  122. model,
  123. credentials,
  124. prompt_tokens,
  125. completion_tokens
  126. )
  127. self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
  128. # response moderation
  129. if self._output_moderation_handler:
  130. self._output_moderation_handler.stop_thread()
  131. self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
  132. completion=self._task_state.llm_result.message.content,
  133. public_event=False
  134. )
  135. # Save message
  136. self._save_message(self._task_state.llm_result)
  137. response = {
  138. 'event': 'message',
  139. 'task_id': self._application_generate_entity.task_id,
  140. 'id': self._message.id,
  141. 'message_id': self._message.id,
  142. 'mode': self._conversation.mode,
  143. 'answer': event.llm_result.message.content,
  144. 'metadata': {},
  145. 'created_at': int(self._message.created_at.timestamp())
  146. }
  147. if self._conversation.mode == 'chat':
  148. response['conversation_id'] = self._conversation.id
  149. if self._task_state.metadata:
  150. response['metadata'] = self._get_response_metadata()
  151. return response
  152. else:
  153. continue
  154. def _process_stream_response(self) -> Generator:
  155. """
  156. Process stream response.
  157. :return:
  158. """
  159. for message in self._queue_manager.listen():
  160. event = message.event
  161. if isinstance(event, QueueErrorEvent):
  162. data = self._error_to_stream_response_data(self._handle_error(event))
  163. yield self._yield_response(data)
  164. break
  165. elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
  166. if isinstance(event, QueueMessageEndEvent):
  167. self._task_state.llm_result = event.llm_result
  168. else:
  169. model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
  170. model = model_config.model
  171. model_type_instance = model_config.provider_model_bundle.model_type_instance
  172. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  173. # calculate num tokens
  174. prompt_tokens = 0
  175. if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
  176. prompt_tokens = model_type_instance.get_num_tokens(
  177. model,
  178. model_config.credentials,
  179. self._task_state.llm_result.prompt_messages
  180. )
  181. completion_tokens = 0
  182. if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
  183. completion_tokens = model_type_instance.get_num_tokens(
  184. model,
  185. model_config.credentials,
  186. [self._task_state.llm_result.message]
  187. )
  188. credentials = model_config.credentials
  189. # transform usage
  190. self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
  191. model,
  192. credentials,
  193. prompt_tokens,
  194. completion_tokens
  195. )
  196. self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
  197. # response moderation
  198. if self._output_moderation_handler:
  199. self._output_moderation_handler.stop_thread()
  200. self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
  201. completion=self._task_state.llm_result.message.content,
  202. public_event=False
  203. )
  204. self._output_moderation_handler = None
  205. replace_response = {
  206. 'event': 'message_replace',
  207. 'task_id': self._application_generate_entity.task_id,
  208. 'message_id': self._message.id,
  209. 'answer': self._task_state.llm_result.message.content,
  210. 'created_at': int(self._message.created_at.timestamp())
  211. }
  212. if self._conversation.mode == 'chat':
  213. replace_response['conversation_id'] = self._conversation.id
  214. yield self._yield_response(replace_response)
  215. # Save message
  216. self._save_message(self._task_state.llm_result)
  217. response = {
  218. 'event': 'message_end',
  219. 'task_id': self._application_generate_entity.task_id,
  220. 'id': self._message.id,
  221. 'message_id': self._message.id,
  222. }
  223. if self._conversation.mode == 'chat':
  224. response['conversation_id'] = self._conversation.id
  225. if self._task_state.metadata:
  226. response['metadata'] = self._get_response_metadata()
  227. yield self._yield_response(response)
  228. elif isinstance(event, QueueRetrieverResourcesEvent):
  229. self._task_state.metadata['retriever_resources'] = event.retriever_resources
  230. elif isinstance(event, AnnotationReplyEvent):
  231. annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
  232. if annotation:
  233. account = annotation.account
  234. self._task_state.metadata['annotation_reply'] = {
  235. 'id': annotation.id,
  236. 'account': {
  237. 'id': annotation.account_id,
  238. 'name': account.name if account else 'Dify user'
  239. }
  240. }
  241. self._task_state.llm_result.message.content = annotation.content
  242. elif isinstance(event, QueueAgentThoughtEvent):
  243. agent_thought: MessageAgentThought = (
  244. db.session.query(MessageAgentThought)
  245. .filter(MessageAgentThought.id == event.agent_thought_id)
  246. .first()
  247. )
  248. db.session.refresh(agent_thought)
  249. if agent_thought:
  250. response = {
  251. 'event': 'agent_thought',
  252. 'id': agent_thought.id,
  253. 'task_id': self._application_generate_entity.task_id,
  254. 'message_id': self._message.id,
  255. 'position': agent_thought.position,
  256. 'thought': agent_thought.thought,
  257. 'observation': agent_thought.observation,
  258. 'tool': agent_thought.tool,
  259. 'tool_labels': agent_thought.tool_labels,
  260. 'tool_input': agent_thought.tool_input,
  261. 'created_at': int(self._message.created_at.timestamp()),
  262. 'message_files': agent_thought.files
  263. }
  264. if self._conversation.mode == 'chat':
  265. response['conversation_id'] = self._conversation.id
  266. yield self._yield_response(response)
  267. elif isinstance(event, QueueMessageFileEvent):
  268. message_file: MessageFile = (
  269. db.session.query(MessageFile)
  270. .filter(MessageFile.id == event.message_file_id)
  271. .first()
  272. )
  273. # get extension
  274. if '.' in message_file.url:
  275. extension = f'.{message_file.url.split(".")[-1]}'
  276. if len(extension) > 10:
  277. extension = '.bin'
  278. else:
  279. extension = '.bin'
  280. # add sign url
  281. url = ToolFileManager.sign_file(file_id=message_file.id, extension=extension)
  282. if message_file:
  283. response = {
  284. 'event': 'message_file',
  285. 'id': message_file.id,
  286. 'type': message_file.type,
  287. 'belongs_to': message_file.belongs_to or 'user',
  288. 'url': url
  289. }
  290. if self._conversation.mode == 'chat':
  291. response['conversation_id'] = self._conversation.id
  292. yield self._yield_response(response)
  293. elif isinstance(event, (QueueMessageEvent, QueueAgentMessageEvent)):
  294. chunk = event.chunk
  295. delta_text = chunk.delta.message.content
  296. if delta_text is None:
  297. continue
  298. if not self._task_state.llm_result.prompt_messages:
  299. self._task_state.llm_result.prompt_messages = chunk.prompt_messages
  300. if self._output_moderation_handler:
  301. if self._output_moderation_handler.should_direct_output():
  302. # stop subscribe new token when output moderation should direct output
  303. self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
  304. self._queue_manager.publish_chunk_message(LLMResultChunk(
  305. model=self._task_state.llm_result.model,
  306. prompt_messages=self._task_state.llm_result.prompt_messages,
  307. delta=LLMResultChunkDelta(
  308. index=0,
  309. message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
  310. )
  311. ), PublishFrom.TASK_PIPELINE)
  312. self._queue_manager.publish(
  313. QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
  314. PublishFrom.TASK_PIPELINE
  315. )
  316. continue
  317. else:
  318. self._output_moderation_handler.append_new_token(delta_text)
  319. self._task_state.llm_result.message.content += delta_text
  320. response = self._handle_chunk(delta_text, agent=isinstance(event, QueueAgentMessageEvent))
  321. yield self._yield_response(response)
  322. elif isinstance(event, QueueMessageReplaceEvent):
  323. response = {
  324. 'event': 'message_replace',
  325. 'task_id': self._application_generate_entity.task_id,
  326. 'message_id': self._message.id,
  327. 'answer': event.text,
  328. 'created_at': int(self._message.created_at.timestamp())
  329. }
  330. if self._conversation.mode == 'chat':
  331. response['conversation_id'] = self._conversation.id
  332. yield self._yield_response(response)
  333. elif isinstance(event, QueuePingEvent):
  334. yield "event: ping\n\n"
  335. else:
  336. continue
  337. def _save_message(self, llm_result: LLMResult) -> None:
  338. """
  339. Save message.
  340. :param llm_result: llm result
  341. :return:
  342. """
  343. usage = llm_result.usage
  344. self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
  345. self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
  346. self._message.message_tokens = usage.prompt_tokens
  347. self._message.message_unit_price = usage.prompt_unit_price
  348. self._message.message_price_unit = usage.prompt_price_unit
  349. self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
  350. if llm_result.message.content else ''
  351. self._message.answer_tokens = usage.completion_tokens
  352. self._message.answer_unit_price = usage.completion_unit_price
  353. self._message.answer_price_unit = usage.completion_price_unit
  354. self._message.provider_response_latency = time.perf_counter() - self._start_at
  355. self._message.total_price = usage.total_price
  356. db.session.commit()
  357. message_was_created.send(
  358. self._message,
  359. application_generate_entity=self._application_generate_entity,
  360. conversation=self._conversation,
  361. is_first_message=self._application_generate_entity.conversation_id is None,
  362. extras=self._application_generate_entity.extras
  363. )
  364. def _handle_chunk(self, text: str, agent: bool = False) -> dict:
  365. """
  366. Handle completed event.
  367. :param text: text
  368. :return:
  369. """
  370. response = {
  371. 'event': 'message' if not agent else 'agent_message',
  372. 'id': self._message.id,
  373. 'task_id': self._application_generate_entity.task_id,
  374. 'message_id': self._message.id,
  375. 'answer': text,
  376. 'created_at': int(self._message.created_at.timestamp())
  377. }
  378. if self._conversation.mode == 'chat':
  379. response['conversation_id'] = self._conversation.id
  380. return response
  381. def _handle_error(self, event: QueueErrorEvent) -> Exception:
  382. """
  383. Handle error event.
  384. :param event: event
  385. :return:
  386. """
  387. logger.debug("error: %s", event.error)
  388. e = event.error
  389. if isinstance(e, InvokeAuthorizationError):
  390. return InvokeAuthorizationError('Incorrect API key provided')
  391. elif isinstance(e, InvokeError) or isinstance(e, ValueError):
  392. return e
  393. else:
  394. return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
  395. def _error_to_stream_response_data(self, e: Exception) -> dict:
  396. """
  397. Error to stream response.
  398. :param e: exception
  399. :return:
  400. """
  401. if isinstance(e, ValueError):
  402. data = {
  403. 'code': 'invalid_param',
  404. 'message': str(e),
  405. 'status': 400
  406. }
  407. elif isinstance(e, ProviderTokenNotInitError):
  408. data = {
  409. 'code': 'provider_not_initialize',
  410. 'message': e.description,
  411. 'status': 400
  412. }
  413. elif isinstance(e, QuotaExceededError):
  414. data = {
  415. 'code': 'provider_quota_exceeded',
  416. 'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
  417. "Please go to Settings -> Model Provider to complete your own provider credentials.",
  418. 'status': 400
  419. }
  420. elif isinstance(e, ModelCurrentlyNotSupportError):
  421. data = {
  422. 'code': 'model_currently_not_support',
  423. 'message': e.description,
  424. 'status': 400
  425. }
  426. elif isinstance(e, InvokeError):
  427. data = {
  428. 'code': 'completion_request_error',
  429. 'message': e.description,
  430. 'status': 400
  431. }
  432. else:
  433. logging.error(e)
  434. data = {
  435. 'code': 'internal_server_error',
  436. 'message': 'Internal Server Error, please contact support.',
  437. 'status': 500
  438. }
  439. return {
  440. 'event': 'error',
  441. 'task_id': self._application_generate_entity.task_id,
  442. 'message_id': self._message.id,
  443. **data
  444. }
  445. def _get_response_metadata(self) -> dict:
  446. """
  447. Get response metadata by invoke from.
  448. :return:
  449. """
  450. metadata = {}
  451. # show_retrieve_source
  452. if 'retriever_resources' in self._task_state.metadata:
  453. if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
  454. metadata['retriever_resources'] = self._task_state.metadata['retriever_resources']
  455. else:
  456. metadata['retriever_resources'] = []
  457. for resource in self._task_state.metadata['retriever_resources']:
  458. metadata['retriever_resources'].append({
  459. 'segment_id': resource['segment_id'],
  460. 'position': resource['position'],
  461. 'document_name': resource['document_name'],
  462. 'score': resource['score'],
  463. 'content': resource['content'],
  464. })
  465. # show annotation reply
  466. if 'annotation_reply' in self._task_state.metadata:
  467. if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
  468. metadata['annotation_reply'] = self._task_state.metadata['annotation_reply']
  469. # show usage
  470. if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
  471. metadata['usage'] = self._task_state.metadata['usage']
  472. return metadata
  473. def _yield_response(self, response: dict) -> str:
  474. """
  475. Yield response.
  476. :param response: response
  477. :return:
  478. """
  479. return "data: " + json.dumps(response) + "\n\n"
  480. def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]:
  481. """
  482. Prompt messages to prompt for saving.
  483. :param prompt_messages: prompt messages
  484. :return:
  485. """
  486. prompts = []
  487. if self._application_generate_entity.app_orchestration_config_entity.model_config.mode == 'chat':
  488. for prompt_message in prompt_messages:
  489. if prompt_message.role == PromptMessageRole.USER:
  490. role = 'user'
  491. elif prompt_message.role == PromptMessageRole.ASSISTANT:
  492. role = 'assistant'
  493. elif prompt_message.role == PromptMessageRole.SYSTEM:
  494. role = 'system'
  495. else:
  496. continue
  497. text = ''
  498. files = []
  499. if isinstance(prompt_message.content, list):
  500. for content in prompt_message.content:
  501. if content.type == PromptMessageContentType.TEXT:
  502. content = cast(TextPromptMessageContent, content)
  503. text += content.data
  504. else:
  505. content = cast(ImagePromptMessageContent, content)
  506. files.append({
  507. "type": 'image',
  508. "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
  509. "detail": content.detail.value
  510. })
  511. else:
  512. text = prompt_message.content
  513. prompts.append({
  514. "role": role,
  515. "text": text,
  516. "files": files
  517. })
  518. else:
  519. prompt_message = prompt_messages[0]
  520. text = ''
  521. files = []
  522. if isinstance(prompt_message.content, list):
  523. for content in prompt_message.content:
  524. if content.type == PromptMessageContentType.TEXT:
  525. content = cast(TextPromptMessageContent, content)
  526. text += content.data
  527. else:
  528. content = cast(ImagePromptMessageContent, content)
  529. files.append({
  530. "type": 'image',
  531. "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
  532. "detail": content.detail.value
  533. })
  534. else:
  535. text = prompt_message.content
  536. params = {
  537. "role": 'user',
  538. "text": text,
  539. }
  540. if files:
  541. params['files'] = files
  542. prompts.append(params)
  543. return prompts
  544. def _init_output_moderation(self) -> Optional[OutputModerationHandler]:
  545. """
  546. Init output moderation.
  547. :return:
  548. """
  549. app_orchestration_config_entity = self._application_generate_entity.app_orchestration_config_entity
  550. sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance
  551. if sensitive_word_avoidance:
  552. return OutputModerationHandler(
  553. tenant_id=self._application_generate_entity.tenant_id,
  554. app_id=self._application_generate_entity.app_id,
  555. rule=ModerationRule(
  556. type=sensitive_word_avoidance.type,
  557. config=sensitive_word_avoidance.config
  558. ),
  559. on_message_replace_func=self._queue_manager.publish_message_replace
  560. )