generate_task_pipeline.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638
  1. import json
  2. import logging
  3. import time
  4. from typing import Generator, Optional, Union, cast
  5. from core.app_runner.moderation_handler import ModerationRule, OutputModerationHandler
  6. from core.application_queue_manager import ApplicationQueueManager, PublishFrom
  7. from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom
  8. from core.entities.queue_entities import (AnnotationReplyEvent, QueueAgentThoughtEvent, QueueErrorEvent,
  9. QueueMessageEndEvent, QueueMessageEvent, QueueMessageReplaceEvent,
  10. QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent,
  11. QueueMessageFileEvent, QueueAgentMessageEvent)
  12. from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
  13. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
  14. from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent,
  15. PromptMessage, PromptMessageContentType, PromptMessageRole,
  16. TextPromptMessageContent)
  17. from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
  18. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  19. from core.tools.tool_file_manager import ToolFileManager
  20. from core.model_runtime.utils.encoders import jsonable_encoder
  21. from core.prompt.prompt_template import PromptTemplateParser
  22. from events.message_event import message_was_created
  23. from extensions.ext_database import db
  24. from models.model import Conversation, Message, MessageAgentThought, MessageFile
  25. from pydantic import BaseModel
  26. from services.annotation_service import AppAnnotationService
  27. logger = logging.getLogger(__name__)
  28. class TaskState(BaseModel):
  29. """
  30. TaskState entity
  31. """
  32. llm_result: LLMResult
  33. metadata: dict = {}
  34. class GenerateTaskPipeline:
  35. """
  36. GenerateTaskPipeline is a class that generate stream output and state management for Application.
  37. """
  38. def __init__(self, application_generate_entity: ApplicationGenerateEntity,
  39. queue_manager: ApplicationQueueManager,
  40. conversation: Conversation,
  41. message: Message) -> None:
  42. """
  43. Initialize GenerateTaskPipeline.
  44. :param application_generate_entity: application generate entity
  45. :param queue_manager: queue manager
  46. :param conversation: conversation
  47. :param message: message
  48. """
  49. self._application_generate_entity = application_generate_entity
  50. self._queue_manager = queue_manager
  51. self._conversation = conversation
  52. self._message = message
  53. self._task_state = TaskState(
  54. llm_result=LLMResult(
  55. model=self._application_generate_entity.app_orchestration_config_entity.model_config.model,
  56. prompt_messages=[],
  57. message=AssistantPromptMessage(content=""),
  58. usage=LLMUsage.empty_usage()
  59. )
  60. )
  61. self._start_at = time.perf_counter()
  62. self._output_moderation_handler = self._init_output_moderation()
  63. def process(self, stream: bool) -> Union[dict, Generator]:
  64. """
  65. Process generate task pipeline.
  66. :return:
  67. """
  68. if stream:
  69. return self._process_stream_response()
  70. else:
  71. return self._process_blocking_response()
  72. def _process_blocking_response(self) -> dict:
  73. """
  74. Process blocking response.
  75. :return:
  76. """
  77. for queue_message in self._queue_manager.listen():
  78. event = queue_message.event
  79. if isinstance(event, QueueErrorEvent):
  80. raise self._handle_error(event)
  81. elif isinstance(event, QueueRetrieverResourcesEvent):
  82. self._task_state.metadata['retriever_resources'] = event.retriever_resources
  83. elif isinstance(event, AnnotationReplyEvent):
  84. annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
  85. if annotation:
  86. account = annotation.account
  87. self._task_state.metadata['annotation_reply'] = {
  88. 'id': annotation.id,
  89. 'account': {
  90. 'id': annotation.account_id,
  91. 'name': account.name if account else 'Dify user'
  92. }
  93. }
  94. self._task_state.llm_result.message.content = annotation.content
  95. elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
  96. if isinstance(event, QueueMessageEndEvent):
  97. self._task_state.llm_result = event.llm_result
  98. else:
  99. model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
  100. model = model_config.model
  101. model_type_instance = model_config.provider_model_bundle.model_type_instance
  102. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  103. # calculate num tokens
  104. prompt_tokens = 0
  105. if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
  106. prompt_tokens = model_type_instance.get_num_tokens(
  107. model,
  108. model_config.credentials,
  109. self._task_state.llm_result.prompt_messages
  110. )
  111. completion_tokens = 0
  112. if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
  113. completion_tokens = model_type_instance.get_num_tokens(
  114. model,
  115. model_config.credentials,
  116. [self._task_state.llm_result.message]
  117. )
  118. credentials = model_config.credentials
  119. # transform usage
  120. self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
  121. model,
  122. credentials,
  123. prompt_tokens,
  124. completion_tokens
  125. )
  126. self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
  127. # response moderation
  128. if self._output_moderation_handler:
  129. self._output_moderation_handler.stop_thread()
  130. self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
  131. completion=self._task_state.llm_result.message.content,
  132. public_event=False
  133. )
  134. # Save message
  135. self._save_message(self._task_state.llm_result)
  136. response = {
  137. 'event': 'message',
  138. 'task_id': self._application_generate_entity.task_id,
  139. 'id': self._message.id,
  140. 'message_id': self._message.id,
  141. 'mode': self._conversation.mode,
  142. 'answer': event.llm_result.message.content,
  143. 'metadata': {},
  144. 'created_at': int(self._message.created_at.timestamp())
  145. }
  146. if self._conversation.mode == 'chat':
  147. response['conversation_id'] = self._conversation.id
  148. if self._task_state.metadata:
  149. response['metadata'] = self._get_response_metadata()
  150. return response
  151. else:
  152. continue
  153. def _process_stream_response(self) -> Generator:
  154. """
  155. Process stream response.
  156. :return:
  157. """
  158. for message in self._queue_manager.listen():
  159. event = message.event
  160. if isinstance(event, QueueErrorEvent):
  161. data = self._error_to_stream_response_data(self._handle_error(event))
  162. yield self._yield_response(data)
  163. break
  164. elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
  165. if isinstance(event, QueueMessageEndEvent):
  166. self._task_state.llm_result = event.llm_result
  167. else:
  168. model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
  169. model = model_config.model
  170. model_type_instance = model_config.provider_model_bundle.model_type_instance
  171. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  172. # calculate num tokens
  173. prompt_tokens = 0
  174. if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
  175. prompt_tokens = model_type_instance.get_num_tokens(
  176. model,
  177. model_config.credentials,
  178. self._task_state.llm_result.prompt_messages
  179. )
  180. completion_tokens = 0
  181. if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
  182. completion_tokens = model_type_instance.get_num_tokens(
  183. model,
  184. model_config.credentials,
  185. [self._task_state.llm_result.message]
  186. )
  187. credentials = model_config.credentials
  188. # transform usage
  189. self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
  190. model,
  191. credentials,
  192. prompt_tokens,
  193. completion_tokens
  194. )
  195. self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
  196. # response moderation
  197. if self._output_moderation_handler:
  198. self._output_moderation_handler.stop_thread()
  199. self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
  200. completion=self._task_state.llm_result.message.content,
  201. public_event=False
  202. )
  203. self._output_moderation_handler = None
  204. replace_response = {
  205. 'event': 'message_replace',
  206. 'task_id': self._application_generate_entity.task_id,
  207. 'message_id': self._message.id,
  208. 'answer': self._task_state.llm_result.message.content,
  209. 'created_at': int(self._message.created_at.timestamp())
  210. }
  211. if self._conversation.mode == 'chat':
  212. replace_response['conversation_id'] = self._conversation.id
  213. yield self._yield_response(replace_response)
  214. # Save message
  215. self._save_message(self._task_state.llm_result)
  216. response = {
  217. 'event': 'message_end',
  218. 'task_id': self._application_generate_entity.task_id,
  219. 'id': self._message.id,
  220. 'message_id': self._message.id,
  221. }
  222. if self._conversation.mode == 'chat':
  223. response['conversation_id'] = self._conversation.id
  224. if self._task_state.metadata:
  225. response['metadata'] = self._get_response_metadata()
  226. yield self._yield_response(response)
  227. elif isinstance(event, QueueRetrieverResourcesEvent):
  228. self._task_state.metadata['retriever_resources'] = event.retriever_resources
  229. elif isinstance(event, AnnotationReplyEvent):
  230. annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
  231. if annotation:
  232. account = annotation.account
  233. self._task_state.metadata['annotation_reply'] = {
  234. 'id': annotation.id,
  235. 'account': {
  236. 'id': annotation.account_id,
  237. 'name': account.name if account else 'Dify user'
  238. }
  239. }
  240. self._task_state.llm_result.message.content = annotation.content
  241. elif isinstance(event, QueueAgentThoughtEvent):
  242. agent_thought = (
  243. db.session.query(MessageAgentThought)
  244. .filter(MessageAgentThought.id == event.agent_thought_id)
  245. .first()
  246. )
  247. db.session.refresh(agent_thought)
  248. if agent_thought:
  249. response = {
  250. 'event': 'agent_thought',
  251. 'id': agent_thought.id,
  252. 'task_id': self._application_generate_entity.task_id,
  253. 'message_id': self._message.id,
  254. 'position': agent_thought.position,
  255. 'thought': agent_thought.thought,
  256. 'observation': agent_thought.observation,
  257. 'tool': agent_thought.tool,
  258. 'tool_input': agent_thought.tool_input,
  259. 'created_at': int(self._message.created_at.timestamp()),
  260. 'message_files': agent_thought.files
  261. }
  262. if self._conversation.mode == 'chat':
  263. response['conversation_id'] = self._conversation.id
  264. yield self._yield_response(response)
  265. elif isinstance(event, QueueMessageFileEvent):
  266. message_file: MessageFile = (
  267. db.session.query(MessageFile)
  268. .filter(MessageFile.id == event.message_file_id)
  269. .first()
  270. )
  271. # get extension
  272. if '.' in message_file.url:
  273. extension = f'.{message_file.url.split(".")[-1]}'
  274. if len(extension) > 10:
  275. extension = '.bin'
  276. else:
  277. extension = '.bin'
  278. # add sign url
  279. url = ToolFileManager.sign_file(file_id=message_file.id, extension=extension)
  280. if message_file:
  281. response = {
  282. 'event': 'message_file',
  283. 'id': message_file.id,
  284. 'type': message_file.type,
  285. 'belongs_to': message_file.belongs_to or 'user',
  286. 'url': url
  287. }
  288. if self._conversation.mode == 'chat':
  289. response['conversation_id'] = self._conversation.id
  290. yield self._yield_response(response)
  291. elif isinstance(event, (QueueMessageEvent, QueueAgentMessageEvent)):
  292. chunk = event.chunk
  293. delta_text = chunk.delta.message.content
  294. if delta_text is None:
  295. continue
  296. if not self._task_state.llm_result.prompt_messages:
  297. self._task_state.llm_result.prompt_messages = chunk.prompt_messages
  298. if self._output_moderation_handler:
  299. if self._output_moderation_handler.should_direct_output():
  300. # stop subscribe new token when output moderation should direct output
  301. self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
  302. self._queue_manager.publish_chunk_message(LLMResultChunk(
  303. model=self._task_state.llm_result.model,
  304. prompt_messages=self._task_state.llm_result.prompt_messages,
  305. delta=LLMResultChunkDelta(
  306. index=0,
  307. message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
  308. )
  309. ), PublishFrom.TASK_PIPELINE)
  310. self._queue_manager.publish(
  311. QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
  312. PublishFrom.TASK_PIPELINE
  313. )
  314. continue
  315. else:
  316. self._output_moderation_handler.append_new_token(delta_text)
  317. self._task_state.llm_result.message.content += delta_text
  318. response = self._handle_chunk(delta_text, agent=isinstance(event, QueueAgentMessageEvent))
  319. yield self._yield_response(response)
  320. elif isinstance(event, QueueMessageReplaceEvent):
  321. response = {
  322. 'event': 'message_replace',
  323. 'task_id': self._application_generate_entity.task_id,
  324. 'message_id': self._message.id,
  325. 'answer': event.text,
  326. 'created_at': int(self._message.created_at.timestamp())
  327. }
  328. if self._conversation.mode == 'chat':
  329. response['conversation_id'] = self._conversation.id
  330. yield self._yield_response(response)
  331. elif isinstance(event, QueuePingEvent):
  332. yield "event: ping\n\n"
  333. else:
  334. continue
  335. def _save_message(self, llm_result: LLMResult) -> None:
  336. """
  337. Save message.
  338. :param llm_result: llm result
  339. :return:
  340. """
  341. usage = llm_result.usage
  342. self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
  343. self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
  344. self._message.message_tokens = usage.prompt_tokens
  345. self._message.message_unit_price = usage.prompt_unit_price
  346. self._message.message_price_unit = usage.prompt_price_unit
  347. self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
  348. if llm_result.message.content else ''
  349. self._message.answer_tokens = usage.completion_tokens
  350. self._message.answer_unit_price = usage.completion_unit_price
  351. self._message.answer_price_unit = usage.completion_price_unit
  352. self._message.provider_response_latency = time.perf_counter() - self._start_at
  353. self._message.total_price = usage.total_price
  354. db.session.commit()
  355. message_was_created.send(
  356. self._message,
  357. application_generate_entity=self._application_generate_entity,
  358. conversation=self._conversation,
  359. is_first_message=self._application_generate_entity.conversation_id is None,
  360. extras=self._application_generate_entity.extras
  361. )
  362. def _handle_chunk(self, text: str, agent: bool = False) -> dict:
  363. """
  364. Handle completed event.
  365. :param text: text
  366. :return:
  367. """
  368. response = {
  369. 'event': 'message' if not agent else 'agent_message',
  370. 'id': self._message.id,
  371. 'task_id': self._application_generate_entity.task_id,
  372. 'message_id': self._message.id,
  373. 'answer': text,
  374. 'created_at': int(self._message.created_at.timestamp())
  375. }
  376. if self._conversation.mode == 'chat':
  377. response['conversation_id'] = self._conversation.id
  378. return response
  379. def _handle_error(self, event: QueueErrorEvent) -> Exception:
  380. """
  381. Handle error event.
  382. :param event: event
  383. :return:
  384. """
  385. logger.debug("error: %s", event.error)
  386. e = event.error
  387. if isinstance(e, InvokeAuthorizationError):
  388. return InvokeAuthorizationError('Incorrect API key provided')
  389. elif isinstance(e, InvokeError) or isinstance(e, ValueError):
  390. return e
  391. else:
  392. return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
  393. def _error_to_stream_response_data(self, e: Exception) -> dict:
  394. """
  395. Error to stream response.
  396. :param e: exception
  397. :return:
  398. """
  399. if isinstance(e, ValueError):
  400. data = {
  401. 'code': 'invalid_param',
  402. 'message': str(e),
  403. 'status': 400
  404. }
  405. elif isinstance(e, ProviderTokenNotInitError):
  406. data = {
  407. 'code': 'provider_not_initialize',
  408. 'message': e.description,
  409. 'status': 400
  410. }
  411. elif isinstance(e, QuotaExceededError):
  412. data = {
  413. 'code': 'provider_quota_exceeded',
  414. 'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
  415. "Please go to Settings -> Model Provider to complete your own provider credentials.",
  416. 'status': 400
  417. }
  418. elif isinstance(e, ModelCurrentlyNotSupportError):
  419. data = {
  420. 'code': 'model_currently_not_support',
  421. 'message': e.description,
  422. 'status': 400
  423. }
  424. elif isinstance(e, InvokeError):
  425. data = {
  426. 'code': 'completion_request_error',
  427. 'message': e.description,
  428. 'status': 400
  429. }
  430. else:
  431. logging.error(e)
  432. data = {
  433. 'code': 'internal_server_error',
  434. 'message': 'Internal Server Error, please contact support.',
  435. 'status': 500
  436. }
  437. return {
  438. 'event': 'error',
  439. 'task_id': self._application_generate_entity.task_id,
  440. 'message_id': self._message.id,
  441. **data
  442. }
  443. def _get_response_metadata(self) -> dict:
  444. """
  445. Get response metadata by invoke from.
  446. :return:
  447. """
  448. metadata = {}
  449. # show_retrieve_source
  450. if 'retriever_resources' in self._task_state.metadata:
  451. if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
  452. metadata['retriever_resources'] = self._task_state.metadata['retriever_resources']
  453. else:
  454. metadata['retriever_resources'] = []
  455. for resource in self._task_state.metadata['retriever_resources']:
  456. metadata['retriever_resources'].append({
  457. 'segment_id': resource['segment_id'],
  458. 'position': resource['position'],
  459. 'document_name': resource['document_name'],
  460. 'score': resource['score'],
  461. 'content': resource['content'],
  462. })
  463. # show annotation reply
  464. if 'annotation_reply' in self._task_state.metadata:
  465. if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
  466. metadata['annotation_reply'] = self._task_state.metadata['annotation_reply']
  467. # show usage
  468. if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
  469. metadata['usage'] = self._task_state.metadata['usage']
  470. return metadata
  471. def _yield_response(self, response: dict) -> str:
  472. """
  473. Yield response.
  474. :param response: response
  475. :return:
  476. """
  477. return "data: " + json.dumps(response) + "\n\n"
  478. def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]:
  479. """
  480. Prompt messages to prompt for saving.
  481. :param prompt_messages: prompt messages
  482. :return:
  483. """
  484. prompts = []
  485. if self._application_generate_entity.app_orchestration_config_entity.model_config.mode == 'chat':
  486. for prompt_message in prompt_messages:
  487. if prompt_message.role == PromptMessageRole.USER:
  488. role = 'user'
  489. elif prompt_message.role == PromptMessageRole.ASSISTANT:
  490. role = 'assistant'
  491. elif prompt_message.role == PromptMessageRole.SYSTEM:
  492. role = 'system'
  493. else:
  494. continue
  495. text = ''
  496. files = []
  497. if isinstance(prompt_message.content, list):
  498. for content in prompt_message.content:
  499. if content.type == PromptMessageContentType.TEXT:
  500. content = cast(TextPromptMessageContent, content)
  501. text += content.data
  502. else:
  503. content = cast(ImagePromptMessageContent, content)
  504. files.append({
  505. "type": 'image',
  506. "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
  507. "detail": content.detail.value
  508. })
  509. else:
  510. text = prompt_message.content
  511. prompts.append({
  512. "role": role,
  513. "text": text,
  514. "files": files
  515. })
  516. else:
  517. prompt_message = prompt_messages[0]
  518. text = ''
  519. files = []
  520. if isinstance(prompt_message.content, list):
  521. for content in prompt_message.content:
  522. if content.type == PromptMessageContentType.TEXT:
  523. content = cast(TextPromptMessageContent, content)
  524. text += content.data
  525. else:
  526. content = cast(ImagePromptMessageContent, content)
  527. files.append({
  528. "type": 'image',
  529. "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
  530. "detail": content.detail.value
  531. })
  532. else:
  533. text = prompt_message.content
  534. params = {
  535. "role": 'user',
  536. "text": text,
  537. }
  538. if files:
  539. params['files'] = files
  540. prompts.append(params)
  541. return prompts
  542. def _init_output_moderation(self) -> Optional[OutputModerationHandler]:
  543. """
  544. Init output moderation.
  545. :return:
  546. """
  547. app_orchestration_config_entity = self._application_generate_entity.app_orchestration_config_entity
  548. sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance
  549. if sensitive_word_avoidance:
  550. return OutputModerationHandler(
  551. tenant_id=self._application_generate_entity.tenant_id,
  552. app_id=self._application_generate_entity.app_id,
  553. rule=ModerationRule(
  554. type=sensitive_word_avoidance.type,
  555. config=sensitive_word_avoidance.config
  556. ),
  557. on_message_replace_func=self._queue_manager.publish_message_replace
  558. )