use-config.ts 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. import { useCallback, useEffect, useRef, useState } from 'react'
  2. import produce from 'immer'
  3. import useVarList from '../_base/hooks/use-var-list'
  4. import { VarType } from '../../types'
  5. import type { Memory, ValueSelector, Var } from '../../types'
  6. import { useStore } from '../../store'
  7. import {
  8. useIsChatMode,
  9. useNodesReadOnly,
  10. } from '../../hooks'
  11. import useAvailableVarList from '../_base/hooks/use-available-var-list'
  12. import type { LLMNodeType } from './types'
  13. import { Resolution } from '@/types/app'
  14. import { useModelListAndDefaultModelAndCurrentProviderAndModel, useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
  15. import {
  16. ModelFeatureEnum,
  17. ModelTypeEnum,
  18. } from '@/app/components/header/account-setting/model-provider-page/declarations'
  19. import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
  20. import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
  21. import type { PromptItem } from '@/models/debug'
  22. import { RETRIEVAL_OUTPUT_STRUCT } from '@/app/components/workflow/constants'
  23. import { checkHasContextBlock, checkHasHistoryBlock, checkHasQueryBlock } from '@/app/components/base/prompt-editor/constants'
  24. const useConfig = (id: string, payload: LLMNodeType) => {
  25. const { nodesReadOnly: readOnly } = useNodesReadOnly()
  26. const isChatMode = useIsChatMode()
  27. const defaultConfig = useStore(s => s.nodesDefaultConfigs)[payload.type]
  28. const [defaultRolePrefix, setDefaultRolePrefix] = useState<{ user: string; assistant: string }>({ user: '', assistant: '' })
  29. const { inputs, setInputs: doSetInputs } = useNodeCrud<LLMNodeType>(id, payload)
  30. const setInputs = useCallback((newInputs: LLMNodeType) => {
  31. if (newInputs.memory && !newInputs.memory.role_prefix) {
  32. const newPayload = produce(newInputs, (draft) => {
  33. draft.memory!.role_prefix = defaultRolePrefix
  34. })
  35. doSetInputs(newPayload)
  36. return
  37. }
  38. doSetInputs(newInputs)
  39. }, [doSetInputs, defaultRolePrefix])
  40. const inputRef = useRef(inputs)
  41. useEffect(() => {
  42. inputRef.current = inputs
  43. }, [inputs])
  44. // model
  45. const model = inputs.model
  46. const modelMode = inputs.model?.mode
  47. const isChatModel = modelMode === 'chat'
  48. const isCompletionModel = !isChatModel
  49. const hasSetBlockStatus = (() => {
  50. const promptTemplate = inputs.prompt_template
  51. const hasSetContext = isChatModel ? (promptTemplate as PromptItem[]).some(item => checkHasContextBlock(item.text)) : checkHasContextBlock((promptTemplate as PromptItem).text)
  52. if (!isChatMode) {
  53. return {
  54. history: false,
  55. query: false,
  56. context: hasSetContext,
  57. }
  58. }
  59. if (isChatModel) {
  60. return {
  61. history: false,
  62. query: (promptTemplate as PromptItem[]).some(item => checkHasQueryBlock(item.text)),
  63. context: hasSetContext,
  64. }
  65. }
  66. else {
  67. return {
  68. history: checkHasHistoryBlock((promptTemplate as PromptItem).text),
  69. query: checkHasQueryBlock((promptTemplate as PromptItem).text),
  70. context: hasSetContext,
  71. }
  72. }
  73. })()
  74. const shouldShowContextTip = !hasSetBlockStatus.context && inputs.context.enabled
  75. const appendDefaultPromptConfig = useCallback((draft: LLMNodeType, defaultConfig: any, passInIsChatMode?: boolean) => {
  76. const promptTemplates = defaultConfig.prompt_templates
  77. if (passInIsChatMode === undefined ? isChatModel : passInIsChatMode) {
  78. draft.prompt_template = promptTemplates.chat_model.prompts
  79. }
  80. else {
  81. draft.prompt_template = promptTemplates.completion_model.prompt
  82. setDefaultRolePrefix({
  83. user: promptTemplates.completion_model.conversation_histories_role.user_prefix,
  84. assistant: promptTemplates.completion_model.conversation_histories_role.assistant_prefix,
  85. })
  86. }
  87. }, [isChatModel])
  88. useEffect(() => {
  89. const isReady = defaultConfig && Object.keys(defaultConfig).length > 0
  90. if (isReady && !inputs.prompt_template) {
  91. const newInputs = produce(inputs, (draft) => {
  92. appendDefaultPromptConfig(draft, defaultConfig)
  93. })
  94. setInputs(newInputs)
  95. }
  96. // eslint-disable-next-line react-hooks/exhaustive-deps
  97. }, [defaultConfig, isChatModel])
  98. const [modelChanged, setModelChanged] = useState(false)
  99. const {
  100. currentProvider,
  101. currentModel,
  102. } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
  103. const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
  104. const newInputs = produce(inputRef.current, (draft) => {
  105. draft.model.provider = model.provider
  106. draft.model.name = model.modelId
  107. draft.model.mode = model.mode!
  108. const isModeChange = model.mode !== inputRef.current.model.mode
  109. if (isModeChange && defaultConfig && Object.keys(defaultConfig).length > 0)
  110. appendDefaultPromptConfig(draft, defaultConfig, model.mode === 'chat')
  111. })
  112. setInputs(newInputs)
  113. setModelChanged(true)
  114. }, [setInputs, defaultConfig, appendDefaultPromptConfig])
  115. useEffect(() => {
  116. if (currentProvider?.provider && currentModel?.model && !model.provider) {
  117. handleModelChanged({
  118. provider: currentProvider?.provider,
  119. modelId: currentModel?.model,
  120. mode: currentModel?.model_properties?.mode as string,
  121. })
  122. }
  123. }, [model.provider, currentProvider, currentModel, handleModelChanged])
  124. const handleCompletionParamsChange = useCallback((newParams: Record<string, any>) => {
  125. const newInputs = produce(inputs, (draft) => {
  126. draft.model.completion_params = newParams
  127. })
  128. setInputs(newInputs)
  129. }, [inputs, setInputs])
  130. const {
  131. currentModel: currModel,
  132. } = useTextGenerationCurrentProviderAndModelAndModelList(
  133. {
  134. provider: model.provider,
  135. model: model.name,
  136. },
  137. )
  138. const isShowVisionConfig = !!currModel?.features?.includes(ModelFeatureEnum.vision)
  139. // change to vision model to set vision enabled, else disabled
  140. useEffect(() => {
  141. if (!modelChanged)
  142. return
  143. setModelChanged(false)
  144. if (!isShowVisionConfig) {
  145. const newInputs = produce(inputs, (draft) => {
  146. draft.vision = {
  147. enabled: false,
  148. }
  149. })
  150. setInputs(newInputs)
  151. return
  152. }
  153. if (!inputs.vision?.enabled) {
  154. const newInputs = produce(inputs, (draft) => {
  155. if (!draft.vision?.enabled) {
  156. draft.vision = {
  157. enabled: true,
  158. configs: {
  159. detail: Resolution.high,
  160. },
  161. }
  162. }
  163. })
  164. setInputs(newInputs)
  165. }
  166. // eslint-disable-next-line react-hooks/exhaustive-deps
  167. }, [isShowVisionConfig, modelChanged])
  168. // variables
  169. const { handleVarListChange, handleAddVariable } = useVarList<LLMNodeType>({
  170. inputs,
  171. setInputs,
  172. })
  173. // context
  174. const handleContextVarChange = useCallback((newVar: ValueSelector | string) => {
  175. const newInputs = produce(inputs, (draft) => {
  176. draft.context.variable_selector = newVar as ValueSelector || []
  177. draft.context.enabled = !!(newVar && newVar.length > 0)
  178. })
  179. setInputs(newInputs)
  180. }, [inputs, setInputs])
  181. const handlePromptChange = useCallback((newPrompt: PromptItem[] | PromptItem) => {
  182. const newInputs = produce(inputs, (draft) => {
  183. draft.prompt_template = newPrompt
  184. })
  185. setInputs(newInputs)
  186. }, [inputs, setInputs])
  187. const handleMemoryChange = useCallback((newMemory?: Memory) => {
  188. const newInputs = produce(inputs, (draft) => {
  189. draft.memory = newMemory
  190. })
  191. setInputs(newInputs)
  192. }, [inputs, setInputs])
  193. const handleSyeQueryChange = useCallback((newQuery: string) => {
  194. const newInputs = produce(inputs, (draft) => {
  195. if (!draft.memory) {
  196. draft.memory = {
  197. window: {
  198. enabled: false,
  199. size: 10,
  200. },
  201. query_prompt_template: newQuery,
  202. }
  203. }
  204. else {
  205. draft.memory.query_prompt_template = newQuery
  206. }
  207. })
  208. setInputs(newInputs)
  209. }, [inputs, setInputs])
  210. const handleVisionResolutionEnabledChange = useCallback((enabled: boolean) => {
  211. const newInputs = produce(inputs, (draft) => {
  212. if (!draft.vision) {
  213. draft.vision = {
  214. enabled,
  215. configs: {
  216. detail: Resolution.high,
  217. },
  218. }
  219. }
  220. else {
  221. draft.vision.enabled = enabled
  222. if (!draft.vision.configs) {
  223. draft.vision.configs = {
  224. detail: Resolution.high,
  225. }
  226. }
  227. }
  228. })
  229. setInputs(newInputs)
  230. }, [inputs, setInputs])
  231. const handleVisionResolutionChange = useCallback((newResolution: Resolution) => {
  232. const newInputs = produce(inputs, (draft) => {
  233. if (!draft.vision.configs) {
  234. draft.vision.configs = {
  235. detail: Resolution.high,
  236. }
  237. }
  238. draft.vision.configs.detail = newResolution
  239. })
  240. setInputs(newInputs)
  241. }, [inputs, setInputs])
  242. const filterInputVar = useCallback((varPayload: Var) => {
  243. return [VarType.number, VarType.string].includes(varPayload.type)
  244. }, [])
  245. const filterVar = useCallback((varPayload: Var) => {
  246. return [VarType.arrayObject, VarType.array, VarType.string].includes(varPayload.type)
  247. }, [])
  248. const {
  249. availableVars,
  250. availableNodes,
  251. } = useAvailableVarList(id, {
  252. onlyLeafNodeVar: false,
  253. filterVar,
  254. })
  255. // single run
  256. const {
  257. isShowSingleRun,
  258. hideSingleRun,
  259. getInputVars,
  260. runningStatus,
  261. handleRun,
  262. handleStop,
  263. runInputData,
  264. setRunInputData,
  265. runResult,
  266. } = useOneStepRun<LLMNodeType>({
  267. id,
  268. data: inputs,
  269. defaultRunInputData: {
  270. '#context#': [RETRIEVAL_OUTPUT_STRUCT],
  271. '#files#': [],
  272. },
  273. })
  274. // const handleRun = (submitData: Record<string, any>) => {
  275. // console.log(submitData)
  276. // const res = produce(submitData, (draft) => {
  277. // debugger
  278. // if (draft.contexts) {
  279. // draft['#context#'] = draft.contexts
  280. // delete draft.contexts
  281. // }
  282. // if (draft.visionFiles) {
  283. // draft['#files#'] = draft.visionFiles
  284. // delete draft.visionFiles
  285. // }
  286. // })
  287. // doHandleRun(res)
  288. // }
  289. const inputVarValues = (() => {
  290. const vars: Record<string, any> = {}
  291. Object.keys(runInputData)
  292. .filter(key => !['#context#', '#files#'].includes(key))
  293. .forEach((key) => {
  294. vars[key] = runInputData[key]
  295. })
  296. return vars
  297. })()
  298. const setInputVarValues = useCallback((newPayload: Record<string, any>) => {
  299. const newVars = {
  300. ...newPayload,
  301. '#context#': runInputData['#context#'],
  302. '#files#': runInputData['#files#'],
  303. }
  304. setRunInputData(newVars)
  305. }, [runInputData, setRunInputData])
  306. const contexts = runInputData['#context#']
  307. const setContexts = useCallback((newContexts: string[]) => {
  308. setRunInputData({
  309. ...runInputData,
  310. '#context#': newContexts,
  311. })
  312. }, [runInputData, setRunInputData])
  313. const visionFiles = runInputData['#files#']
  314. const setVisionFiles = useCallback((newFiles: any[]) => {
  315. setRunInputData({
  316. ...runInputData,
  317. '#files#': newFiles,
  318. })
  319. }, [runInputData, setRunInputData])
  320. const allVarStrArr = (() => {
  321. const arr = isChatModel ? (inputs.prompt_template as PromptItem[]).map(item => item.text) : [(inputs.prompt_template as PromptItem).text]
  322. if (isChatMode && isChatModel && !!inputs.memory) {
  323. arr.push('{{#sys.query#}}')
  324. arr.push(inputs.memory.query_prompt_template)
  325. }
  326. return arr
  327. })()
  328. const varInputs = getInputVars(allVarStrArr)
  329. return {
  330. readOnly,
  331. isChatMode,
  332. inputs,
  333. isChatModel,
  334. isCompletionModel,
  335. hasSetBlockStatus,
  336. shouldShowContextTip,
  337. isShowVisionConfig,
  338. handleModelChanged,
  339. handleCompletionParamsChange,
  340. handleVarListChange,
  341. handleAddVariable,
  342. handleContextVarChange,
  343. filterInputVar,
  344. filterVar,
  345. availableVars,
  346. availableNodes,
  347. handlePromptChange,
  348. handleMemoryChange,
  349. handleSyeQueryChange,
  350. handleVisionResolutionEnabledChange,
  351. handleVisionResolutionChange,
  352. isShowSingleRun,
  353. hideSingleRun,
  354. inputVarValues,
  355. setInputVarValues,
  356. visionFiles,
  357. setVisionFiles,
  358. contexts,
  359. setContexts,
  360. varInputs,
  361. runningStatus,
  362. handleRun,
  363. handleStop,
  364. runResult,
  365. }
  366. }
  367. export default useConfig