generate_task_pipeline.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644
  1. import json
  2. import logging
  3. import time
  4. from typing import Generator, Optional, Union, cast
  5. from pydantic import BaseModel
  6. from core.app_runner.moderation_handler import ModerationRule, OutputModerationHandler
  7. from core.application_queue_manager import ApplicationQueueManager, PublishFrom
  8. from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom
  9. from core.entities.queue_entities import (
  10. AnnotationReplyEvent,
  11. QueueAgentMessageEvent,
  12. QueueAgentThoughtEvent,
  13. QueueErrorEvent,
  14. QueueMessageEndEvent,
  15. QueueMessageEvent,
  16. QueueMessageFileEvent,
  17. QueueMessageReplaceEvent,
  18. QueuePingEvent,
  19. QueueRetrieverResourcesEvent,
  20. QueueStopEvent,
  21. )
  22. from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
  23. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
  24. from core.model_runtime.entities.message_entities import (
  25. AssistantPromptMessage,
  26. ImagePromptMessageContent,
  27. PromptMessage,
  28. PromptMessageContentType,
  29. PromptMessageRole,
  30. TextPromptMessageContent,
  31. )
  32. from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
  33. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  34. from core.model_runtime.utils.encoders import jsonable_encoder
  35. from core.prompt.prompt_template import PromptTemplateParser
  36. from core.tools.tool_file_manager import ToolFileManager
  37. from events.message_event import message_was_created
  38. from extensions.ext_database import db
  39. from models.model import Conversation, Message, MessageAgentThought, MessageFile
  40. from services.annotation_service import AppAnnotationService
  41. logger = logging.getLogger(__name__)
  42. class TaskState(BaseModel):
  43. """
  44. TaskState entity
  45. """
  46. llm_result: LLMResult
  47. metadata: dict = {}
  48. class GenerateTaskPipeline:
  49. """
  50. GenerateTaskPipeline is a class that generate stream output and state management for Application.
  51. """
  52. def __init__(self, application_generate_entity: ApplicationGenerateEntity,
  53. queue_manager: ApplicationQueueManager,
  54. conversation: Conversation,
  55. message: Message) -> None:
  56. """
  57. Initialize GenerateTaskPipeline.
  58. :param application_generate_entity: application generate entity
  59. :param queue_manager: queue manager
  60. :param conversation: conversation
  61. :param message: message
  62. """
  63. self._application_generate_entity = application_generate_entity
  64. self._queue_manager = queue_manager
  65. self._conversation = conversation
  66. self._message = message
  67. self._task_state = TaskState(
  68. llm_result=LLMResult(
  69. model=self._application_generate_entity.app_orchestration_config_entity.model_config.model,
  70. prompt_messages=[],
  71. message=AssistantPromptMessage(content=""),
  72. usage=LLMUsage.empty_usage()
  73. )
  74. )
  75. self._start_at = time.perf_counter()
  76. self._output_moderation_handler = self._init_output_moderation()
  77. def process(self, stream: bool) -> Union[dict, Generator]:
  78. """
  79. Process generate task pipeline.
  80. :return:
  81. """
  82. if stream:
  83. return self._process_stream_response()
  84. else:
  85. return self._process_blocking_response()
  86. def _process_blocking_response(self) -> dict:
  87. """
  88. Process blocking response.
  89. :return:
  90. """
  91. for queue_message in self._queue_manager.listen():
  92. event = queue_message.event
  93. if isinstance(event, QueueErrorEvent):
  94. raise self._handle_error(event)
  95. elif isinstance(event, QueueRetrieverResourcesEvent):
  96. self._task_state.metadata['retriever_resources'] = event.retriever_resources
  97. elif isinstance(event, AnnotationReplyEvent):
  98. annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
  99. if annotation:
  100. account = annotation.account
  101. self._task_state.metadata['annotation_reply'] = {
  102. 'id': annotation.id,
  103. 'account': {
  104. 'id': annotation.account_id,
  105. 'name': account.name if account else 'Dify user'
  106. }
  107. }
  108. self._task_state.llm_result.message.content = annotation.content
  109. elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
  110. if isinstance(event, QueueMessageEndEvent):
  111. self._task_state.llm_result = event.llm_result
  112. else:
  113. model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
  114. model = model_config.model
  115. model_type_instance = model_config.provider_model_bundle.model_type_instance
  116. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  117. # calculate num tokens
  118. prompt_tokens = 0
  119. if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
  120. prompt_tokens = model_type_instance.get_num_tokens(
  121. model,
  122. model_config.credentials,
  123. self._task_state.llm_result.prompt_messages
  124. )
  125. completion_tokens = 0
  126. if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
  127. completion_tokens = model_type_instance.get_num_tokens(
  128. model,
  129. model_config.credentials,
  130. [self._task_state.llm_result.message]
  131. )
  132. credentials = model_config.credentials
  133. # transform usage
  134. self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
  135. model,
  136. credentials,
  137. prompt_tokens,
  138. completion_tokens
  139. )
  140. self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
  141. # response moderation
  142. if self._output_moderation_handler:
  143. self._output_moderation_handler.stop_thread()
  144. self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
  145. completion=self._task_state.llm_result.message.content,
  146. public_event=False
  147. )
  148. # Save message
  149. self._save_message(self._task_state.llm_result)
  150. response = {
  151. 'event': 'message',
  152. 'task_id': self._application_generate_entity.task_id,
  153. 'id': self._message.id,
  154. 'message_id': self._message.id,
  155. 'mode': self._conversation.mode,
  156. 'answer': event.llm_result.message.content,
  157. 'metadata': {},
  158. 'created_at': int(self._message.created_at.timestamp())
  159. }
  160. if self._conversation.mode == 'chat':
  161. response['conversation_id'] = self._conversation.id
  162. if self._task_state.metadata:
  163. response['metadata'] = self._get_response_metadata()
  164. return response
  165. else:
  166. continue
  167. def _process_stream_response(self) -> Generator:
  168. """
  169. Process stream response.
  170. :return:
  171. """
  172. for message in self._queue_manager.listen():
  173. event = message.event
  174. if isinstance(event, QueueErrorEvent):
  175. data = self._error_to_stream_response_data(self._handle_error(event))
  176. yield self._yield_response(data)
  177. break
  178. elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
  179. if isinstance(event, QueueMessageEndEvent):
  180. self._task_state.llm_result = event.llm_result
  181. else:
  182. model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
  183. model = model_config.model
  184. model_type_instance = model_config.provider_model_bundle.model_type_instance
  185. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  186. # calculate num tokens
  187. prompt_tokens = 0
  188. if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
  189. prompt_tokens = model_type_instance.get_num_tokens(
  190. model,
  191. model_config.credentials,
  192. self._task_state.llm_result.prompt_messages
  193. )
  194. completion_tokens = 0
  195. if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
  196. completion_tokens = model_type_instance.get_num_tokens(
  197. model,
  198. model_config.credentials,
  199. [self._task_state.llm_result.message]
  200. )
  201. credentials = model_config.credentials
  202. # transform usage
  203. self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
  204. model,
  205. credentials,
  206. prompt_tokens,
  207. completion_tokens
  208. )
  209. self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
  210. # response moderation
  211. if self._output_moderation_handler:
  212. self._output_moderation_handler.stop_thread()
  213. self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
  214. completion=self._task_state.llm_result.message.content,
  215. public_event=False
  216. )
  217. self._output_moderation_handler = None
  218. replace_response = {
  219. 'event': 'message_replace',
  220. 'task_id': self._application_generate_entity.task_id,
  221. 'message_id': self._message.id,
  222. 'answer': self._task_state.llm_result.message.content,
  223. 'created_at': int(self._message.created_at.timestamp())
  224. }
  225. if self._conversation.mode == 'chat':
  226. replace_response['conversation_id'] = self._conversation.id
  227. yield self._yield_response(replace_response)
  228. # Save message
  229. self._save_message(self._task_state.llm_result)
  230. response = {
  231. 'event': 'message_end',
  232. 'task_id': self._application_generate_entity.task_id,
  233. 'id': self._message.id,
  234. 'message_id': self._message.id,
  235. }
  236. if self._conversation.mode == 'chat':
  237. response['conversation_id'] = self._conversation.id
  238. if self._task_state.metadata:
  239. response['metadata'] = self._get_response_metadata()
  240. yield self._yield_response(response)
  241. elif isinstance(event, QueueRetrieverResourcesEvent):
  242. self._task_state.metadata['retriever_resources'] = event.retriever_resources
  243. elif isinstance(event, AnnotationReplyEvent):
  244. annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
  245. if annotation:
  246. account = annotation.account
  247. self._task_state.metadata['annotation_reply'] = {
  248. 'id': annotation.id,
  249. 'account': {
  250. 'id': annotation.account_id,
  251. 'name': account.name if account else 'Dify user'
  252. }
  253. }
  254. self._task_state.llm_result.message.content = annotation.content
  255. elif isinstance(event, QueueAgentThoughtEvent):
  256. agent_thought: MessageAgentThought = (
  257. db.session.query(MessageAgentThought)
  258. .filter(MessageAgentThought.id == event.agent_thought_id)
  259. .first()
  260. )
  261. db.session.refresh(agent_thought)
  262. if agent_thought:
  263. response = {
  264. 'event': 'agent_thought',
  265. 'id': agent_thought.id,
  266. 'task_id': self._application_generate_entity.task_id,
  267. 'message_id': self._message.id,
  268. 'position': agent_thought.position,
  269. 'thought': agent_thought.thought,
  270. 'observation': agent_thought.observation,
  271. 'tool': agent_thought.tool,
  272. 'tool_labels': agent_thought.tool_labels,
  273. 'tool_input': agent_thought.tool_input,
  274. 'created_at': int(self._message.created_at.timestamp()),
  275. 'message_files': agent_thought.files
  276. }
  277. if self._conversation.mode == 'chat':
  278. response['conversation_id'] = self._conversation.id
  279. yield self._yield_response(response)
  280. elif isinstance(event, QueueMessageFileEvent):
  281. message_file: MessageFile = (
  282. db.session.query(MessageFile)
  283. .filter(MessageFile.id == event.message_file_id)
  284. .first()
  285. )
  286. # get extension
  287. if '.' in message_file.url:
  288. extension = f'.{message_file.url.split(".")[-1]}'
  289. if len(extension) > 10:
  290. extension = '.bin'
  291. else:
  292. extension = '.bin'
  293. # add sign url
  294. url = ToolFileManager.sign_file(file_id=message_file.id, extension=extension)
  295. if message_file:
  296. response = {
  297. 'event': 'message_file',
  298. 'id': message_file.id,
  299. 'type': message_file.type,
  300. 'belongs_to': message_file.belongs_to or 'user',
  301. 'url': url
  302. }
  303. if self._conversation.mode == 'chat':
  304. response['conversation_id'] = self._conversation.id
  305. yield self._yield_response(response)
  306. elif isinstance(event, (QueueMessageEvent, QueueAgentMessageEvent)):
  307. chunk = event.chunk
  308. delta_text = chunk.delta.message.content
  309. if delta_text is None:
  310. continue
  311. if not self._task_state.llm_result.prompt_messages:
  312. self._task_state.llm_result.prompt_messages = chunk.prompt_messages
  313. if self._output_moderation_handler:
  314. if self._output_moderation_handler.should_direct_output():
  315. # stop subscribe new token when output moderation should direct output
  316. self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
  317. self._queue_manager.publish_chunk_message(LLMResultChunk(
  318. model=self._task_state.llm_result.model,
  319. prompt_messages=self._task_state.llm_result.prompt_messages,
  320. delta=LLMResultChunkDelta(
  321. index=0,
  322. message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
  323. )
  324. ), PublishFrom.TASK_PIPELINE)
  325. self._queue_manager.publish(
  326. QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
  327. PublishFrom.TASK_PIPELINE
  328. )
  329. continue
  330. else:
  331. self._output_moderation_handler.append_new_token(delta_text)
  332. self._task_state.llm_result.message.content += delta_text
  333. response = self._handle_chunk(delta_text, agent=isinstance(event, QueueAgentMessageEvent))
  334. yield self._yield_response(response)
  335. elif isinstance(event, QueueMessageReplaceEvent):
  336. response = {
  337. 'event': 'message_replace',
  338. 'task_id': self._application_generate_entity.task_id,
  339. 'message_id': self._message.id,
  340. 'answer': event.text,
  341. 'created_at': int(self._message.created_at.timestamp())
  342. }
  343. if self._conversation.mode == 'chat':
  344. response['conversation_id'] = self._conversation.id
  345. yield self._yield_response(response)
  346. elif isinstance(event, QueuePingEvent):
  347. yield "event: ping\n\n"
  348. else:
  349. continue
  350. def _save_message(self, llm_result: LLMResult) -> None:
  351. """
  352. Save message.
  353. :param llm_result: llm result
  354. :return:
  355. """
  356. usage = llm_result.usage
  357. self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
  358. self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
  359. self._message.message_tokens = usage.prompt_tokens
  360. self._message.message_unit_price = usage.prompt_unit_price
  361. self._message.message_price_unit = usage.prompt_price_unit
  362. self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
  363. if llm_result.message.content else ''
  364. self._message.answer_tokens = usage.completion_tokens
  365. self._message.answer_unit_price = usage.completion_unit_price
  366. self._message.answer_price_unit = usage.completion_price_unit
  367. self._message.provider_response_latency = time.perf_counter() - self._start_at
  368. self._message.total_price = usage.total_price
  369. db.session.commit()
  370. message_was_created.send(
  371. self._message,
  372. application_generate_entity=self._application_generate_entity,
  373. conversation=self._conversation,
  374. is_first_message=self._application_generate_entity.conversation_id is None,
  375. extras=self._application_generate_entity.extras
  376. )
  377. def _handle_chunk(self, text: str, agent: bool = False) -> dict:
  378. """
  379. Handle completed event.
  380. :param text: text
  381. :return:
  382. """
  383. response = {
  384. 'event': 'message' if not agent else 'agent_message',
  385. 'id': self._message.id,
  386. 'task_id': self._application_generate_entity.task_id,
  387. 'message_id': self._message.id,
  388. 'answer': text,
  389. 'created_at': int(self._message.created_at.timestamp())
  390. }
  391. if self._conversation.mode == 'chat':
  392. response['conversation_id'] = self._conversation.id
  393. return response
  394. def _handle_error(self, event: QueueErrorEvent) -> Exception:
  395. """
  396. Handle error event.
  397. :param event: event
  398. :return:
  399. """
  400. logger.debug("error: %s", event.error)
  401. e = event.error
  402. if isinstance(e, InvokeAuthorizationError):
  403. return InvokeAuthorizationError('Incorrect API key provided')
  404. elif isinstance(e, InvokeError) or isinstance(e, ValueError):
  405. return e
  406. else:
  407. return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
  408. def _error_to_stream_response_data(self, e: Exception) -> dict:
  409. """
  410. Error to stream response.
  411. :param e: exception
  412. :return:
  413. """
  414. error_responses = {
  415. ValueError: {'code': 'invalid_param', 'status': 400},
  416. ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400},
  417. QuotaExceededError: {
  418. 'code': 'provider_quota_exceeded',
  419. 'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
  420. "Please go to Settings -> Model Provider to complete your own provider credentials.",
  421. 'status': 400
  422. },
  423. ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400},
  424. InvokeError: {'code': 'completion_request_error', 'status': 400}
  425. }
  426. # Determine the response based on the type of exception
  427. data = None
  428. for k, v in error_responses.items():
  429. if isinstance(e, k):
  430. data = v
  431. if data:
  432. data.setdefault('message', getattr(e, 'description', str(e)))
  433. else:
  434. logging.error(e)
  435. data = {
  436. 'code': 'internal_server_error',
  437. 'message': 'Internal Server Error, please contact support.',
  438. 'status': 500
  439. }
  440. return {
  441. 'event': 'error',
  442. 'task_id': self._application_generate_entity.task_id,
  443. 'message_id': self._message.id,
  444. **data
  445. }
  446. def _get_response_metadata(self) -> dict:
  447. """
  448. Get response metadata by invoke from.
  449. :return:
  450. """
  451. metadata = {}
  452. # show_retrieve_source
  453. if 'retriever_resources' in self._task_state.metadata:
  454. if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
  455. metadata['retriever_resources'] = self._task_state.metadata['retriever_resources']
  456. else:
  457. metadata['retriever_resources'] = []
  458. for resource in self._task_state.metadata['retriever_resources']:
  459. metadata['retriever_resources'].append({
  460. 'segment_id': resource['segment_id'],
  461. 'position': resource['position'],
  462. 'document_name': resource['document_name'],
  463. 'score': resource['score'],
  464. 'content': resource['content'],
  465. })
  466. # show annotation reply
  467. if 'annotation_reply' in self._task_state.metadata:
  468. if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
  469. metadata['annotation_reply'] = self._task_state.metadata['annotation_reply']
  470. # show usage
  471. if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
  472. metadata['usage'] = self._task_state.metadata['usage']
  473. return metadata
  474. def _yield_response(self, response: dict) -> str:
  475. """
  476. Yield response.
  477. :param response: response
  478. :return:
  479. """
  480. return "data: " + json.dumps(response) + "\n\n"
  481. def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]:
  482. """
  483. Prompt messages to prompt for saving.
  484. :param prompt_messages: prompt messages
  485. :return:
  486. """
  487. prompts = []
  488. if self._application_generate_entity.app_orchestration_config_entity.model_config.mode == 'chat':
  489. for prompt_message in prompt_messages:
  490. if prompt_message.role == PromptMessageRole.USER:
  491. role = 'user'
  492. elif prompt_message.role == PromptMessageRole.ASSISTANT:
  493. role = 'assistant'
  494. elif prompt_message.role == PromptMessageRole.SYSTEM:
  495. role = 'system'
  496. else:
  497. continue
  498. text = ''
  499. files = []
  500. if isinstance(prompt_message.content, list):
  501. for content in prompt_message.content:
  502. if content.type == PromptMessageContentType.TEXT:
  503. content = cast(TextPromptMessageContent, content)
  504. text += content.data
  505. else:
  506. content = cast(ImagePromptMessageContent, content)
  507. files.append({
  508. "type": 'image',
  509. "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
  510. "detail": content.detail.value
  511. })
  512. else:
  513. text = prompt_message.content
  514. prompts.append({
  515. "role": role,
  516. "text": text,
  517. "files": files
  518. })
  519. else:
  520. prompt_message = prompt_messages[0]
  521. text = ''
  522. files = []
  523. if isinstance(prompt_message.content, list):
  524. for content in prompt_message.content:
  525. if content.type == PromptMessageContentType.TEXT:
  526. content = cast(TextPromptMessageContent, content)
  527. text += content.data
  528. else:
  529. content = cast(ImagePromptMessageContent, content)
  530. files.append({
  531. "type": 'image',
  532. "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
  533. "detail": content.detail.value
  534. })
  535. else:
  536. text = prompt_message.content
  537. params = {
  538. "role": 'user',
  539. "text": text,
  540. }
  541. if files:
  542. params['files'] = files
  543. prompts.append(params)
  544. return prompts
  545. def _init_output_moderation(self) -> Optional[OutputModerationHandler]:
  546. """
  547. Init output moderation.
  548. :return:
  549. """
  550. app_orchestration_config_entity = self._application_generate_entity.app_orchestration_config_entity
  551. sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance
  552. if sensitive_word_avoidance:
  553. return OutputModerationHandler(
  554. tenant_id=self._application_generate_entity.tenant_id,
  555. app_id=self._application_generate_entity.app_id,
  556. rule=ModerationRule(
  557. type=sensitive_word_avoidance.type,
  558. config=sensitive_word_avoidance.config
  559. ),
  560. on_message_replace_func=self._queue_manager.publish_message_replace
  561. )