use-workflow-run.ts 21 KB


  1. import { useCallback } from 'react'
  2. import {
  3. getIncomers,
  4. useReactFlow,
  5. useStoreApi,
  6. } from 'reactflow'
  7. import produce from 'immer'
  8. import { v4 as uuidV4 } from 'uuid'
  9. import { usePathname } from 'next/navigation'
  10. import { useWorkflowStore } from '../store'
  11. import { useNodesSyncDraft } from '../hooks'
  12. import type { Node } from '../types'
  13. import {
  14. NodeRunningStatus,
  15. WorkflowRunningStatus,
  16. } from '../types'
  17. import { DEFAULT_ITER_TIMES } from '../constants'
  18. import { useWorkflowUpdate } from './use-workflow-interactions'
  19. import { useStore as useAppStore } from '@/app/components/app/store'
  20. import type { IOtherOptions } from '@/service/base'
  21. import { ssePost } from '@/service/base'
  22. import {
  23. fetchPublishedWorkflow,
  24. stopWorkflowRun,
  25. } from '@/service/workflow'
  26. import { useFeaturesStore } from '@/app/components/base/features/hooks'
  27. import { AudioPlayerManager } from '@/app/components/base/audio-btn/audio.player.manager'
  28. import {
  29. getProcessedFilesFromResponse,
  30. } from '@/app/components/base/file-uploader/utils'
  31. export const useWorkflowRun = () => {
  32. const store = useStoreApi()
  33. const workflowStore = useWorkflowStore()
  34. const reactflow = useReactFlow()
  35. const featuresStore = useFeaturesStore()
  36. const { doSyncWorkflowDraft } = useNodesSyncDraft()
  37. const { handleUpdateWorkflowCanvas } = useWorkflowUpdate()
  38. const pathname = usePathname()
  39. const handleBackupDraft = useCallback(() => {
  40. const {
  41. getNodes,
  42. edges,
  43. } = store.getState()
  44. const { getViewport } = reactflow
  45. const {
  46. backupDraft,
  47. setBackupDraft,
  48. environmentVariables,
  49. } = workflowStore.getState()
  50. const { features } = featuresStore!.getState()
  51. if (!backupDraft) {
  52. setBackupDraft({
  53. nodes: getNodes(),
  54. edges,
  55. viewport: getViewport(),
  56. features,
  57. environmentVariables,
  58. })
  59. doSyncWorkflowDraft()
  60. }
  61. }, [reactflow, workflowStore, store, featuresStore, doSyncWorkflowDraft])
  62. const handleLoadBackupDraft = useCallback(() => {
  63. const {
  64. backupDraft,
  65. setBackupDraft,
  66. setEnvironmentVariables,
  67. } = workflowStore.getState()
  68. if (backupDraft) {
  69. const {
  70. nodes,
  71. edges,
  72. viewport,
  73. features,
  74. environmentVariables,
  75. } = backupDraft
  76. handleUpdateWorkflowCanvas({
  77. nodes,
  78. edges,
  79. viewport,
  80. })
  81. setEnvironmentVariables(environmentVariables)
  82. featuresStore!.setState({ features })
  83. setBackupDraft(undefined)
  84. }
  85. }, [handleUpdateWorkflowCanvas, workflowStore, featuresStore])
  86. const handleRun = useCallback(async (
  87. params: any,
  88. callback?: IOtherOptions,
  89. ) => {
  90. const {
  91. getNodes,
  92. setNodes,
  93. } = store.getState()
  94. const newNodes = produce(getNodes(), (draft) => {
  95. draft.forEach((node) => {
  96. node.data.selected = false
  97. node.data._runningStatus = undefined
  98. })
  99. })
  100. setNodes(newNodes)
  101. await doSyncWorkflowDraft()
  102. const {
  103. onWorkflowStarted,
  104. onWorkflowFinished,
  105. onNodeStarted,
  106. onNodeFinished,
  107. onIterationStart,
  108. onIterationNext,
  109. onIterationFinish,
  110. onError,
  111. ...restCallback
  112. } = callback || {}
  113. workflowStore.setState({ historyWorkflowData: undefined })
  114. const appDetail = useAppStore.getState().appDetail
  115. const workflowContainer = document.getElementById('workflow-container')
  116. const {
  117. clientWidth,
  118. clientHeight,
  119. } = workflowContainer!
  120. let url = ''
  121. if (appDetail?.mode === 'advanced-chat')
  122. url = `/apps/${appDetail.id}/advanced-chat/workflows/draft/run`
  123. if (appDetail?.mode === 'workflow')
  124. url = `/apps/${appDetail.id}/workflows/draft/run`
  125. let prevNodeId = ''
  126. const {
  127. setWorkflowRunningData,
  128. } = workflowStore.getState()
  129. setWorkflowRunningData({
  130. result: {
  131. status: WorkflowRunningStatus.Running,
  132. },
  133. tracing: [],
  134. resultText: '',
  135. })
  136. let ttsUrl = ''
  137. let ttsIsPublic = false
  138. if (params.token) {
  139. ttsUrl = '/text-to-audio'
  140. ttsIsPublic = true
  141. }
  142. else if (params.appId) {
  143. if (pathname.search('explore/installed') > -1)
  144. ttsUrl = `/installed-apps/${params.appId}/text-to-audio`
  145. else
  146. ttsUrl = `/apps/${params.appId}/text-to-audio`
  147. }
  148. const player = AudioPlayerManager.getInstance().getAudioPlayer(ttsUrl, ttsIsPublic, uuidV4(), 'none', 'none', (_: any): any => {})
  149. ssePost(
  150. url,
  151. {
  152. body: params,
  153. },
  154. {
  155. onWorkflowStarted: (params) => {
  156. const { task_id, data } = params
  157. const {
  158. workflowRunningData,
  159. setWorkflowRunningData,
  160. setIterParallelLogMap,
  161. } = workflowStore.getState()
  162. const {
  163. edges,
  164. setEdges,
  165. } = store.getState()
  166. setIterParallelLogMap(new Map())
  167. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  168. draft.task_id = task_id
  169. draft.result = {
  170. ...draft?.result,
  171. ...data,
  172. status: WorkflowRunningStatus.Running,
  173. }
  174. }))
  175. const newEdges = produce(edges, (draft) => {
  176. draft.forEach((edge) => {
  177. edge.data = {
  178. ...edge.data,
  179. _run: false,
  180. }
  181. })
  182. })
  183. setEdges(newEdges)
  184. if (onWorkflowStarted)
  185. onWorkflowStarted(params)
  186. },
  187. onWorkflowFinished: (params) => {
  188. const { data } = params
  189. const {
  190. workflowRunningData,
  191. setWorkflowRunningData,
  192. } = workflowStore.getState()
  193. const isStringOutput = data.outputs && Object.keys(data.outputs).length === 1 && typeof data.outputs[Object.keys(data.outputs)[0]] === 'string'
  194. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  195. draft.result = {
  196. ...draft.result,
  197. ...data,
  198. files: getProcessedFilesFromResponse(data.files || []),
  199. } as any
  200. if (isStringOutput) {
  201. draft.resultTabActive = true
  202. draft.resultText = data.outputs[Object.keys(data.outputs)[0]]
  203. }
  204. }))
  205. prevNodeId = ''
  206. if (onWorkflowFinished)
  207. onWorkflowFinished(params)
  208. },
  209. onError: (params) => {
  210. const {
  211. workflowRunningData,
  212. setWorkflowRunningData,
  213. } = workflowStore.getState()
  214. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  215. draft.result = {
  216. ...draft.result,
  217. status: WorkflowRunningStatus.Failed,
  218. }
  219. }))
  220. if (onError)
  221. onError(params)
  222. },
  223. onNodeStarted: (params) => {
  224. const { data } = params
  225. const {
  226. workflowRunningData,
  227. setWorkflowRunningData,
  228. iterParallelLogMap,
  229. setIterParallelLogMap,
  230. } = workflowStore.getState()
  231. const {
  232. getNodes,
  233. setNodes,
  234. edges,
  235. setEdges,
  236. transform,
  237. } = store.getState()
  238. const nodes = getNodes()
  239. const node = nodes.find(node => node.id === data.node_id)
  240. if (node?.parentId) {
  241. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  242. const tracing = draft.tracing!
  243. const iterations = tracing.find(trace => trace.node_id === node?.parentId)
  244. const currIteration = iterations?.details![node.data.iteration_index] || iterations?.details![iterations.details!.length - 1]
  245. if (!data.parallel_run_id) {
  246. currIteration?.push({
  247. ...data,
  248. status: NodeRunningStatus.Running,
  249. } as any)
  250. }
  251. else {
  252. if (!iterParallelLogMap.has(data.parallel_run_id))
  253. iterParallelLogMap.set(data.parallel_run_id, [{ ...data, status: NodeRunningStatus.Running } as any])
  254. else
  255. iterParallelLogMap.get(data.parallel_run_id)!.push({ ...data, status: NodeRunningStatus.Running } as any)
  256. setIterParallelLogMap(iterParallelLogMap)
  257. if (iterations)
  258. iterations.details = Array.from(iterParallelLogMap.values())
  259. }
  260. }))
  261. }
  262. else {
  263. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  264. draft.tracing!.push({
  265. ...data,
  266. status: NodeRunningStatus.Running,
  267. } as any)
  268. }))
  269. const {
  270. setViewport,
  271. } = reactflow
  272. const currentNodeIndex = nodes.findIndex(node => node.id === data.node_id)
  273. const currentNode = nodes[currentNodeIndex]
  274. const position = currentNode.position
  275. const zoom = transform[2]
  276. if (!currentNode.parentId) {
  277. setViewport({
  278. x: (clientWidth - 400 - currentNode.width! * zoom) / 2 - position.x * zoom,
  279. y: (clientHeight - currentNode.height! * zoom) / 2 - position.y * zoom,
  280. zoom: transform[2],
  281. })
  282. }
  283. const newNodes = produce(nodes, (draft) => {
  284. draft[currentNodeIndex].data._runningStatus = NodeRunningStatus.Running
  285. })
  286. setNodes(newNodes)
  287. const incomeNodesId = getIncomers({ id: data.node_id } as Node, newNodes, edges).filter(node => node.data._runningStatus === NodeRunningStatus.Succeeded).map(node => node.id)
  288. const newEdges = produce(edges, (draft) => {
  289. draft.forEach((edge) => {
  290. if (edge.target === data.node_id && incomeNodesId.includes(edge.source))
  291. edge.data = { ...edge.data, _run: true } as any
  292. })
  293. })
  294. setEdges(newEdges)
  295. }
  296. if (onNodeStarted)
  297. onNodeStarted(params)
  298. },
  299. onNodeFinished: (params) => {
  300. const { data } = params
  301. const {
  302. workflowRunningData,
  303. setWorkflowRunningData,
  304. iterParallelLogMap,
  305. setIterParallelLogMap,
  306. } = workflowStore.getState()
  307. const {
  308. getNodes,
  309. setNodes,
  310. } = store.getState()
  311. const nodes = getNodes()
  312. const nodeParentId = nodes.find(node => node.id === data.node_id)!.parentId
  313. if (nodeParentId) {
  314. if (!data.execution_metadata.parallel_mode_run_id) {
  315. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  316. const tracing = draft.tracing!
  317. const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node
  318. if (iterations && iterations.details) {
  319. const iterationIndex = data.execution_metadata?.iteration_index || 0
  320. if (!iterations.details[iterationIndex])
  321. iterations.details[iterationIndex] = []
  322. const currIteration = iterations.details[iterationIndex]
  323. const nodeIndex = currIteration.findIndex(node =>
  324. node.node_id === data.node_id && (
  325. node.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || node.parallel_id === data.execution_metadata?.parallel_id),
  326. )
  327. if (nodeIndex !== -1) {
  328. currIteration[nodeIndex] = {
  329. ...currIteration[nodeIndex],
  330. ...data,
  331. } as any
  332. }
  333. else {
  334. currIteration.push({
  335. ...data,
  336. } as any)
  337. }
  338. }
  339. }))
  340. }
  341. else {
  342. // open parallel mode
  343. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  344. const tracing = draft.tracing!
  345. const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node
  346. if (iterations && iterations.details) {
  347. const iterRunID = data.execution_metadata?.parallel_mode_run_id
  348. const currIteration = iterParallelLogMap.get(iterRunID)
  349. const nodeIndex = currIteration?.findIndex(node =>
  350. node.node_id === data.node_id && (
  351. node?.parallel_run_id === data.execution_metadata?.parallel_mode_run_id),
  352. )
  353. if (currIteration) {
  354. if (nodeIndex !== undefined && nodeIndex !== -1) {
  355. currIteration[nodeIndex] = {
  356. ...currIteration[nodeIndex],
  357. ...data,
  358. } as any
  359. }
  360. else {
  361. currIteration.push({
  362. ...data,
  363. } as any)
  364. }
  365. }
  366. setIterParallelLogMap(iterParallelLogMap)
  367. iterations.details = Array.from(iterParallelLogMap.values())
  368. }
  369. }))
  370. }
  371. }
  372. else {
  373. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  374. const currentIndex = draft.tracing!.findIndex((trace) => {
  375. if (!trace.execution_metadata?.parallel_id)
  376. return trace.node_id === data.node_id
  377. return trace.node_id === data.node_id && trace.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id
  378. })
  379. if (currentIndex > -1 && draft.tracing) {
  380. draft.tracing[currentIndex] = {
  381. ...(draft.tracing[currentIndex].extras
  382. ? { extras: draft.tracing[currentIndex].extras }
  383. : {}),
  384. ...data,
  385. } as any
  386. }
  387. }))
  388. const newNodes = produce(nodes, (draft) => {
  389. const currentNode = draft.find(node => node.id === data.node_id)!
  390. currentNode.data._runningStatus = data.status as any
  391. })
  392. setNodes(newNodes)
  393. prevNodeId = data.node_id
  394. }
  395. if (onNodeFinished)
  396. onNodeFinished(params)
  397. },
  398. onIterationStart: (params) => {
  399. const { data } = params
  400. const {
  401. workflowRunningData,
  402. setWorkflowRunningData,
  403. setIterTimes,
  404. } = workflowStore.getState()
  405. const {
  406. getNodes,
  407. setNodes,
  408. edges,
  409. setEdges,
  410. transform,
  411. } = store.getState()
  412. const nodes = getNodes()
  413. setIterTimes(DEFAULT_ITER_TIMES)
  414. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  415. draft.tracing!.push({
  416. ...data,
  417. status: NodeRunningStatus.Running,
  418. details: [],
  419. } as any)
  420. }))
  421. const {
  422. setViewport,
  423. } = reactflow
  424. const currentNodeIndex = nodes.findIndex(node => node.id === data.node_id)
  425. const currentNode = nodes[currentNodeIndex]
  426. const position = currentNode.position
  427. const zoom = transform[2]
  428. if (!currentNode.parentId) {
  429. setViewport({
  430. x: (clientWidth - 400 - currentNode.width! * zoom) / 2 - position.x * zoom,
  431. y: (clientHeight - currentNode.height! * zoom) / 2 - position.y * zoom,
  432. zoom: transform[2],
  433. })
  434. }
  435. const newNodes = produce(nodes, (draft) => {
  436. draft[currentNodeIndex].data._runningStatus = NodeRunningStatus.Running
  437. draft[currentNodeIndex].data._iterationLength = data.metadata.iterator_length
  438. })
  439. setNodes(newNodes)
  440. const newEdges = produce(edges, (draft) => {
  441. const edge = draft.find(edge => edge.target === data.node_id && edge.source === prevNodeId)
  442. if (edge)
  443. edge.data = { ...edge.data, _run: true } as any
  444. })
  445. setEdges(newEdges)
  446. if (onIterationStart)
  447. onIterationStart(params)
  448. },
  449. onIterationNext: (params) => {
  450. const {
  451. workflowRunningData,
  452. setWorkflowRunningData,
  453. iterTimes,
  454. setIterTimes,
  455. } = workflowStore.getState()
  456. const { data } = params
  457. const {
  458. getNodes,
  459. setNodes,
  460. } = store.getState()
  461. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  462. const iteration = draft.tracing!.find(trace => trace.node_id === data.node_id)
  463. if (iteration) {
  464. if (iteration.details!.length >= iteration.metadata.iterator_length!)
  465. return
  466. }
  467. if (!data.parallel_mode_run_id)
  468. iteration?.details!.push([])
  469. }))
  470. const nodes = getNodes()
  471. const newNodes = produce(nodes, (draft) => {
  472. const currentNode = draft.find(node => node.id === data.node_id)!
  473. currentNode.data._iterationIndex = iterTimes
  474. setIterTimes(iterTimes + 1)
  475. })
  476. setNodes(newNodes)
  477. if (onIterationNext)
  478. onIterationNext(params)
  479. },
  480. onIterationFinish: (params) => {
  481. const { data } = params
  482. const {
  483. workflowRunningData,
  484. setWorkflowRunningData,
  485. setIterTimes,
  486. } = workflowStore.getState()
  487. const {
  488. getNodes,
  489. setNodes,
  490. } = store.getState()
  491. const nodes = getNodes()
  492. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  493. const tracing = draft.tracing!
  494. const currIterationNode = tracing.find(trace => trace.node_id === data.node_id)
  495. if (currIterationNode) {
  496. Object.assign(currIterationNode, {
  497. ...data,
  498. status: NodeRunningStatus.Succeeded,
  499. })
  500. }
  501. }))
  502. setIterTimes(DEFAULT_ITER_TIMES)
  503. const newNodes = produce(nodes, (draft) => {
  504. const currentNode = draft.find(node => node.id === data.node_id)!
  505. currentNode.data._runningStatus = data.status
  506. })
  507. setNodes(newNodes)
  508. prevNodeId = data.node_id
  509. if (onIterationFinish)
  510. onIterationFinish(params)
  511. },
  512. onParallelBranchStarted: (params) => {
  513. // console.log(params, 'parallel start')
  514. },
  515. onParallelBranchFinished: (params) => {
  516. // console.log(params, 'finished')
  517. },
  518. onTextChunk: (params) => {
  519. const { data: { text } } = params
  520. const {
  521. workflowRunningData,
  522. setWorkflowRunningData,
  523. } = workflowStore.getState()
  524. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  525. draft.resultTabActive = true
  526. draft.resultText += text
  527. }))
  528. },
  529. onTextReplace: (params) => {
  530. const { data: { text } } = params
  531. const {
  532. workflowRunningData,
  533. setWorkflowRunningData,
  534. } = workflowStore.getState()
  535. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  536. draft.resultText = text
  537. }))
  538. },
  539. onTTSChunk: (messageId: string, audio: string, audioType?: string) => {
  540. if (!audio || audio === '')
  541. return
  542. player.playAudioWithAudio(audio, true)
  543. AudioPlayerManager.getInstance().resetMsgId(messageId)
  544. },
  545. onTTSEnd: (messageId: string, audio: string, audioType?: string) => {
  546. player.playAudioWithAudio(audio, false)
  547. },
  548. ...restCallback,
  549. },
  550. )
  551. }, [store, reactflow, workflowStore, doSyncWorkflowDraft])
  552. const handleStopRun = useCallback((taskId: string) => {
  553. const appId = useAppStore.getState().appDetail?.id
  554. stopWorkflowRun(`/apps/${appId}/workflow-runs/tasks/${taskId}/stop`)
  555. }, [])
  556. const handleRestoreFromPublishedWorkflow = useCallback(async () => {
  557. const appDetail = useAppStore.getState().appDetail
  558. const publishedWorkflow = await fetchPublishedWorkflow(`/apps/${appDetail?.id}/workflows/publish`)
  559. if (publishedWorkflow) {
  560. const nodes = publishedWorkflow.graph.nodes
  561. const edges = publishedWorkflow.graph.edges
  562. const viewport = publishedWorkflow.graph.viewport!
  563. handleUpdateWorkflowCanvas({
  564. nodes,
  565. edges,
  566. viewport,
  567. })
  568. featuresStore?.setState({ features: publishedWorkflow.features })
  569. workflowStore.getState().setPublishedAt(publishedWorkflow.created_at)
  570. workflowStore.getState().setEnvironmentVariables(publishedWorkflow.environment_variables || [])
  571. }
  572. }, [featuresStore, handleUpdateWorkflowCanvas, workflowStore])
  573. return {
  574. handleBackupDraft,
  575. handleLoadBackupDraft,
  576. handleRun,
  577. handleStopRun,
  578. handleRestoreFromPublishedWorkflow,
  579. }
  580. }