utils.ts 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. import {
  2. Position,
  3. getConnectedEdges,
  4. getOutgoers,
  5. } from 'reactflow'
  6. import dagre from '@dagrejs/dagre'
  7. import { v4 as uuid4 } from 'uuid'
  8. import {
  9. cloneDeep,
  10. uniqBy,
  11. } from 'lodash-es'
  12. import type {
  13. Edge,
  14. InputVar,
  15. Node,
  16. ToolWithProvider,
  17. ValueSelector,
  18. } from './types'
  19. import { BlockEnum } from './types'
  20. import {
  21. CUSTOM_NODE,
  22. ITERATION_NODE_Z_INDEX,
  23. NODE_WIDTH_X_OFFSET,
  24. START_INITIAL_POSITION,
  25. } from './constants'
  26. import type { QuestionClassifierNodeType } from './nodes/question-classifier/types'
  27. import type { IfElseNodeType } from './nodes/if-else/types'
  28. import { branchNameCorrect } from './nodes/if-else/utils'
  29. import type { ToolNodeType } from './nodes/tool/types'
  30. import { CollectionType } from '@/app/components/tools/types'
  31. import { toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
  32. const WHITE = 'WHITE'
  33. const GRAY = 'GRAY'
  34. const BLACK = 'BLACK'
  35. const isCyclicUtil = (nodeId: string, color: Record<string, string>, adjaList: Record<string, string[]>, stack: string[]) => {
  36. color[nodeId] = GRAY
  37. stack.push(nodeId)
  38. for (let i = 0; i < adjaList[nodeId].length; ++i) {
  39. const childId = adjaList[nodeId][i]
  40. if (color[childId] === GRAY) {
  41. stack.push(childId)
  42. return true
  43. }
  44. if (color[childId] === WHITE && isCyclicUtil(childId, color, adjaList, stack))
  45. return true
  46. }
  47. color[nodeId] = BLACK
  48. if (stack.length > 0 && stack[stack.length - 1] === nodeId)
  49. stack.pop()
  50. return false
  51. }
  52. const getCycleEdges = (nodes: Node[], edges: Edge[]) => {
  53. const adjaList: Record<string, string[]> = {}
  54. const color: Record<string, string> = {}
  55. const stack: string[] = []
  56. for (const node of nodes) {
  57. color[node.id] = WHITE
  58. adjaList[node.id] = []
  59. }
  60. for (const edge of edges)
  61. adjaList[edge.source]?.push(edge.target)
  62. for (let i = 0; i < nodes.length; i++) {
  63. if (color[nodes[i].id] === WHITE)
  64. isCyclicUtil(nodes[i].id, color, adjaList, stack)
  65. }
  66. const cycleEdges = []
  67. if (stack.length > 0) {
  68. const cycleNodes = new Set(stack)
  69. for (const edge of edges) {
  70. if (cycleNodes.has(edge.source) && cycleNodes.has(edge.target))
  71. cycleEdges.push(edge)
  72. }
  73. }
  74. return cycleEdges
  75. }
  76. export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => {
  77. const nodes = cloneDeep(originNodes)
  78. const edges = cloneDeep(originEdges)
  79. const firstNode = nodes[0]
  80. if (!firstNode?.position) {
  81. nodes.forEach((node, index) => {
  82. node.position = {
  83. x: START_INITIAL_POSITION.x + index * NODE_WIDTH_X_OFFSET,
  84. y: START_INITIAL_POSITION.y,
  85. }
  86. })
  87. }
  88. const iterationNodeMap = nodes.reduce((acc, node) => {
  89. if (node.parentId) {
  90. if (acc[node.parentId])
  91. acc[node.parentId].push(node.id)
  92. else
  93. acc[node.parentId] = [node.id]
  94. }
  95. return acc
  96. }, {} as Record<string, string[]>)
  97. return nodes.map((node) => {
  98. if (!node.type)
  99. node.type = CUSTOM_NODE
  100. const connectedEdges = getConnectedEdges([node], edges)
  101. node.data._connectedSourceHandleIds = connectedEdges.filter(edge => edge.source === node.id).map(edge => edge.sourceHandle || 'source')
  102. node.data._connectedTargetHandleIds = connectedEdges.filter(edge => edge.target === node.id).map(edge => edge.targetHandle || 'target')
  103. if (node.data.type === BlockEnum.IfElse) {
  104. const nodeData = node.data as IfElseNodeType
  105. if (!nodeData.cases && nodeData.logical_operator && nodeData.conditions) {
  106. (node.data as IfElseNodeType).cases = [
  107. {
  108. case_id: 'true',
  109. logical_operator: nodeData.logical_operator,
  110. conditions: nodeData.conditions,
  111. },
  112. ]
  113. }
  114. node.data._targetBranches = branchNameCorrect([
  115. ...(node.data as IfElseNodeType).cases.map(item => ({ id: item.case_id, name: '' })),
  116. { id: 'false', name: '' },
  117. ])
  118. }
  119. if (node.data.type === BlockEnum.QuestionClassifier) {
  120. node.data._targetBranches = (node.data as QuestionClassifierNodeType).classes.map((topic) => {
  121. return topic
  122. })
  123. }
  124. if (node.data.type === BlockEnum.Iteration)
  125. node.data._children = iterationNodeMap[node.id] || []
  126. return node
  127. })
  128. }
  129. export const initialEdges = (originEdges: Edge[], originNodes: Node[]) => {
  130. const nodes = cloneDeep(originNodes)
  131. const edges = cloneDeep(originEdges)
  132. let selectedNode: Node | null = null
  133. const nodesMap = nodes.reduce((acc, node) => {
  134. acc[node.id] = node
  135. if (node.data?.selected)
  136. selectedNode = node
  137. return acc
  138. }, {} as Record<string, Node>)
  139. const cycleEdges = getCycleEdges(nodes, edges)
  140. return edges.filter((edge) => {
  141. return !cycleEdges.find(cycEdge => cycEdge.source === edge.source && cycEdge.target === edge.target)
  142. }).map((edge) => {
  143. edge.type = 'custom'
  144. if (!edge.sourceHandle)
  145. edge.sourceHandle = 'source'
  146. if (!edge.targetHandle)
  147. edge.targetHandle = 'target'
  148. if (!edge.data?.sourceType && edge.source && nodesMap[edge.source]) {
  149. edge.data = {
  150. ...edge.data,
  151. sourceType: nodesMap[edge.source].data.type!,
  152. } as any
  153. }
  154. if (!edge.data?.targetType && edge.target && nodesMap[edge.target]) {
  155. edge.data = {
  156. ...edge.data,
  157. targetType: nodesMap[edge.target].data.type!,
  158. } as any
  159. }
  160. if (selectedNode) {
  161. edge.data = {
  162. ...edge.data,
  163. _connectedNodeIsSelected: edge.source === selectedNode.id || edge.target === selectedNode.id,
  164. } as any
  165. }
  166. return edge
  167. })
  168. }
  169. export const getLayoutByDagre = (originNodes: Node[], originEdges: Edge[]) => {
  170. const dagreGraph = new dagre.graphlib.Graph()
  171. dagreGraph.setDefaultEdgeLabel(() => ({}))
  172. const nodes = cloneDeep(originNodes).filter(node => !node.parentId && node.type === CUSTOM_NODE)
  173. const edges = cloneDeep(originEdges).filter(edge => !edge.data?.isInIteration)
  174. dagreGraph.setGraph({
  175. rankdir: 'LR',
  176. align: 'UL',
  177. nodesep: 40,
  178. ranksep: 60,
  179. ranker: 'tight-tree',
  180. marginx: 30,
  181. marginy: 200,
  182. })
  183. nodes.forEach((node) => {
  184. dagreGraph.setNode(node.id, {
  185. width: node.width!,
  186. height: node.height!,
  187. })
  188. })
  189. edges.forEach((edge) => {
  190. dagreGraph.setEdge(edge.source, edge.target)
  191. })
  192. dagre.layout(dagreGraph)
  193. return dagreGraph
  194. }
  195. export const canRunBySingle = (nodeType: BlockEnum) => {
  196. return nodeType === BlockEnum.LLM
  197. || nodeType === BlockEnum.KnowledgeRetrieval
  198. || nodeType === BlockEnum.Code
  199. || nodeType === BlockEnum.TemplateTransform
  200. || nodeType === BlockEnum.QuestionClassifier
  201. || nodeType === BlockEnum.HttpRequest
  202. || nodeType === BlockEnum.Tool
  203. || nodeType === BlockEnum.ParameterExtractor
  204. || nodeType === BlockEnum.Iteration
  205. }
  206. type ConnectedSourceOrTargetNodesChange = {
  207. type: string
  208. edge: Edge
  209. }[]
  210. export const getNodesConnectedSourceOrTargetHandleIdsMap = (changes: ConnectedSourceOrTargetNodesChange, nodes: Node[]) => {
  211. const nodesConnectedSourceOrTargetHandleIdsMap = {} as Record<string, any>
  212. changes.forEach((change) => {
  213. const {
  214. edge,
  215. type,
  216. } = change
  217. const sourceNode = nodes.find(node => node.id === edge.source)!
  218. if (sourceNode) {
  219. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id] = nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id] || {
  220. _connectedSourceHandleIds: [...(sourceNode?.data._connectedSourceHandleIds || [])],
  221. _connectedTargetHandleIds: [...(sourceNode?.data._connectedTargetHandleIds || [])],
  222. }
  223. }
  224. const targetNode = nodes.find(node => node.id === edge.target)!
  225. if (targetNode) {
  226. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id] = nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id] || {
  227. _connectedSourceHandleIds: [...(targetNode?.data._connectedSourceHandleIds || [])],
  228. _connectedTargetHandleIds: [...(targetNode?.data._connectedTargetHandleIds || [])],
  229. }
  230. }
  231. if (sourceNode) {
  232. if (type === 'remove') {
  233. const index = nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.findIndex((handleId: string) => handleId === edge.sourceHandle)
  234. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.splice(index, 1)
  235. }
  236. if (type === 'add')
  237. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.push(edge.sourceHandle || 'source')
  238. }
  239. if (targetNode) {
  240. if (type === 'remove') {
  241. const index = nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.findIndex((handleId: string) => handleId === edge.targetHandle)
  242. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.splice(index, 1)
  243. }
  244. if (type === 'add')
  245. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.push(edge.targetHandle || 'target')
  246. }
  247. })
  248. return nodesConnectedSourceOrTargetHandleIdsMap
  249. }
  250. export const generateNewNode = ({ data, position, id, zIndex, type, ...rest }: Omit<Node, 'id'> & { id?: string }) => {
  251. return {
  252. id: id || `${Date.now()}`,
  253. type: type || CUSTOM_NODE,
  254. data,
  255. position,
  256. targetPosition: Position.Left,
  257. sourcePosition: Position.Right,
  258. zIndex: data.type === BlockEnum.Iteration ? ITERATION_NODE_Z_INDEX : zIndex,
  259. ...rest,
  260. } as Node
  261. }
  262. export const genNewNodeTitleFromOld = (oldTitle: string) => {
  263. const regex = /^(.+?)\s*\((\d+)\)\s*$/
  264. const match = oldTitle.match(regex)
  265. if (match) {
  266. const title = match[1]
  267. const num = parseInt(match[2], 10)
  268. return `${title} (${num + 1})`
  269. }
  270. else {
  271. return `${oldTitle} (1)`
  272. }
  273. }
  274. export const getValidTreeNodes = (nodes: Node[], edges: Edge[]) => {
  275. const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
  276. if (!startNode) {
  277. return {
  278. validNodes: [],
  279. maxDepth: 0,
  280. }
  281. }
  282. const list: Node[] = [startNode]
  283. let maxDepth = 1
  284. const traverse = (root: Node, depth: number) => {
  285. if (depth > maxDepth)
  286. maxDepth = depth
  287. const outgoers = getOutgoers(root, nodes, edges)
  288. if (outgoers.length) {
  289. outgoers.forEach((outgoer) => {
  290. list.push(outgoer)
  291. if (outgoer.data.type === BlockEnum.Iteration)
  292. list.push(...nodes.filter(node => node.parentId === outgoer.id))
  293. traverse(outgoer, depth + 1)
  294. })
  295. }
  296. else {
  297. list.push(root)
  298. if (root.data.type === BlockEnum.Iteration)
  299. list.push(...nodes.filter(node => node.parentId === root.id))
  300. }
  301. }
  302. traverse(startNode, maxDepth)
  303. return {
  304. validNodes: uniqBy(list, 'id'),
  305. maxDepth,
  306. }
  307. }
  308. export const getToolCheckParams = (
  309. toolData: ToolNodeType,
  310. buildInTools: ToolWithProvider[],
  311. customTools: ToolWithProvider[],
  312. workflowTools: ToolWithProvider[],
  313. language: string,
  314. ) => {
  315. const { provider_id, provider_type, tool_name } = toolData
  316. const isBuiltIn = provider_type === CollectionType.builtIn
  317. const currentTools = provider_type === CollectionType.builtIn ? buildInTools : provider_type === CollectionType.custom ? customTools : workflowTools
  318. const currCollection = currentTools.find(item => item.id === provider_id)
  319. const currTool = currCollection?.tools.find(tool => tool.name === tool_name)
  320. const formSchemas = currTool ? toolParametersToFormSchemas(currTool.parameters) : []
  321. const toolInputVarSchema = formSchemas.filter((item: any) => item.form === 'llm')
  322. const toolSettingSchema = formSchemas.filter((item: any) => item.form !== 'llm')
  323. return {
  324. toolInputsSchema: (() => {
  325. const formInputs: InputVar[] = []
  326. toolInputVarSchema.forEach((item: any) => {
  327. formInputs.push({
  328. label: item.label[language] || item.label.en_US,
  329. variable: item.variable,
  330. type: item.type,
  331. required: item.required,
  332. })
  333. })
  334. return formInputs
  335. })(),
  336. notAuthed: isBuiltIn && !!currCollection?.allow_delete && !currCollection?.is_team_authorization,
  337. toolSettingSchema,
  338. language,
  339. }
  340. }
  341. export const changeNodesAndEdgesId = (nodes: Node[], edges: Edge[]) => {
  342. const idMap = nodes.reduce((acc, node) => {
  343. acc[node.id] = uuid4()
  344. return acc
  345. }, {} as Record<string, string>)
  346. const newNodes = nodes.map((node) => {
  347. return {
  348. ...node,
  349. id: idMap[node.id],
  350. }
  351. })
  352. const newEdges = edges.map((edge) => {
  353. return {
  354. ...edge,
  355. source: idMap[edge.source],
  356. target: idMap[edge.target],
  357. }
  358. })
  359. return [newNodes, newEdges] as [Node[], Edge[]]
  360. }
  361. export const isMac = () => {
  362. return navigator.userAgent.toUpperCase().includes('MAC')
  363. }
  364. const specialKeysNameMap: Record<string, string | undefined> = {
  365. ctrl: '⌘',
  366. alt: '⌥',
  367. }
  368. export const getKeyboardKeyNameBySystem = (key: string) => {
  369. if (isMac())
  370. return specialKeysNameMap[key] || key
  371. return key
  372. }
  373. const specialKeysCodeMap: Record<string, string | undefined> = {
  374. ctrl: 'meta',
  375. }
  376. export const getKeyboardKeyCodeBySystem = (key: string) => {
  377. if (isMac())
  378. return specialKeysCodeMap[key] || key
  379. return key
  380. }
  381. export const getTopLeftNodePosition = (nodes: Node[]) => {
  382. let minX = Infinity
  383. let minY = Infinity
  384. nodes.forEach((node) => {
  385. if (node.position.x < minX)
  386. minX = node.position.x
  387. if (node.position.y < minY)
  388. minY = node.position.y
  389. })
  390. return {
  391. x: minX,
  392. y: minY,
  393. }
  394. }
  395. export const isEventTargetInputArea = (target: HTMLElement) => {
  396. if (target.tagName === 'INPUT' || target.tagName === 'TEXTAREA')
  397. return true
  398. if (target.contentEditable === 'true')
  399. return true
  400. }
  401. export const variableTransformer = (v: ValueSelector | string) => {
  402. if (typeof v === 'string')
  403. return v.replace(/^{{#|#}}$/g, '').split('.')
  404. return `{{#${v.join('.')}#}}`
  405. }