generate_task_pipeline.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. import json
  2. import logging
  3. import time
  4. from typing import Union, Generator, cast, Optional
  5. from pydantic import BaseModel
  6. from core.app_runner.moderation_handler import OutputModerationHandler, ModerationRule
  7. from core.entities.application_entities import ApplicationGenerateEntity
  8. from core.application_queue_manager import ApplicationQueueManager, PublishFrom
  9. from core.entities.queue_entities import QueueErrorEvent, QueueStopEvent, QueueMessageEndEvent, \
  10. QueueRetrieverResourcesEvent, QueueAgentThoughtEvent, QueuePingEvent, QueueMessageEvent, QueueMessageReplaceEvent, \
  11. AnnotationReplyEvent
  12. from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta
  13. from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessageRole, \
  14. TextPromptMessageContent, PromptMessageContentType, ImagePromptMessageContent, PromptMessage
  15. from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError
  16. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  17. from core.prompt.prompt_template import PromptTemplateParser
  18. from events.message_event import message_was_created
  19. from extensions.ext_database import db
  20. from models.model import Message, Conversation, MessageAgentThought
  21. from services.annotation_service import AppAnnotationService
  22. logger = logging.getLogger(__name__)
  23. class TaskState(BaseModel):
  24. """
  25. TaskState entity
  26. """
  27. llm_result: LLMResult
  28. metadata: dict = {}
  29. class GenerateTaskPipeline:
  30. """
  31. GenerateTaskPipeline is a class that generate stream output and state management for Application.
  32. """
  33. def __init__(self, application_generate_entity: ApplicationGenerateEntity,
  34. queue_manager: ApplicationQueueManager,
  35. conversation: Conversation,
  36. message: Message) -> None:
  37. """
  38. Initialize GenerateTaskPipeline.
  39. :param application_generate_entity: application generate entity
  40. :param queue_manager: queue manager
  41. :param conversation: conversation
  42. :param message: message
  43. """
  44. self._application_generate_entity = application_generate_entity
  45. self._queue_manager = queue_manager
  46. self._conversation = conversation
  47. self._message = message
  48. self._task_state = TaskState(
  49. llm_result=LLMResult(
  50. model=self._application_generate_entity.app_orchestration_config_entity.model_config.model,
  51. prompt_messages=[],
  52. message=AssistantPromptMessage(content=""),
  53. usage=LLMUsage.empty_usage()
  54. )
  55. )
  56. self._start_at = time.perf_counter()
  57. self._output_moderation_handler = self._init_output_moderation()
  58. def process(self, stream: bool) -> Union[dict, Generator]:
  59. """
  60. Process generate task pipeline.
  61. :return:
  62. """
  63. if stream:
  64. return self._process_stream_response()
  65. else:
  66. return self._process_blocking_response()
  67. def _process_blocking_response(self) -> dict:
  68. """
  69. Process blocking response.
  70. :return:
  71. """
  72. for queue_message in self._queue_manager.listen():
  73. event = queue_message.event
  74. if isinstance(event, QueueErrorEvent):
  75. raise self._handle_error(event)
  76. elif isinstance(event, QueueRetrieverResourcesEvent):
  77. self._task_state.metadata['retriever_resources'] = event.retriever_resources
  78. elif isinstance(event, AnnotationReplyEvent):
  79. annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
  80. if annotation:
  81. account = annotation.account
  82. self._task_state.metadata['annotation_reply'] = {
  83. 'id': annotation.id,
  84. 'account': {
  85. 'id': annotation.account_id,
  86. 'name': account.name if account else 'Dify user'
  87. }
  88. }
  89. self._task_state.llm_result.message.content = annotation.content
  90. elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
  91. if isinstance(event, QueueMessageEndEvent):
  92. self._task_state.llm_result = event.llm_result
  93. else:
  94. model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
  95. model = model_config.model
  96. model_type_instance = model_config.provider_model_bundle.model_type_instance
  97. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  98. # calculate num tokens
  99. prompt_tokens = 0
  100. if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
  101. prompt_tokens = model_type_instance.get_num_tokens(
  102. model,
  103. model_config.credentials,
  104. self._task_state.llm_result.prompt_messages
  105. )
  106. completion_tokens = 0
  107. if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
  108. completion_tokens = model_type_instance.get_num_tokens(
  109. model,
  110. model_config.credentials,
  111. [self._task_state.llm_result.message]
  112. )
  113. credentials = model_config.credentials
  114. # transform usage
  115. self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
  116. model,
  117. credentials,
  118. prompt_tokens,
  119. completion_tokens
  120. )
  121. # response moderation
  122. if self._output_moderation_handler:
  123. self._output_moderation_handler.stop_thread()
  124. self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
  125. completion=self._task_state.llm_result.message.content,
  126. public_event=False
  127. )
  128. # Save message
  129. self._save_message(event.llm_result)
  130. response = {
  131. 'event': 'message',
  132. 'task_id': self._application_generate_entity.task_id,
  133. 'id': self._message.id,
  134. 'mode': self._conversation.mode,
  135. 'answer': event.llm_result.message.content,
  136. 'metadata': {},
  137. 'created_at': int(self._message.created_at.timestamp())
  138. }
  139. if self._conversation.mode == 'chat':
  140. response['conversation_id'] = self._conversation.id
  141. if self._task_state.metadata:
  142. response['metadata'] = self._task_state.metadata
  143. return response
  144. else:
  145. continue
  146. def _process_stream_response(self) -> Generator:
  147. """
  148. Process stream response.
  149. :return:
  150. """
  151. for message in self._queue_manager.listen():
  152. event = message.event
  153. if isinstance(event, QueueErrorEvent):
  154. raise self._handle_error(event)
  155. elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
  156. if isinstance(event, QueueMessageEndEvent):
  157. self._task_state.llm_result = event.llm_result
  158. else:
  159. model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
  160. model = model_config.model
  161. model_type_instance = model_config.provider_model_bundle.model_type_instance
  162. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  163. # calculate num tokens
  164. prompt_tokens = 0
  165. if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
  166. prompt_tokens = model_type_instance.get_num_tokens(
  167. model,
  168. model_config.credentials,
  169. self._task_state.llm_result.prompt_messages
  170. )
  171. completion_tokens = 0
  172. if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
  173. completion_tokens = model_type_instance.get_num_tokens(
  174. model,
  175. model_config.credentials,
  176. [self._task_state.llm_result.message]
  177. )
  178. credentials = model_config.credentials
  179. # transform usage
  180. self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
  181. model,
  182. credentials,
  183. prompt_tokens,
  184. completion_tokens
  185. )
  186. # response moderation
  187. if self._output_moderation_handler:
  188. self._output_moderation_handler.stop_thread()
  189. self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
  190. completion=self._task_state.llm_result.message.content,
  191. public_event=False
  192. )
  193. self._output_moderation_handler = None
  194. replace_response = {
  195. 'event': 'message_replace',
  196. 'task_id': self._application_generate_entity.task_id,
  197. 'message_id': self._message.id,
  198. 'answer': self._task_state.llm_result.message.content,
  199. 'created_at': int(self._message.created_at.timestamp())
  200. }
  201. if self._conversation.mode == 'chat':
  202. replace_response['conversation_id'] = self._conversation.id
  203. yield self._yield_response(replace_response)
  204. # Save message
  205. self._save_message(self._task_state.llm_result)
  206. response = {
  207. 'event': 'message_end',
  208. 'task_id': self._application_generate_entity.task_id,
  209. 'id': self._message.id,
  210. }
  211. if self._conversation.mode == 'chat':
  212. response['conversation_id'] = self._conversation.id
  213. if self._task_state.metadata:
  214. response['metadata'] = self._task_state.metadata
  215. yield self._yield_response(response)
  216. elif isinstance(event, QueueRetrieverResourcesEvent):
  217. self._task_state.metadata['retriever_resources'] = event.retriever_resources
  218. elif isinstance(event, AnnotationReplyEvent):
  219. annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
  220. if annotation:
  221. account = annotation.account
  222. self._task_state.metadata['annotation_reply'] = {
  223. 'id': annotation.id,
  224. 'account': {
  225. 'id': annotation.account_id,
  226. 'name': account.name if account else 'Dify user'
  227. }
  228. }
  229. self._task_state.llm_result.message.content = annotation.content
  230. elif isinstance(event, QueueAgentThoughtEvent):
  231. agent_thought = (
  232. db.session.query(MessageAgentThought)
  233. .filter(MessageAgentThought.id == event.agent_thought_id)
  234. .first()
  235. )
  236. if agent_thought:
  237. response = {
  238. 'event': 'agent_thought',
  239. 'id': agent_thought.id,
  240. 'task_id': self._application_generate_entity.task_id,
  241. 'message_id': self._message.id,
  242. 'position': agent_thought.position,
  243. 'thought': agent_thought.thought,
  244. 'tool': agent_thought.tool,
  245. 'tool_input': agent_thought.tool_input,
  246. 'created_at': int(self._message.created_at.timestamp())
  247. }
  248. if self._conversation.mode == 'chat':
  249. response['conversation_id'] = self._conversation.id
  250. yield self._yield_response(response)
  251. elif isinstance(event, QueueMessageEvent):
  252. chunk = event.chunk
  253. delta_text = chunk.delta.message.content
  254. if delta_text is None:
  255. continue
  256. if not self._task_state.llm_result.prompt_messages:
  257. self._task_state.llm_result.prompt_messages = chunk.prompt_messages
  258. if self._output_moderation_handler:
  259. if self._output_moderation_handler.should_direct_output():
  260. # stop subscribe new token when output moderation should direct output
  261. self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
  262. self._queue_manager.publish_chunk_message(LLMResultChunk(
  263. model=self._task_state.llm_result.model,
  264. prompt_messages=self._task_state.llm_result.prompt_messages,
  265. delta=LLMResultChunkDelta(
  266. index=0,
  267. message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
  268. )
  269. ), PublishFrom.TASK_PIPELINE)
  270. self._queue_manager.publish(
  271. QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
  272. PublishFrom.TASK_PIPELINE
  273. )
  274. continue
  275. else:
  276. self._output_moderation_handler.append_new_token(delta_text)
  277. self._task_state.llm_result.message.content += delta_text
  278. response = self._handle_chunk(delta_text)
  279. yield self._yield_response(response)
  280. elif isinstance(event, QueueMessageReplaceEvent):
  281. response = {
  282. 'event': 'message_replace',
  283. 'task_id': self._application_generate_entity.task_id,
  284. 'message_id': self._message.id,
  285. 'answer': event.text,
  286. 'created_at': int(self._message.created_at.timestamp())
  287. }
  288. if self._conversation.mode == 'chat':
  289. response['conversation_id'] = self._conversation.id
  290. yield self._yield_response(response)
  291. elif isinstance(event, QueuePingEvent):
  292. yield "event: ping\n\n"
  293. else:
  294. continue
  295. def _save_message(self, llm_result: LLMResult) -> None:
  296. """
  297. Save message.
  298. :param llm_result: llm result
  299. :return:
  300. """
  301. usage = llm_result.usage
  302. self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
  303. self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
  304. self._message.message_tokens = usage.prompt_tokens
  305. self._message.message_unit_price = usage.prompt_unit_price
  306. self._message.message_price_unit = usage.prompt_price_unit
  307. self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
  308. if llm_result.message.content else ''
  309. self._message.answer_tokens = usage.completion_tokens
  310. self._message.answer_unit_price = usage.completion_unit_price
  311. self._message.answer_price_unit = usage.completion_price_unit
  312. self._message.provider_response_latency = time.perf_counter() - self._start_at
  313. self._message.total_price = usage.total_price
  314. db.session.commit()
  315. message_was_created.send(
  316. self._message,
  317. application_generate_entity=self._application_generate_entity,
  318. conversation=self._conversation,
  319. is_first_message=self._application_generate_entity.conversation_id is None,
  320. extras=self._application_generate_entity.extras
  321. )
  322. def _handle_chunk(self, text: str) -> dict:
  323. """
  324. Handle completed event.
  325. :param text: text
  326. :return:
  327. """
  328. response = {
  329. 'event': 'message',
  330. 'id': self._message.id,
  331. 'task_id': self._application_generate_entity.task_id,
  332. 'message_id': self._message.id,
  333. 'answer': text,
  334. 'created_at': int(self._message.created_at.timestamp())
  335. }
  336. if self._conversation.mode == 'chat':
  337. response['conversation_id'] = self._conversation.id
  338. return response
  339. def _handle_error(self, event: QueueErrorEvent) -> Exception:
  340. """
  341. Handle error event.
  342. :param event: event
  343. :return:
  344. """
  345. logger.debug("error: %s", event.error)
  346. e = event.error
  347. if isinstance(e, InvokeAuthorizationError):
  348. return InvokeAuthorizationError('Incorrect API key provided')
  349. elif isinstance(e, InvokeError) or isinstance(e, ValueError):
  350. return e
  351. else:
  352. return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
  353. def _yield_response(self, response: dict) -> str:
  354. """
  355. Yield response.
  356. :param response: response
  357. :return:
  358. """
  359. return "data: " + json.dumps(response) + "\n\n"
  360. def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]:
  361. """
  362. Prompt messages to prompt for saving.
  363. :param prompt_messages: prompt messages
  364. :return:
  365. """
  366. prompts = []
  367. if self._application_generate_entity.app_orchestration_config_entity.model_config.mode == 'chat':
  368. for prompt_message in prompt_messages:
  369. if prompt_message.role == PromptMessageRole.USER:
  370. role = 'user'
  371. elif prompt_message.role == PromptMessageRole.ASSISTANT:
  372. role = 'assistant'
  373. elif prompt_message.role == PromptMessageRole.SYSTEM:
  374. role = 'system'
  375. else:
  376. continue
  377. text = ''
  378. files = []
  379. if isinstance(prompt_message.content, list):
  380. for content in prompt_message.content:
  381. if content.type == PromptMessageContentType.TEXT:
  382. content = cast(TextPromptMessageContent, content)
  383. text += content.data
  384. else:
  385. content = cast(ImagePromptMessageContent, content)
  386. files.append({
  387. "type": 'image',
  388. "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
  389. "detail": content.detail.value
  390. })
  391. else:
  392. text = prompt_message.content
  393. prompts.append({
  394. "role": role,
  395. "text": text,
  396. "files": files
  397. })
  398. else:
  399. prompts.append({
  400. "role": 'user',
  401. "text": prompt_messages[0].content
  402. })
  403. return prompts
  404. def _init_output_moderation(self) -> Optional[OutputModerationHandler]:
  405. """
  406. Init output moderation.
  407. :return:
  408. """
  409. app_orchestration_config_entity = self._application_generate_entity.app_orchestration_config_entity
  410. sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance
  411. if sensitive_word_avoidance:
  412. return OutputModerationHandler(
  413. tenant_id=self._application_generate_entity.tenant_id,
  414. app_id=self._application_generate_entity.app_id,
  415. rule=ModerationRule(
  416. type=sensitive_word_avoidance.type,
  417. config=sensitive_word_avoidance.config
  418. ),
  419. on_message_replace_func=self._queue_manager.publish_message_replace
  420. )