use-workflow-run.ts 17 KB


  1. import { useCallback } from 'react'
  2. import {
  3. useReactFlow,
  4. useStoreApi,
  5. } from 'reactflow'
  6. import produce from 'immer'
  7. import { v4 as uuidV4 } from 'uuid'
  8. import { usePathname } from 'next/navigation'
  9. import { useWorkflowStore } from '../store'
  10. import { useNodesSyncDraft } from '../hooks'
  11. import {
  12. NodeRunningStatus,
  13. WorkflowRunningStatus,
  14. } from '../types'
  15. import { useWorkflowUpdate } from './use-workflow-interactions'
  16. import { useStore as useAppStore } from '@/app/components/app/store'
  17. import type { IOtherOptions } from '@/service/base'
  18. import { ssePost } from '@/service/base'
  19. import {
  20. fetchPublishedWorkflow,
  21. stopWorkflowRun,
  22. } from '@/service/workflow'
  23. import { useFeaturesStore } from '@/app/components/base/features/hooks'
  24. import { AudioPlayerManager } from '@/app/components/base/audio-btn/audio.player.manager'
  25. export const useWorkflowRun = () => {
  26. const store = useStoreApi()
  27. const workflowStore = useWorkflowStore()
  28. const reactflow = useReactFlow()
  29. const featuresStore = useFeaturesStore()
  30. const { doSyncWorkflowDraft } = useNodesSyncDraft()
  31. const { handleUpdateWorkflowCanvas } = useWorkflowUpdate()
  32. const pathname = usePathname()
  33. const handleBackupDraft = useCallback(() => {
  34. const {
  35. getNodes,
  36. edges,
  37. } = store.getState()
  38. const { getViewport } = reactflow
  39. const {
  40. backupDraft,
  41. setBackupDraft,
  42. environmentVariables,
  43. } = workflowStore.getState()
  44. const { features } = featuresStore!.getState()
  45. if (!backupDraft) {
  46. setBackupDraft({
  47. nodes: getNodes(),
  48. edges,
  49. viewport: getViewport(),
  50. features,
  51. environmentVariables,
  52. })
  53. doSyncWorkflowDraft()
  54. }
  55. }, [reactflow, workflowStore, store, featuresStore, doSyncWorkflowDraft])
  56. const handleLoadBackupDraft = useCallback(() => {
  57. const {
  58. backupDraft,
  59. setBackupDraft,
  60. setEnvironmentVariables,
  61. } = workflowStore.getState()
  62. if (backupDraft) {
  63. const {
  64. nodes,
  65. edges,
  66. viewport,
  67. features,
  68. environmentVariables,
  69. } = backupDraft
  70. handleUpdateWorkflowCanvas({
  71. nodes,
  72. edges,
  73. viewport,
  74. })
  75. setEnvironmentVariables(environmentVariables)
  76. featuresStore!.setState({ features })
  77. setBackupDraft(undefined)
  78. }
  79. }, [handleUpdateWorkflowCanvas, workflowStore, featuresStore])
  80. const handleRun = useCallback(async (
  81. params: any,
  82. callback?: IOtherOptions,
  83. ) => {
  84. const {
  85. getNodes,
  86. setNodes,
  87. } = store.getState()
  88. const newNodes = produce(getNodes(), (draft) => {
  89. draft.forEach((node) => {
  90. node.data.selected = false
  91. node.data._runningStatus = undefined
  92. })
  93. })
  94. setNodes(newNodes)
  95. await doSyncWorkflowDraft()
  96. const {
  97. onWorkflowStarted,
  98. onWorkflowFinished,
  99. onNodeStarted,
  100. onNodeFinished,
  101. onIterationStart,
  102. onIterationNext,
  103. onIterationFinish,
  104. onError,
  105. ...restCallback
  106. } = callback || {}
  107. workflowStore.setState({ historyWorkflowData: undefined })
  108. const appDetail = useAppStore.getState().appDetail
  109. const workflowContainer = document.getElementById('workflow-container')
  110. const {
  111. clientWidth,
  112. clientHeight,
  113. } = workflowContainer!
  114. let url = ''
  115. if (appDetail?.mode === 'advanced-chat')
  116. url = `/apps/${appDetail.id}/advanced-chat/workflows/draft/run`
  117. if (appDetail?.mode === 'workflow')
  118. url = `/apps/${appDetail.id}/workflows/draft/run`
  119. let prevNodeId = ''
  120. const {
  121. setWorkflowRunningData,
  122. } = workflowStore.getState()
  123. setWorkflowRunningData({
  124. result: {
  125. status: WorkflowRunningStatus.Running,
  126. },
  127. tracing: [],
  128. resultText: '',
  129. })
  130. let isInIteration = false
  131. let iterationLength = 0
  132. let ttsUrl = ''
  133. let ttsIsPublic = false
  134. if (params.token) {
  135. ttsUrl = '/text-to-audio'
  136. ttsIsPublic = true
  137. }
  138. else if (params.appId) {
  139. if (pathname.search('explore/installed') > -1)
  140. ttsUrl = `/installed-apps/${params.appId}/text-to-audio`
  141. else
  142. ttsUrl = `/apps/${params.appId}/text-to-audio`
  143. }
  144. const player = AudioPlayerManager.getInstance().getAudioPlayer(ttsUrl, ttsIsPublic, uuidV4(), 'none', 'none', (_: any): any => {})
  145. ssePost(
  146. url,
  147. {
  148. body: params,
  149. },
  150. {
  151. onWorkflowStarted: (params) => {
  152. const { task_id, data } = params
  153. const {
  154. workflowRunningData,
  155. setWorkflowRunningData,
  156. } = workflowStore.getState()
  157. const {
  158. edges,
  159. setEdges,
  160. } = store.getState()
  161. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  162. draft.task_id = task_id
  163. draft.result = {
  164. ...draft?.result,
  165. ...data,
  166. status: WorkflowRunningStatus.Running,
  167. }
  168. }))
  169. const newEdges = produce(edges, (draft) => {
  170. draft.forEach((edge) => {
  171. edge.data = {
  172. ...edge.data,
  173. _runned: false,
  174. }
  175. })
  176. })
  177. setEdges(newEdges)
  178. if (onWorkflowStarted)
  179. onWorkflowStarted(params)
  180. },
  181. onWorkflowFinished: (params) => {
  182. const { data } = params
  183. const {
  184. workflowRunningData,
  185. setWorkflowRunningData,
  186. } = workflowStore.getState()
  187. const isStringOutput = data.outputs && Object.keys(data.outputs).length === 1 && typeof data.outputs[Object.keys(data.outputs)[0]] === 'string'
  188. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  189. draft.result = {
  190. ...draft.result,
  191. ...data,
  192. } as any
  193. if (isStringOutput) {
  194. draft.resultTabActive = true
  195. draft.resultText = data.outputs[Object.keys(data.outputs)[0]]
  196. }
  197. }))
  198. prevNodeId = ''
  199. if (onWorkflowFinished)
  200. onWorkflowFinished(params)
  201. },
  202. onError: (params) => {
  203. const {
  204. workflowRunningData,
  205. setWorkflowRunningData,
  206. } = workflowStore.getState()
  207. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  208. draft.result = {
  209. ...draft.result,
  210. status: WorkflowRunningStatus.Failed,
  211. }
  212. }))
  213. if (onError)
  214. onError(params)
  215. },
  216. onNodeStarted: (params) => {
  217. const { data } = params
  218. const {
  219. workflowRunningData,
  220. setWorkflowRunningData,
  221. } = workflowStore.getState()
  222. const {
  223. getNodes,
  224. setNodes,
  225. edges,
  226. setEdges,
  227. transform,
  228. } = store.getState()
  229. if (isInIteration) {
  230. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  231. const tracing = draft.tracing!
  232. const iterations = tracing[tracing.length - 1]
  233. const currIteration = iterations.details![iterations.details!.length - 1]
  234. currIteration.push({
  235. ...data,
  236. status: NodeRunningStatus.Running,
  237. } as any)
  238. }))
  239. }
  240. else {
  241. const nodes = getNodes()
  242. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  243. draft.tracing!.push({
  244. ...data,
  245. status: NodeRunningStatus.Running,
  246. } as any)
  247. }))
  248. const {
  249. setViewport,
  250. } = reactflow
  251. const currentNodeIndex = nodes.findIndex(node => node.id === data.node_id)
  252. const currentNode = nodes[currentNodeIndex]
  253. const position = currentNode.position
  254. const zoom = transform[2]
  255. if (!currentNode.parentId) {
  256. setViewport({
  257. x: (clientWidth - 400 - currentNode.width! * zoom) / 2 - position.x * zoom,
  258. y: (clientHeight - currentNode.height! * zoom) / 2 - position.y * zoom,
  259. zoom: transform[2],
  260. })
  261. }
  262. const newNodes = produce(nodes, (draft) => {
  263. draft[currentNodeIndex].data._runningStatus = NodeRunningStatus.Running
  264. })
  265. setNodes(newNodes)
  266. const newEdges = produce(edges, (draft) => {
  267. const edge = draft.find(edge => edge.target === data.node_id && edge.source === prevNodeId)
  268. if (edge)
  269. edge.data = { ...edge.data, _runned: true } as any
  270. })
  271. setEdges(newEdges)
  272. }
  273. if (onNodeStarted)
  274. onNodeStarted(params)
  275. },
  276. onNodeFinished: (params) => {
  277. const { data } = params
  278. const {
  279. workflowRunningData,
  280. setWorkflowRunningData,
  281. } = workflowStore.getState()
  282. const {
  283. getNodes,
  284. setNodes,
  285. } = store.getState()
  286. if (isInIteration) {
  287. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  288. const tracing = draft.tracing!
  289. const iterations = tracing[tracing.length - 1]
  290. const currIteration = iterations.details![iterations.details!.length - 1]
  291. const nodeInfo = currIteration[currIteration.length - 1]
  292. currIteration[currIteration.length - 1] = {
  293. ...nodeInfo,
  294. ...data,
  295. status: NodeRunningStatus.Succeeded,
  296. } as any
  297. }))
  298. }
  299. else {
  300. const nodes = getNodes()
  301. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  302. const currentIndex = draft.tracing!.findIndex(trace => trace.node_id === data.node_id)
  303. if (currentIndex > -1 && draft.tracing) {
  304. draft.tracing[currentIndex] = {
  305. ...(draft.tracing[currentIndex].extras
  306. ? { extras: draft.tracing[currentIndex].extras }
  307. : {}),
  308. ...data,
  309. } as any
  310. }
  311. }))
  312. const newNodes = produce(nodes, (draft) => {
  313. const currentNode = draft.find(node => node.id === data.node_id)!
  314. currentNode.data._runningStatus = data.status as any
  315. })
  316. setNodes(newNodes)
  317. prevNodeId = data.node_id
  318. }
  319. if (onNodeFinished)
  320. onNodeFinished(params)
  321. },
  322. onIterationStart: (params) => {
  323. const { data } = params
  324. const {
  325. workflowRunningData,
  326. setWorkflowRunningData,
  327. } = workflowStore.getState()
  328. const {
  329. getNodes,
  330. setNodes,
  331. edges,
  332. setEdges,
  333. transform,
  334. } = store.getState()
  335. const nodes = getNodes()
  336. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  337. draft.tracing!.push({
  338. ...data,
  339. status: NodeRunningStatus.Running,
  340. details: [],
  341. } as any)
  342. }))
  343. isInIteration = true
  344. iterationLength = data.metadata.iterator_length
  345. const {
  346. setViewport,
  347. } = reactflow
  348. const currentNodeIndex = nodes.findIndex(node => node.id === data.node_id)
  349. const currentNode = nodes[currentNodeIndex]
  350. const position = currentNode.position
  351. const zoom = transform[2]
  352. if (!currentNode.parentId) {
  353. setViewport({
  354. x: (clientWidth - 400 - currentNode.width! * zoom) / 2 - position.x * zoom,
  355. y: (clientHeight - currentNode.height! * zoom) / 2 - position.y * zoom,
  356. zoom: transform[2],
  357. })
  358. }
  359. const newNodes = produce(nodes, (draft) => {
  360. draft[currentNodeIndex].data._runningStatus = NodeRunningStatus.Running
  361. draft[currentNodeIndex].data._iterationLength = data.metadata.iterator_length
  362. })
  363. setNodes(newNodes)
  364. const newEdges = produce(edges, (draft) => {
  365. const edge = draft.find(edge => edge.target === data.node_id && edge.source === prevNodeId)
  366. if (edge)
  367. edge.data = { ...edge.data, _runned: true } as any
  368. })
  369. setEdges(newEdges)
  370. if (onIterationStart)
  371. onIterationStart(params)
  372. },
  373. onIterationNext: (params) => {
  374. const {
  375. workflowRunningData,
  376. setWorkflowRunningData,
  377. } = workflowStore.getState()
  378. const { data } = params
  379. const {
  380. getNodes,
  381. setNodes,
  382. } = store.getState()
  383. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  384. const iteration = draft.tracing![draft.tracing!.length - 1]
  385. if (iteration.details!.length >= iterationLength)
  386. return
  387. iteration.details!.push([])
  388. }))
  389. const nodes = getNodes()
  390. const newNodes = produce(nodes, (draft) => {
  391. const currentNode = draft.find(node => node.id === data.node_id)!
  392. currentNode.data._iterationIndex = data.index > 0 ? data.index : 1
  393. })
  394. setNodes(newNodes)
  395. if (onIterationNext)
  396. onIterationNext(params)
  397. },
  398. onIterationFinish: (params) => {
  399. const { data } = params
  400. const {
  401. workflowRunningData,
  402. setWorkflowRunningData,
  403. } = workflowStore.getState()
  404. const {
  405. getNodes,
  406. setNodes,
  407. } = store.getState()
  408. const nodes = getNodes()
  409. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  410. const tracing = draft.tracing!
  411. tracing[tracing.length - 1] = {
  412. ...tracing[tracing.length - 1],
  413. ...data,
  414. status: NodeRunningStatus.Succeeded,
  415. } as any
  416. }))
  417. isInIteration = false
  418. const newNodes = produce(nodes, (draft) => {
  419. const currentNode = draft.find(node => node.id === data.node_id)!
  420. currentNode.data._runningStatus = data.status
  421. })
  422. setNodes(newNodes)
  423. prevNodeId = data.node_id
  424. if (onIterationFinish)
  425. onIterationFinish(params)
  426. },
  427. onTextChunk: (params) => {
  428. const { data: { text } } = params
  429. const {
  430. workflowRunningData,
  431. setWorkflowRunningData,
  432. } = workflowStore.getState()
  433. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  434. draft.resultTabActive = true
  435. draft.resultText += text
  436. }))
  437. },
  438. onTextReplace: (params) => {
  439. const { data: { text } } = params
  440. const {
  441. workflowRunningData,
  442. setWorkflowRunningData,
  443. } = workflowStore.getState()
  444. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  445. draft.resultText = text
  446. }))
  447. },
  448. onTTSChunk: (messageId: string, audio: string, audioType?: string) => {
  449. if (!audio || audio === '')
  450. return
  451. player.playAudioWithAudio(audio, true)
  452. AudioPlayerManager.getInstance().resetMsgId(messageId)
  453. },
  454. onTTSEnd: (messageId: string, audio: string, audioType?: string) => {
  455. player.playAudioWithAudio(audio, false)
  456. },
  457. ...restCallback,
  458. },
  459. )
  460. }, [store, reactflow, workflowStore, doSyncWorkflowDraft])
  461. const handleStopRun = useCallback((taskId: string) => {
  462. const appId = useAppStore.getState().appDetail?.id
  463. stopWorkflowRun(`/apps/${appId}/workflow-runs/tasks/${taskId}/stop`)
  464. }, [])
  465. const handleRestoreFromPublishedWorkflow = useCallback(async () => {
  466. const appDetail = useAppStore.getState().appDetail
  467. const publishedWorkflow = await fetchPublishedWorkflow(`/apps/${appDetail?.id}/workflows/publish`)
  468. if (publishedWorkflow) {
  469. const nodes = publishedWorkflow.graph.nodes
  470. const edges = publishedWorkflow.graph.edges
  471. const viewport = publishedWorkflow.graph.viewport!
  472. handleUpdateWorkflowCanvas({
  473. nodes,
  474. edges,
  475. viewport,
  476. })
  477. featuresStore?.setState({ features: publishedWorkflow.features })
  478. workflowStore.getState().setPublishedAt(publishedWorkflow.created_at)
  479. workflowStore.getState().setEnvironmentVariables(publishedWorkflow.environment_variables || [])
  480. }
  481. }, [featuresStore, handleUpdateWorkflowCanvas, workflowStore])
  482. return {
  483. handleBackupDraft,
  484. handleLoadBackupDraft,
  485. handleRun,
  486. handleStopRun,
  487. handleRestoreFromPublishedWorkflow,
  488. }
  489. }