generate_task_pipeline.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509
  1. import json
  2. import logging
  3. import time
  4. from typing import Generator, Optional, Union, cast
  5. from core.app_runner.moderation_handler import ModerationRule, OutputModerationHandler
  6. from core.application_queue_manager import ApplicationQueueManager, PublishFrom
  7. from core.entities.application_entities import ApplicationGenerateEntity
  8. from core.entities.queue_entities import (AnnotationReplyEvent, QueueAgentThoughtEvent, QueueErrorEvent,
  9. QueueMessageEndEvent, QueueMessageEvent, QueueMessageReplaceEvent,
  10. QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent)
  11. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
  12. from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent,
  13. PromptMessage, PromptMessageContentType, PromptMessageRole,
  14. TextPromptMessageContent)
  15. from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
  16. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  17. from core.prompt.prompt_template import PromptTemplateParser
  18. from events.message_event import message_was_created
  19. from extensions.ext_database import db
  20. from models.model import Conversation, Message, MessageAgentThought
  21. from pydantic import BaseModel
  22. from services.annotation_service import AppAnnotationService
  23. logger = logging.getLogger(__name__)
  24. class TaskState(BaseModel):
  25. """
  26. TaskState entity
  27. """
  28. llm_result: LLMResult
  29. metadata: dict = {}
  30. class GenerateTaskPipeline:
  31. """
  32. GenerateTaskPipeline is a class that generate stream output and state management for Application.
  33. """
  34. def __init__(self, application_generate_entity: ApplicationGenerateEntity,
  35. queue_manager: ApplicationQueueManager,
  36. conversation: Conversation,
  37. message: Message) -> None:
  38. """
  39. Initialize GenerateTaskPipeline.
  40. :param application_generate_entity: application generate entity
  41. :param queue_manager: queue manager
  42. :param conversation: conversation
  43. :param message: message
  44. """
  45. self._application_generate_entity = application_generate_entity
  46. self._queue_manager = queue_manager
  47. self._conversation = conversation
  48. self._message = message
  49. self._task_state = TaskState(
  50. llm_result=LLMResult(
  51. model=self._application_generate_entity.app_orchestration_config_entity.model_config.model,
  52. prompt_messages=[],
  53. message=AssistantPromptMessage(content=""),
  54. usage=LLMUsage.empty_usage()
  55. )
  56. )
  57. self._start_at = time.perf_counter()
  58. self._output_moderation_handler = self._init_output_moderation()
  59. def process(self, stream: bool) -> Union[dict, Generator]:
  60. """
  61. Process generate task pipeline.
  62. :return:
  63. """
  64. if stream:
  65. return self._process_stream_response()
  66. else:
  67. return self._process_blocking_response()
  68. def _process_blocking_response(self) -> dict:
  69. """
  70. Process blocking response.
  71. :return:
  72. """
  73. for queue_message in self._queue_manager.listen():
  74. event = queue_message.event
  75. if isinstance(event, QueueErrorEvent):
  76. raise self._handle_error(event)
  77. elif isinstance(event, QueueRetrieverResourcesEvent):
  78. self._task_state.metadata['retriever_resources'] = event.retriever_resources
  79. elif isinstance(event, AnnotationReplyEvent):
  80. annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
  81. if annotation:
  82. account = annotation.account
  83. self._task_state.metadata['annotation_reply'] = {
  84. 'id': annotation.id,
  85. 'account': {
  86. 'id': annotation.account_id,
  87. 'name': account.name if account else 'Dify user'
  88. }
  89. }
  90. self._task_state.llm_result.message.content = annotation.content
  91. elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
  92. if isinstance(event, QueueMessageEndEvent):
  93. self._task_state.llm_result = event.llm_result
  94. else:
  95. model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
  96. model = model_config.model
  97. model_type_instance = model_config.provider_model_bundle.model_type_instance
  98. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  99. # calculate num tokens
  100. prompt_tokens = 0
  101. if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
  102. prompt_tokens = model_type_instance.get_num_tokens(
  103. model,
  104. model_config.credentials,
  105. self._task_state.llm_result.prompt_messages
  106. )
  107. completion_tokens = 0
  108. if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
  109. completion_tokens = model_type_instance.get_num_tokens(
  110. model,
  111. model_config.credentials,
  112. [self._task_state.llm_result.message]
  113. )
  114. credentials = model_config.credentials
  115. # transform usage
  116. self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
  117. model,
  118. credentials,
  119. prompt_tokens,
  120. completion_tokens
  121. )
  122. # response moderation
  123. if self._output_moderation_handler:
  124. self._output_moderation_handler.stop_thread()
  125. self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
  126. completion=self._task_state.llm_result.message.content,
  127. public_event=False
  128. )
  129. # Save message
  130. self._save_message(event.llm_result)
  131. response = {
  132. 'event': 'message',
  133. 'task_id': self._application_generate_entity.task_id,
  134. 'id': self._message.id,
  135. 'mode': self._conversation.mode,
  136. 'answer': event.llm_result.message.content,
  137. 'metadata': {},
  138. 'created_at': int(self._message.created_at.timestamp())
  139. }
  140. if self._conversation.mode == 'chat':
  141. response['conversation_id'] = self._conversation.id
  142. if self._task_state.metadata:
  143. response['metadata'] = self._task_state.metadata
  144. return response
  145. else:
  146. continue
  147. def _process_stream_response(self) -> Generator:
  148. """
  149. Process stream response.
  150. :return:
  151. """
  152. for message in self._queue_manager.listen():
  153. event = message.event
  154. if isinstance(event, QueueErrorEvent):
  155. raise self._handle_error(event)
  156. elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
  157. if isinstance(event, QueueMessageEndEvent):
  158. self._task_state.llm_result = event.llm_result
  159. else:
  160. model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
  161. model = model_config.model
  162. model_type_instance = model_config.provider_model_bundle.model_type_instance
  163. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  164. # calculate num tokens
  165. prompt_tokens = 0
  166. if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
  167. prompt_tokens = model_type_instance.get_num_tokens(
  168. model,
  169. model_config.credentials,
  170. self._task_state.llm_result.prompt_messages
  171. )
  172. completion_tokens = 0
  173. if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
  174. completion_tokens = model_type_instance.get_num_tokens(
  175. model,
  176. model_config.credentials,
  177. [self._task_state.llm_result.message]
  178. )
  179. credentials = model_config.credentials
  180. # transform usage
  181. self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
  182. model,
  183. credentials,
  184. prompt_tokens,
  185. completion_tokens
  186. )
  187. # response moderation
  188. if self._output_moderation_handler:
  189. self._output_moderation_handler.stop_thread()
  190. self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
  191. completion=self._task_state.llm_result.message.content,
  192. public_event=False
  193. )
  194. self._output_moderation_handler = None
  195. replace_response = {
  196. 'event': 'message_replace',
  197. 'task_id': self._application_generate_entity.task_id,
  198. 'message_id': self._message.id,
  199. 'answer': self._task_state.llm_result.message.content,
  200. 'created_at': int(self._message.created_at.timestamp())
  201. }
  202. if self._conversation.mode == 'chat':
  203. replace_response['conversation_id'] = self._conversation.id
  204. yield self._yield_response(replace_response)
  205. # Save message
  206. self._save_message(self._task_state.llm_result)
  207. response = {
  208. 'event': 'message_end',
  209. 'task_id': self._application_generate_entity.task_id,
  210. 'id': self._message.id,
  211. }
  212. if self._conversation.mode == 'chat':
  213. response['conversation_id'] = self._conversation.id
  214. if self._task_state.metadata:
  215. response['metadata'] = self._task_state.metadata
  216. yield self._yield_response(response)
  217. elif isinstance(event, QueueRetrieverResourcesEvent):
  218. self._task_state.metadata['retriever_resources'] = event.retriever_resources
  219. elif isinstance(event, AnnotationReplyEvent):
  220. annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
  221. if annotation:
  222. account = annotation.account
  223. self._task_state.metadata['annotation_reply'] = {
  224. 'id': annotation.id,
  225. 'account': {
  226. 'id': annotation.account_id,
  227. 'name': account.name if account else 'Dify user'
  228. }
  229. }
  230. self._task_state.llm_result.message.content = annotation.content
  231. elif isinstance(event, QueueAgentThoughtEvent):
  232. agent_thought = (
  233. db.session.query(MessageAgentThought)
  234. .filter(MessageAgentThought.id == event.agent_thought_id)
  235. .first()
  236. )
  237. if agent_thought:
  238. response = {
  239. 'event': 'agent_thought',
  240. 'id': agent_thought.id,
  241. 'task_id': self._application_generate_entity.task_id,
  242. 'message_id': self._message.id,
  243. 'position': agent_thought.position,
  244. 'thought': agent_thought.thought,
  245. 'tool': agent_thought.tool,
  246. 'tool_input': agent_thought.tool_input,
  247. 'created_at': int(self._message.created_at.timestamp())
  248. }
  249. if self._conversation.mode == 'chat':
  250. response['conversation_id'] = self._conversation.id
  251. yield self._yield_response(response)
  252. elif isinstance(event, QueueMessageEvent):
  253. chunk = event.chunk
  254. delta_text = chunk.delta.message.content
  255. if delta_text is None:
  256. continue
  257. if not self._task_state.llm_result.prompt_messages:
  258. self._task_state.llm_result.prompt_messages = chunk.prompt_messages
  259. if self._output_moderation_handler:
  260. if self._output_moderation_handler.should_direct_output():
  261. # stop subscribe new token when output moderation should direct output
  262. self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
  263. self._queue_manager.publish_chunk_message(LLMResultChunk(
  264. model=self._task_state.llm_result.model,
  265. prompt_messages=self._task_state.llm_result.prompt_messages,
  266. delta=LLMResultChunkDelta(
  267. index=0,
  268. message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
  269. )
  270. ), PublishFrom.TASK_PIPELINE)
  271. self._queue_manager.publish(
  272. QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
  273. PublishFrom.TASK_PIPELINE
  274. )
  275. continue
  276. else:
  277. self._output_moderation_handler.append_new_token(delta_text)
  278. self._task_state.llm_result.message.content += delta_text
  279. response = self._handle_chunk(delta_text)
  280. yield self._yield_response(response)
  281. elif isinstance(event, QueueMessageReplaceEvent):
  282. response = {
  283. 'event': 'message_replace',
  284. 'task_id': self._application_generate_entity.task_id,
  285. 'message_id': self._message.id,
  286. 'answer': event.text,
  287. 'created_at': int(self._message.created_at.timestamp())
  288. }
  289. if self._conversation.mode == 'chat':
  290. response['conversation_id'] = self._conversation.id
  291. yield self._yield_response(response)
  292. elif isinstance(event, QueuePingEvent):
  293. yield "event: ping\n\n"
  294. else:
  295. continue
  296. def _save_message(self, llm_result: LLMResult) -> None:
  297. """
  298. Save message.
  299. :param llm_result: llm result
  300. :return:
  301. """
  302. usage = llm_result.usage
  303. self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
  304. self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
  305. self._message.message_tokens = usage.prompt_tokens
  306. self._message.message_unit_price = usage.prompt_unit_price
  307. self._message.message_price_unit = usage.prompt_price_unit
  308. self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
  309. if llm_result.message.content else ''
  310. self._message.answer_tokens = usage.completion_tokens
  311. self._message.answer_unit_price = usage.completion_unit_price
  312. self._message.answer_price_unit = usage.completion_price_unit
  313. self._message.provider_response_latency = time.perf_counter() - self._start_at
  314. self._message.total_price = usage.total_price
  315. db.session.commit()
  316. message_was_created.send(
  317. self._message,
  318. application_generate_entity=self._application_generate_entity,
  319. conversation=self._conversation,
  320. is_first_message=self._application_generate_entity.conversation_id is None,
  321. extras=self._application_generate_entity.extras
  322. )
  323. def _handle_chunk(self, text: str) -> dict:
  324. """
  325. Handle completed event.
  326. :param text: text
  327. :return:
  328. """
  329. response = {
  330. 'event': 'message',
  331. 'id': self._message.id,
  332. 'task_id': self._application_generate_entity.task_id,
  333. 'message_id': self._message.id,
  334. 'answer': text,
  335. 'created_at': int(self._message.created_at.timestamp())
  336. }
  337. if self._conversation.mode == 'chat':
  338. response['conversation_id'] = self._conversation.id
  339. return response
  340. def _handle_error(self, event: QueueErrorEvent) -> Exception:
  341. """
  342. Handle error event.
  343. :param event: event
  344. :return:
  345. """
  346. logger.debug("error: %s", event.error)
  347. e = event.error
  348. if isinstance(e, InvokeAuthorizationError):
  349. return InvokeAuthorizationError('Incorrect API key provided')
  350. elif isinstance(e, InvokeError) or isinstance(e, ValueError):
  351. return e
  352. else:
  353. return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
  354. def _yield_response(self, response: dict) -> str:
  355. """
  356. Yield response.
  357. :param response: response
  358. :return:
  359. """
  360. return "data: " + json.dumps(response) + "\n\n"
  361. def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]:
  362. """
  363. Prompt messages to prompt for saving.
  364. :param prompt_messages: prompt messages
  365. :return:
  366. """
  367. prompts = []
  368. if self._application_generate_entity.app_orchestration_config_entity.model_config.mode == 'chat':
  369. for prompt_message in prompt_messages:
  370. if prompt_message.role == PromptMessageRole.USER:
  371. role = 'user'
  372. elif prompt_message.role == PromptMessageRole.ASSISTANT:
  373. role = 'assistant'
  374. elif prompt_message.role == PromptMessageRole.SYSTEM:
  375. role = 'system'
  376. else:
  377. continue
  378. text = ''
  379. files = []
  380. if isinstance(prompt_message.content, list):
  381. for content in prompt_message.content:
  382. if content.type == PromptMessageContentType.TEXT:
  383. content = cast(TextPromptMessageContent, content)
  384. text += content.data
  385. else:
  386. content = cast(ImagePromptMessageContent, content)
  387. files.append({
  388. "type": 'image',
  389. "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
  390. "detail": content.detail.value
  391. })
  392. else:
  393. text = prompt_message.content
  394. prompts.append({
  395. "role": role,
  396. "text": text,
  397. "files": files
  398. })
  399. else:
  400. prompt_message = prompt_messages[0]
  401. text = ''
  402. files = []
  403. if isinstance(prompt_message.content, list):
  404. for content in prompt_message.content:
  405. if content.type == PromptMessageContentType.TEXT:
  406. content = cast(TextPromptMessageContent, content)
  407. text += content.data
  408. else:
  409. content = cast(ImagePromptMessageContent, content)
  410. files.append({
  411. "type": 'image',
  412. "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
  413. "detail": content.detail.value
  414. })
  415. else:
  416. text = prompt_message.content
  417. params = {
  418. "role": 'user',
  419. "text": text,
  420. }
  421. if files:
  422. params['files'] = files
  423. prompts.append(params)
  424. return prompts
  425. def _init_output_moderation(self) -> Optional[OutputModerationHandler]:
  426. """
  427. Init output moderation.
  428. :return:
  429. """
  430. app_orchestration_config_entity = self._application_generate_entity.app_orchestration_config_entity
  431. sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance
  432. if sensitive_word_avoidance:
  433. return OutputModerationHandler(
  434. tenant_id=self._application_generate_entity.tenant_id,
  435. app_id=self._application_generate_entity.app_id,
  436. rule=ModerationRule(
  437. type=sensitive_word_avoidance.type,
  438. config=sensitive_word_avoidance.config
  439. ),
  440. on_message_replace_func=self._queue_manager.publish_message_replace
  441. )