use-config.ts 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. import { useCallback, useEffect, useRef, useState } from 'react'
  2. import produce from 'immer'
  3. import useVarList from '../_base/hooks/use-var-list'
  4. import { VarType } from '../../types'
  5. import type { Memory, ValueSelector, Var } from '../../types'
  6. import { useStore } from '../../store'
  7. import {
  8. useIsChatMode,
  9. useNodesReadOnly,
  10. } from '../../hooks'
  11. import type { LLMNodeType } from './types'
  12. import { Resolution } from '@/types/app'
  13. import { useModelListAndDefaultModelAndCurrentProviderAndModel, useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
  14. import {
  15. ModelFeatureEnum,
  16. ModelTypeEnum,
  17. } from '@/app/components/header/account-setting/model-provider-page/declarations'
  18. import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
  19. import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
  20. import type { PromptItem } from '@/models/debug'
  21. import { RETRIEVAL_OUTPUT_STRUCT } from '@/app/components/workflow/constants'
  22. import { checkHasContextBlock, checkHasHistoryBlock, checkHasQueryBlock } from '@/app/components/base/prompt-editor/constants'
  23. const useConfig = (id: string, payload: LLMNodeType) => {
  24. const { nodesReadOnly: readOnly } = useNodesReadOnly()
  25. const isChatMode = useIsChatMode()
  26. const defaultConfig = useStore(s => s.nodesDefaultConfigs)[payload.type]
  27. const [defaultRolePrefix, setDefaultRolePrefix] = useState<{ user: string; assistant: string }>({ user: '', assistant: '' })
  28. const { inputs, setInputs: doSetInputs } = useNodeCrud<LLMNodeType>(id, payload)
  29. const setInputs = useCallback((newInputs: LLMNodeType) => {
  30. if (newInputs.memory && !newInputs.memory.role_prefix) {
  31. const newPayload = produce(newInputs, (draft) => {
  32. draft.memory!.role_prefix = defaultRolePrefix
  33. })
  34. doSetInputs(newPayload)
  35. return
  36. }
  37. doSetInputs(newInputs)
  38. }, [doSetInputs, defaultRolePrefix])
  39. const inputRef = useRef(inputs)
  40. useEffect(() => {
  41. inputRef.current = inputs
  42. }, [inputs])
  43. // model
  44. const model = inputs.model
  45. const modelMode = inputs.model?.mode
  46. const isChatModel = modelMode === 'chat'
  47. const isCompletionModel = !isChatModel
  48. const hasSetBlockStatus = (() => {
  49. const promptTemplate = inputs.prompt_template
  50. const hasSetContext = isChatModel ? (promptTemplate as PromptItem[]).some(item => checkHasContextBlock(item.text)) : checkHasContextBlock((promptTemplate as PromptItem).text)
  51. if (!isChatMode) {
  52. return {
  53. history: false,
  54. query: false,
  55. context: hasSetContext,
  56. }
  57. }
  58. if (isChatModel) {
  59. return {
  60. history: false,
  61. query: (promptTemplate as PromptItem[]).some(item => checkHasQueryBlock(item.text)),
  62. context: hasSetContext,
  63. }
  64. }
  65. else {
  66. return {
  67. history: checkHasHistoryBlock((promptTemplate as PromptItem).text),
  68. query: checkHasQueryBlock((promptTemplate as PromptItem).text),
  69. context: hasSetContext,
  70. }
  71. }
  72. })()
  73. const shouldShowContextTip = !hasSetBlockStatus.context && inputs.context.enabled
  74. const appendDefaultPromptConfig = useCallback((draft: LLMNodeType, defaultConfig: any, passInIsChatMode?: boolean) => {
  75. const promptTemplates = defaultConfig.prompt_templates
  76. if (passInIsChatMode === undefined ? isChatModel : passInIsChatMode) {
  77. draft.prompt_template = promptTemplates.chat_model.prompts
  78. }
  79. else {
  80. draft.prompt_template = promptTemplates.completion_model.prompt
  81. setDefaultRolePrefix({
  82. user: promptTemplates.completion_model.conversation_histories_role.user_prefix,
  83. assistant: promptTemplates.completion_model.conversation_histories_role.assistant_prefix,
  84. })
  85. }
  86. }, [isChatModel])
  87. useEffect(() => {
  88. const isReady = defaultConfig && Object.keys(defaultConfig).length > 0
  89. if (isReady && !inputs.prompt_template) {
  90. const newInputs = produce(inputs, (draft) => {
  91. appendDefaultPromptConfig(draft, defaultConfig)
  92. })
  93. setInputs(newInputs)
  94. }
  95. // eslint-disable-next-line react-hooks/exhaustive-deps
  96. }, [defaultConfig, isChatModel])
  97. const [modelChanged, setModelChanged] = useState(false)
  98. const {
  99. currentProvider,
  100. currentModel,
  101. } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
  102. const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
  103. const newInputs = produce(inputRef.current, (draft) => {
  104. draft.model.provider = model.provider
  105. draft.model.name = model.modelId
  106. draft.model.mode = model.mode!
  107. const isModeChange = model.mode !== inputRef.current.model.mode
  108. if (isModeChange && defaultConfig && Object.keys(defaultConfig).length > 0)
  109. appendDefaultPromptConfig(draft, defaultConfig, model.mode === 'chat')
  110. })
  111. setInputs(newInputs)
  112. setModelChanged(true)
  113. }, [setInputs, defaultConfig, appendDefaultPromptConfig])
  114. useEffect(() => {
  115. if (currentProvider?.provider && currentModel?.model && !model.provider) {
  116. handleModelChanged({
  117. provider: currentProvider?.provider,
  118. modelId: currentModel?.model,
  119. mode: currentModel?.model_properties?.mode as string,
  120. })
  121. }
  122. }, [model.provider, currentProvider, currentModel, handleModelChanged])
  123. const handleCompletionParamsChange = useCallback((newParams: Record<string, any>) => {
  124. const newInputs = produce(inputs, (draft) => {
  125. draft.model.completion_params = newParams
  126. })
  127. setInputs(newInputs)
  128. }, [inputs, setInputs])
  129. const {
  130. currentModel: currModel,
  131. } = useTextGenerationCurrentProviderAndModelAndModelList(
  132. {
  133. provider: model.provider,
  134. model: model.name,
  135. },
  136. )
  137. const isShowVisionConfig = !!currModel?.features?.includes(ModelFeatureEnum.vision)
  138. // change to vision model to set vision enabled, else disabled
  139. useEffect(() => {
  140. if (!modelChanged)
  141. return
  142. setModelChanged(false)
  143. if (!isShowVisionConfig) {
  144. const newInputs = produce(inputs, (draft) => {
  145. draft.vision = {
  146. enabled: false,
  147. }
  148. })
  149. setInputs(newInputs)
  150. return
  151. }
  152. if (!inputs.vision?.enabled) {
  153. const newInputs = produce(inputs, (draft) => {
  154. if (!draft.vision?.enabled) {
  155. draft.vision = {
  156. enabled: true,
  157. configs: {
  158. detail: Resolution.high,
  159. },
  160. }
  161. }
  162. })
  163. setInputs(newInputs)
  164. }
  165. // eslint-disable-next-line react-hooks/exhaustive-deps
  166. }, [isShowVisionConfig, modelChanged])
  167. // variables
  168. const { handleVarListChange, handleAddVariable } = useVarList<LLMNodeType>({
  169. inputs,
  170. setInputs,
  171. })
  172. // context
  173. const handleContextVarChange = useCallback((newVar: ValueSelector | string) => {
  174. const newInputs = produce(inputs, (draft) => {
  175. draft.context.variable_selector = newVar as ValueSelector || []
  176. draft.context.enabled = !!(newVar && newVar.length > 0)
  177. })
  178. setInputs(newInputs)
  179. }, [inputs, setInputs])
  180. const handlePromptChange = useCallback((newPrompt: PromptItem[] | PromptItem) => {
  181. const newInputs = produce(inputs, (draft) => {
  182. draft.prompt_template = newPrompt
  183. })
  184. setInputs(newInputs)
  185. }, [inputs, setInputs])
  186. const handleMemoryChange = useCallback((newMemory?: Memory) => {
  187. const newInputs = produce(inputs, (draft) => {
  188. draft.memory = newMemory
  189. })
  190. setInputs(newInputs)
  191. }, [inputs, setInputs])
  192. const handleVisionResolutionEnabledChange = useCallback((enabled: boolean) => {
  193. const newInputs = produce(inputs, (draft) => {
  194. if (!draft.vision) {
  195. draft.vision = {
  196. enabled,
  197. configs: {
  198. detail: Resolution.high,
  199. },
  200. }
  201. }
  202. else {
  203. draft.vision.enabled = enabled
  204. if (!draft.vision.configs) {
  205. draft.vision.configs = {
  206. detail: Resolution.high,
  207. }
  208. }
  209. }
  210. })
  211. setInputs(newInputs)
  212. }, [inputs, setInputs])
  213. const handleVisionResolutionChange = useCallback((newResolution: Resolution) => {
  214. const newInputs = produce(inputs, (draft) => {
  215. if (!draft.vision.configs) {
  216. draft.vision.configs = {
  217. detail: Resolution.high,
  218. }
  219. }
  220. draft.vision.configs.detail = newResolution
  221. })
  222. setInputs(newInputs)
  223. }, [inputs, setInputs])
  224. const filterInputVar = useCallback((varPayload: Var) => {
  225. return [VarType.number, VarType.string].includes(varPayload.type)
  226. }, [])
  227. const filterVar = useCallback((varPayload: Var) => {
  228. return [VarType.arrayObject, VarType.array, VarType.string].includes(varPayload.type)
  229. }, [])
  230. // single run
  231. const {
  232. isShowSingleRun,
  233. hideSingleRun,
  234. getInputVars,
  235. runningStatus,
  236. handleRun,
  237. handleStop,
  238. runInputData,
  239. setRunInputData,
  240. runResult,
  241. } = useOneStepRun<LLMNodeType>({
  242. id,
  243. data: inputs,
  244. defaultRunInputData: {
  245. '#context#': [RETRIEVAL_OUTPUT_STRUCT],
  246. '#files#': [],
  247. },
  248. })
  249. // const handleRun = (submitData: Record<string, any>) => {
  250. // console.log(submitData)
  251. // const res = produce(submitData, (draft) => {
  252. // debugger
  253. // if (draft.contexts) {
  254. // draft['#context#'] = draft.contexts
  255. // delete draft.contexts
  256. // }
  257. // if (draft.visionFiles) {
  258. // draft['#files#'] = draft.visionFiles
  259. // delete draft.visionFiles
  260. // }
  261. // })
  262. // doHandleRun(res)
  263. // }
  264. const inputVarValues = (() => {
  265. const vars: Record<string, any> = {}
  266. Object.keys(runInputData)
  267. .filter(key => !['#context#', '#files#'].includes(key))
  268. .forEach((key) => {
  269. vars[key] = runInputData[key]
  270. })
  271. return vars
  272. })()
  273. const setInputVarValues = useCallback((newPayload: Record<string, any>) => {
  274. const newVars = {
  275. ...newPayload,
  276. '#context#': runInputData['#context#'],
  277. '#files#': runInputData['#files#'],
  278. }
  279. setRunInputData(newVars)
  280. }, [runInputData, setRunInputData])
  281. const contexts = runInputData['#context#']
  282. const setContexts = useCallback((newContexts: string[]) => {
  283. setRunInputData({
  284. ...runInputData,
  285. '#context#': newContexts,
  286. })
  287. }, [runInputData, setRunInputData])
  288. const visionFiles = runInputData['#files#']
  289. const setVisionFiles = useCallback((newFiles: any[]) => {
  290. setRunInputData({
  291. ...runInputData,
  292. '#files#': newFiles,
  293. })
  294. }, [runInputData, setRunInputData])
  295. const allVarStrArr = (() => {
  296. const arr = isChatModel ? (inputs.prompt_template as PromptItem[]).map(item => item.text) : [(inputs.prompt_template as PromptItem).text]
  297. if (isChatMode && isChatModel && !!inputs.memory)
  298. arr.push('{{#sys.query#}}')
  299. return arr
  300. })()
  301. const varInputs = getInputVars(allVarStrArr)
  302. return {
  303. readOnly,
  304. isChatMode,
  305. inputs,
  306. isChatModel,
  307. isCompletionModel,
  308. hasSetBlockStatus,
  309. shouldShowContextTip,
  310. isShowVisionConfig,
  311. handleModelChanged,
  312. handleCompletionParamsChange,
  313. handleVarListChange,
  314. handleAddVariable,
  315. handleContextVarChange,
  316. filterInputVar,
  317. filterVar,
  318. handlePromptChange,
  319. handleMemoryChange,
  320. handleVisionResolutionEnabledChange,
  321. handleVisionResolutionChange,
  322. isShowSingleRun,
  323. hideSingleRun,
  324. inputVarValues,
  325. setInputVarValues,
  326. visionFiles,
  327. setVisionFiles,
  328. contexts,
  329. setContexts,
  330. varInputs,
  331. runningStatus,
  332. handleRun,
  333. handleStop,
  334. runResult,
  335. }
  336. }
  337. export default useConfig