generate_task_pipeline.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599
  1. import json
  2. import logging
  3. import time
  4. from typing import Generator, Optional, Union, cast
  5. from core.app_runner.moderation_handler import ModerationRule, OutputModerationHandler
  6. from core.application_queue_manager import ApplicationQueueManager, PublishFrom
  7. from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom
  8. from core.entities.queue_entities import (AnnotationReplyEvent, QueueAgentThoughtEvent, QueueErrorEvent,
  9. QueueMessageEndEvent, QueueMessageEvent, QueueMessageReplaceEvent,
  10. QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent)
  11. from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
  12. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
  13. from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent,
  14. PromptMessage, PromptMessageContentType, PromptMessageRole,
  15. TextPromptMessageContent)
  16. from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
  17. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  18. from core.model_runtime.utils.encoders import jsonable_encoder
  19. from core.prompt.prompt_template import PromptTemplateParser
  20. from events.message_event import message_was_created
  21. from extensions.ext_database import db
  22. from models.model import Conversation, Message, MessageAgentThought
  23. from pydantic import BaseModel
  24. from services.annotation_service import AppAnnotationService
  25. logger = logging.getLogger(__name__)
  26. class TaskState(BaseModel):
  27. """
  28. TaskState entity
  29. """
  30. llm_result: LLMResult
  31. metadata: dict = {}
  32. class GenerateTaskPipeline:
  33. """
  34. GenerateTaskPipeline is a class that generate stream output and state management for Application.
  35. """
  36. def __init__(self, application_generate_entity: ApplicationGenerateEntity,
  37. queue_manager: ApplicationQueueManager,
  38. conversation: Conversation,
  39. message: Message) -> None:
  40. """
  41. Initialize GenerateTaskPipeline.
  42. :param application_generate_entity: application generate entity
  43. :param queue_manager: queue manager
  44. :param conversation: conversation
  45. :param message: message
  46. """
  47. self._application_generate_entity = application_generate_entity
  48. self._queue_manager = queue_manager
  49. self._conversation = conversation
  50. self._message = message
  51. self._task_state = TaskState(
  52. llm_result=LLMResult(
  53. model=self._application_generate_entity.app_orchestration_config_entity.model_config.model,
  54. prompt_messages=[],
  55. message=AssistantPromptMessage(content=""),
  56. usage=LLMUsage.empty_usage()
  57. )
  58. )
  59. self._start_at = time.perf_counter()
  60. self._output_moderation_handler = self._init_output_moderation()
  61. def process(self, stream: bool) -> Union[dict, Generator]:
  62. """
  63. Process generate task pipeline.
  64. :return:
  65. """
  66. if stream:
  67. return self._process_stream_response()
  68. else:
  69. return self._process_blocking_response()
  70. def _process_blocking_response(self) -> dict:
  71. """
  72. Process blocking response.
  73. :return:
  74. """
  75. for queue_message in self._queue_manager.listen():
  76. event = queue_message.event
  77. if isinstance(event, QueueErrorEvent):
  78. raise self._handle_error(event)
  79. elif isinstance(event, QueueRetrieverResourcesEvent):
  80. self._task_state.metadata['retriever_resources'] = event.retriever_resources
  81. elif isinstance(event, AnnotationReplyEvent):
  82. annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
  83. if annotation:
  84. account = annotation.account
  85. self._task_state.metadata['annotation_reply'] = {
  86. 'id': annotation.id,
  87. 'account': {
  88. 'id': annotation.account_id,
  89. 'name': account.name if account else 'Dify user'
  90. }
  91. }
  92. self._task_state.llm_result.message.content = annotation.content
  93. elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
  94. if isinstance(event, QueueMessageEndEvent):
  95. self._task_state.llm_result = event.llm_result
  96. else:
  97. model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
  98. model = model_config.model
  99. model_type_instance = model_config.provider_model_bundle.model_type_instance
  100. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  101. # calculate num tokens
  102. prompt_tokens = 0
  103. if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
  104. prompt_tokens = model_type_instance.get_num_tokens(
  105. model,
  106. model_config.credentials,
  107. self._task_state.llm_result.prompt_messages
  108. )
  109. completion_tokens = 0
  110. if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
  111. completion_tokens = model_type_instance.get_num_tokens(
  112. model,
  113. model_config.credentials,
  114. [self._task_state.llm_result.message]
  115. )
  116. credentials = model_config.credentials
  117. # transform usage
  118. self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
  119. model,
  120. credentials,
  121. prompt_tokens,
  122. completion_tokens
  123. )
  124. self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
  125. # response moderation
  126. if self._output_moderation_handler:
  127. self._output_moderation_handler.stop_thread()
  128. self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
  129. completion=self._task_state.llm_result.message.content,
  130. public_event=False
  131. )
  132. # Save message
  133. self._save_message(self._task_state.llm_result)
  134. response = {
  135. 'event': 'message',
  136. 'task_id': self._application_generate_entity.task_id,
  137. 'id': self._message.id,
  138. 'message_id': self._message.id,
  139. 'mode': self._conversation.mode,
  140. 'answer': event.llm_result.message.content,
  141. 'metadata': {},
  142. 'created_at': int(self._message.created_at.timestamp())
  143. }
  144. if self._conversation.mode == 'chat':
  145. response['conversation_id'] = self._conversation.id
  146. if self._task_state.metadata:
  147. response['metadata'] = self._get_response_metadata()
  148. return response
  149. else:
  150. continue
  151. def _process_stream_response(self) -> Generator:
  152. """
  153. Process stream response.
  154. :return:
  155. """
  156. for message in self._queue_manager.listen():
  157. event = message.event
  158. if isinstance(event, QueueErrorEvent):
  159. data = self._error_to_stream_response_data(self._handle_error(event))
  160. yield self._yield_response(data)
  161. break
  162. elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
  163. if isinstance(event, QueueMessageEndEvent):
  164. self._task_state.llm_result = event.llm_result
  165. else:
  166. model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
  167. model = model_config.model
  168. model_type_instance = model_config.provider_model_bundle.model_type_instance
  169. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  170. # calculate num tokens
  171. prompt_tokens = 0
  172. if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
  173. prompt_tokens = model_type_instance.get_num_tokens(
  174. model,
  175. model_config.credentials,
  176. self._task_state.llm_result.prompt_messages
  177. )
  178. completion_tokens = 0
  179. if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
  180. completion_tokens = model_type_instance.get_num_tokens(
  181. model,
  182. model_config.credentials,
  183. [self._task_state.llm_result.message]
  184. )
  185. credentials = model_config.credentials
  186. # transform usage
  187. self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
  188. model,
  189. credentials,
  190. prompt_tokens,
  191. completion_tokens
  192. )
  193. self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
  194. # response moderation
  195. if self._output_moderation_handler:
  196. self._output_moderation_handler.stop_thread()
  197. self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
  198. completion=self._task_state.llm_result.message.content,
  199. public_event=False
  200. )
  201. self._output_moderation_handler = None
  202. replace_response = {
  203. 'event': 'message_replace',
  204. 'task_id': self._application_generate_entity.task_id,
  205. 'message_id': self._message.id,
  206. 'answer': self._task_state.llm_result.message.content,
  207. 'created_at': int(self._message.created_at.timestamp())
  208. }
  209. if self._conversation.mode == 'chat':
  210. replace_response['conversation_id'] = self._conversation.id
  211. yield self._yield_response(replace_response)
  212. # Save message
  213. self._save_message(self._task_state.llm_result)
  214. response = {
  215. 'event': 'message_end',
  216. 'task_id': self._application_generate_entity.task_id,
  217. 'id': self._message.id,
  218. 'message_id': self._message.id,
  219. }
  220. if self._conversation.mode == 'chat':
  221. response['conversation_id'] = self._conversation.id
  222. if self._task_state.metadata:
  223. response['metadata'] = self._get_response_metadata()
  224. yield self._yield_response(response)
  225. elif isinstance(event, QueueRetrieverResourcesEvent):
  226. self._task_state.metadata['retriever_resources'] = event.retriever_resources
  227. elif isinstance(event, AnnotationReplyEvent):
  228. annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
  229. if annotation:
  230. account = annotation.account
  231. self._task_state.metadata['annotation_reply'] = {
  232. 'id': annotation.id,
  233. 'account': {
  234. 'id': annotation.account_id,
  235. 'name': account.name if account else 'Dify user'
  236. }
  237. }
  238. self._task_state.llm_result.message.content = annotation.content
  239. elif isinstance(event, QueueAgentThoughtEvent):
  240. agent_thought = (
  241. db.session.query(MessageAgentThought)
  242. .filter(MessageAgentThought.id == event.agent_thought_id)
  243. .first()
  244. )
  245. if agent_thought:
  246. response = {
  247. 'event': 'agent_thought',
  248. 'id': agent_thought.id,
  249. 'task_id': self._application_generate_entity.task_id,
  250. 'message_id': self._message.id,
  251. 'position': agent_thought.position,
  252. 'thought': agent_thought.thought,
  253. 'tool': agent_thought.tool,
  254. 'tool_input': agent_thought.tool_input,
  255. 'created_at': int(self._message.created_at.timestamp())
  256. }
  257. if self._conversation.mode == 'chat':
  258. response['conversation_id'] = self._conversation.id
  259. yield self._yield_response(response)
  260. elif isinstance(event, QueueMessageEvent):
  261. chunk = event.chunk
  262. delta_text = chunk.delta.message.content
  263. if delta_text is None:
  264. continue
  265. if not self._task_state.llm_result.prompt_messages:
  266. self._task_state.llm_result.prompt_messages = chunk.prompt_messages
  267. if self._output_moderation_handler:
  268. if self._output_moderation_handler.should_direct_output():
  269. # stop subscribe new token when output moderation should direct output
  270. self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
  271. self._queue_manager.publish_chunk_message(LLMResultChunk(
  272. model=self._task_state.llm_result.model,
  273. prompt_messages=self._task_state.llm_result.prompt_messages,
  274. delta=LLMResultChunkDelta(
  275. index=0,
  276. message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
  277. )
  278. ), PublishFrom.TASK_PIPELINE)
  279. self._queue_manager.publish(
  280. QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
  281. PublishFrom.TASK_PIPELINE
  282. )
  283. continue
  284. else:
  285. self._output_moderation_handler.append_new_token(delta_text)
  286. self._task_state.llm_result.message.content += delta_text
  287. response = self._handle_chunk(delta_text)
  288. yield self._yield_response(response)
  289. elif isinstance(event, QueueMessageReplaceEvent):
  290. response = {
  291. 'event': 'message_replace',
  292. 'task_id': self._application_generate_entity.task_id,
  293. 'message_id': self._message.id,
  294. 'answer': event.text,
  295. 'created_at': int(self._message.created_at.timestamp())
  296. }
  297. if self._conversation.mode == 'chat':
  298. response['conversation_id'] = self._conversation.id
  299. yield self._yield_response(response)
  300. elif isinstance(event, QueuePingEvent):
  301. yield "event: ping\n\n"
  302. else:
  303. continue
  304. def _save_message(self, llm_result: LLMResult) -> None:
  305. """
  306. Save message.
  307. :param llm_result: llm result
  308. :return:
  309. """
  310. usage = llm_result.usage
  311. self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
  312. self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
  313. self._message.message_tokens = usage.prompt_tokens
  314. self._message.message_unit_price = usage.prompt_unit_price
  315. self._message.message_price_unit = usage.prompt_price_unit
  316. self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
  317. if llm_result.message.content else ''
  318. self._message.answer_tokens = usage.completion_tokens
  319. self._message.answer_unit_price = usage.completion_unit_price
  320. self._message.answer_price_unit = usage.completion_price_unit
  321. self._message.provider_response_latency = time.perf_counter() - self._start_at
  322. self._message.total_price = usage.total_price
  323. db.session.commit()
  324. message_was_created.send(
  325. self._message,
  326. application_generate_entity=self._application_generate_entity,
  327. conversation=self._conversation,
  328. is_first_message=self._application_generate_entity.conversation_id is None,
  329. extras=self._application_generate_entity.extras
  330. )
  331. def _handle_chunk(self, text: str) -> dict:
  332. """
  333. Handle completed event.
  334. :param text: text
  335. :return:
  336. """
  337. response = {
  338. 'event': 'message',
  339. 'id': self._message.id,
  340. 'task_id': self._application_generate_entity.task_id,
  341. 'message_id': self._message.id,
  342. 'answer': text,
  343. 'created_at': int(self._message.created_at.timestamp())
  344. }
  345. if self._conversation.mode == 'chat':
  346. response['conversation_id'] = self._conversation.id
  347. return response
  348. def _handle_error(self, event: QueueErrorEvent) -> Exception:
  349. """
  350. Handle error event.
  351. :param event: event
  352. :return:
  353. """
  354. logger.debug("error: %s", event.error)
  355. e = event.error
  356. if isinstance(e, InvokeAuthorizationError):
  357. return InvokeAuthorizationError('Incorrect API key provided')
  358. elif isinstance(e, InvokeError) or isinstance(e, ValueError):
  359. return e
  360. else:
  361. return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
  362. def _error_to_stream_response_data(self, e: Exception) -> dict:
  363. """
  364. Error to stream response.
  365. :param e: exception
  366. :return:
  367. """
  368. if isinstance(e, ValueError):
  369. data = {
  370. 'code': 'invalid_param',
  371. 'message': str(e),
  372. 'status': 400
  373. }
  374. elif isinstance(e, ProviderTokenNotInitError):
  375. data = {
  376. 'code': 'provider_not_initialize',
  377. 'message': e.description,
  378. 'status': 400
  379. }
  380. elif isinstance(e, QuotaExceededError):
  381. data = {
  382. 'code': 'provider_quota_exceeded',
  383. 'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
  384. "Please go to Settings -> Model Provider to complete your own provider credentials.",
  385. 'status': 400
  386. }
  387. elif isinstance(e, ModelCurrentlyNotSupportError):
  388. data = {
  389. 'code': 'model_currently_not_support',
  390. 'message': e.description,
  391. 'status': 400
  392. }
  393. elif isinstance(e, InvokeError):
  394. data = {
  395. 'code': 'completion_request_error',
  396. 'message': e.description,
  397. 'status': 400
  398. }
  399. else:
  400. logging.error(e)
  401. data = {
  402. 'code': 'internal_server_error',
  403. 'message': 'Internal Server Error, please contact support.',
  404. 'status': 500
  405. }
  406. return {
  407. 'event': 'error',
  408. 'task_id': self._application_generate_entity.task_id,
  409. 'message_id': self._message.id,
  410. **data
  411. }
  412. def _get_response_metadata(self) -> dict:
  413. """
  414. Get response metadata by invoke from.
  415. :return:
  416. """
  417. metadata = {}
  418. # show_retrieve_source
  419. if 'retriever_resources' in self._task_state.metadata:
  420. if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
  421. metadata['retriever_resources'] = self._task_state.metadata['retriever_resources']
  422. else:
  423. metadata['retriever_resources'] = []
  424. for resource in self._task_state.metadata['retriever_resources']:
  425. metadata['retriever_resources'].append({
  426. 'segment_id': resource['segment_id'],
  427. 'position': resource['position'],
  428. 'document_name': resource['document_name'],
  429. 'score': resource['score'],
  430. 'content': resource['content'],
  431. })
  432. # show usage
  433. if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
  434. metadata['usage'] = self._task_state.metadata['usage']
  435. return metadata
  436. def _yield_response(self, response: dict) -> str:
  437. """
  438. Yield response.
  439. :param response: response
  440. :return:
  441. """
  442. return "data: " + json.dumps(response) + "\n\n"
  443. def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]:
  444. """
  445. Prompt messages to prompt for saving.
  446. :param prompt_messages: prompt messages
  447. :return:
  448. """
  449. prompts = []
  450. if self._application_generate_entity.app_orchestration_config_entity.model_config.mode == 'chat':
  451. for prompt_message in prompt_messages:
  452. if prompt_message.role == PromptMessageRole.USER:
  453. role = 'user'
  454. elif prompt_message.role == PromptMessageRole.ASSISTANT:
  455. role = 'assistant'
  456. elif prompt_message.role == PromptMessageRole.SYSTEM:
  457. role = 'system'
  458. else:
  459. continue
  460. text = ''
  461. files = []
  462. if isinstance(prompt_message.content, list):
  463. for content in prompt_message.content:
  464. if content.type == PromptMessageContentType.TEXT:
  465. content = cast(TextPromptMessageContent, content)
  466. text += content.data
  467. else:
  468. content = cast(ImagePromptMessageContent, content)
  469. files.append({
  470. "type": 'image',
  471. "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
  472. "detail": content.detail.value
  473. })
  474. else:
  475. text = prompt_message.content
  476. prompts.append({
  477. "role": role,
  478. "text": text,
  479. "files": files
  480. })
  481. else:
  482. prompt_message = prompt_messages[0]
  483. text = ''
  484. files = []
  485. if isinstance(prompt_message.content, list):
  486. for content in prompt_message.content:
  487. if content.type == PromptMessageContentType.TEXT:
  488. content = cast(TextPromptMessageContent, content)
  489. text += content.data
  490. else:
  491. content = cast(ImagePromptMessageContent, content)
  492. files.append({
  493. "type": 'image',
  494. "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
  495. "detail": content.detail.value
  496. })
  497. else:
  498. text = prompt_message.content
  499. params = {
  500. "role": 'user',
  501. "text": text,
  502. }
  503. if files:
  504. params['files'] = files
  505. prompts.append(params)
  506. return prompts
  507. def _init_output_moderation(self) -> Optional[OutputModerationHandler]:
  508. """
  509. Init output moderation.
  510. :return:
  511. """
  512. app_orchestration_config_entity = self._application_generate_entity.app_orchestration_config_entity
  513. sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance
  514. if sensitive_word_avoidance:
  515. return OutputModerationHandler(
  516. tenant_id=self._application_generate_entity.tenant_id,
  517. app_id=self._application_generate_entity.app_id,
  518. rule=ModerationRule(
  519. type=sensitive_word_avoidance.type,
  520. config=sensitive_word_avoidance.config
  521. ),
  522. on_message_replace_func=self._queue_manager.publish_message_replace
  523. )