use-config.ts 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. import { useCallback, useEffect, useRef, useState } from 'react'
  2. import produce from 'immer'
  3. import { EditionType, VarType } from '../../types'
  4. import type { Memory, PromptItem, ValueSelector, Var, Variable } from '../../types'
  5. import { useStore } from '../../store'
  6. import {
  7. useIsChatMode,
  8. useNodesReadOnly,
  9. } from '../../hooks'
  10. import useAvailableVarList from '../_base/hooks/use-available-var-list'
  11. import type { LLMNodeType } from './types'
  12. import { Resolution } from '@/types/app'
  13. import { useModelListAndDefaultModelAndCurrentProviderAndModel, useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
  14. import {
  15. ModelFeatureEnum,
  16. ModelTypeEnum,
  17. } from '@/app/components/header/account-setting/model-provider-page/declarations'
  18. import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
  19. import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
  20. import { RETRIEVAL_OUTPUT_STRUCT } from '@/app/components/workflow/constants'
  21. import { checkHasContextBlock, checkHasHistoryBlock, checkHasQueryBlock } from '@/app/components/base/prompt-editor/constants'
  22. const useConfig = (id: string, payload: LLMNodeType) => {
  23. const { nodesReadOnly: readOnly } = useNodesReadOnly()
  24. const isChatMode = useIsChatMode()
  25. const defaultConfig = useStore(s => s.nodesDefaultConfigs)[payload.type]
  26. const [defaultRolePrefix, setDefaultRolePrefix] = useState<{ user: string; assistant: string }>({ user: '', assistant: '' })
  27. const { inputs, setInputs: doSetInputs } = useNodeCrud<LLMNodeType>(id, payload)
  28. const inputRef = useRef(inputs)
  29. const setInputs = useCallback((newInputs: LLMNodeType) => {
  30. if (newInputs.memory && !newInputs.memory.role_prefix) {
  31. const newPayload = produce(newInputs, (draft) => {
  32. draft.memory!.role_prefix = defaultRolePrefix
  33. })
  34. doSetInputs(newPayload)
  35. inputRef.current = newPayload
  36. return
  37. }
  38. doSetInputs(newInputs)
  39. inputRef.current = newInputs
  40. }, [doSetInputs, defaultRolePrefix])
  41. // model
  42. const model = inputs.model
  43. const modelMode = inputs.model?.mode
  44. const isChatModel = modelMode === 'chat'
  45. const isCompletionModel = !isChatModel
  46. const hasSetBlockStatus = (() => {
  47. const promptTemplate = inputs.prompt_template
  48. const hasSetContext = isChatModel ? (promptTemplate as PromptItem[]).some(item => checkHasContextBlock(item.text)) : checkHasContextBlock((promptTemplate as PromptItem).text)
  49. if (!isChatMode) {
  50. return {
  51. history: false,
  52. query: false,
  53. context: hasSetContext,
  54. }
  55. }
  56. if (isChatModel) {
  57. return {
  58. history: false,
  59. query: (promptTemplate as PromptItem[]).some(item => checkHasQueryBlock(item.text)),
  60. context: hasSetContext,
  61. }
  62. }
  63. else {
  64. return {
  65. history: checkHasHistoryBlock((promptTemplate as PromptItem).text),
  66. query: checkHasQueryBlock((promptTemplate as PromptItem).text),
  67. context: hasSetContext,
  68. }
  69. }
  70. })()
  71. const shouldShowContextTip = !hasSetBlockStatus.context && inputs.context.enabled
  72. const appendDefaultPromptConfig = useCallback((draft: LLMNodeType, defaultConfig: any, passInIsChatMode?: boolean) => {
  73. const promptTemplates = defaultConfig.prompt_templates
  74. if (passInIsChatMode === undefined ? isChatModel : passInIsChatMode) {
  75. draft.prompt_template = promptTemplates.chat_model.prompts
  76. }
  77. else {
  78. draft.prompt_template = promptTemplates.completion_model.prompt
  79. setDefaultRolePrefix({
  80. user: promptTemplates.completion_model.conversation_histories_role.user_prefix,
  81. assistant: promptTemplates.completion_model.conversation_histories_role.assistant_prefix,
  82. })
  83. }
  84. }, [isChatModel])
  85. useEffect(() => {
  86. const isReady = defaultConfig && Object.keys(defaultConfig).length > 0
  87. if (isReady && !inputs.prompt_template) {
  88. const newInputs = produce(inputs, (draft) => {
  89. appendDefaultPromptConfig(draft, defaultConfig)
  90. })
  91. setInputs(newInputs)
  92. }
  93. // eslint-disable-next-line react-hooks/exhaustive-deps
  94. }, [defaultConfig, isChatModel])
  95. const [modelChanged, setModelChanged] = useState(false)
  96. const {
  97. currentProvider,
  98. currentModel,
  99. } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
  100. const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
  101. const newInputs = produce(inputRef.current, (draft) => {
  102. draft.model.provider = model.provider
  103. draft.model.name = model.modelId
  104. draft.model.mode = model.mode!
  105. const isModeChange = model.mode !== inputRef.current.model.mode
  106. if (isModeChange && defaultConfig && Object.keys(defaultConfig).length > 0)
  107. appendDefaultPromptConfig(draft, defaultConfig, model.mode === 'chat')
  108. })
  109. setInputs(newInputs)
  110. setModelChanged(true)
  111. }, [setInputs, defaultConfig, appendDefaultPromptConfig])
  112. useEffect(() => {
  113. if (currentProvider?.provider && currentModel?.model && !model.provider) {
  114. handleModelChanged({
  115. provider: currentProvider?.provider,
  116. modelId: currentModel?.model,
  117. mode: currentModel?.model_properties?.mode as string,
  118. })
  119. }
  120. }, [model.provider, currentProvider, currentModel, handleModelChanged])
  121. const handleCompletionParamsChange = useCallback((newParams: Record<string, any>) => {
  122. const newInputs = produce(inputs, (draft) => {
  123. draft.model.completion_params = newParams
  124. })
  125. setInputs(newInputs)
  126. }, [inputs, setInputs])
  127. const {
  128. currentModel: currModel,
  129. } = useTextGenerationCurrentProviderAndModelAndModelList(
  130. {
  131. provider: model.provider,
  132. model: model.name,
  133. },
  134. )
  135. const isShowVisionConfig = !!currModel?.features?.includes(ModelFeatureEnum.vision)
  136. // change to vision model to set vision enabled, else disabled
  137. useEffect(() => {
  138. if (!modelChanged)
  139. return
  140. setModelChanged(false)
  141. if (!isShowVisionConfig) {
  142. const newInputs = produce(inputs, (draft) => {
  143. draft.vision = {
  144. enabled: false,
  145. }
  146. })
  147. setInputs(newInputs)
  148. return
  149. }
  150. if (!inputs.vision?.enabled) {
  151. const newInputs = produce(inputs, (draft) => {
  152. if (!draft.vision?.enabled) {
  153. draft.vision = {
  154. enabled: true,
  155. configs: {
  156. detail: Resolution.high,
  157. },
  158. }
  159. }
  160. })
  161. setInputs(newInputs)
  162. }
  163. // eslint-disable-next-line react-hooks/exhaustive-deps
  164. }, [isShowVisionConfig, modelChanged])
  165. // variables
  166. const isShowVars = (() => {
  167. if (isChatModel)
  168. return (inputs.prompt_template as PromptItem[]).some(item => item.edition_type === EditionType.jinja2)
  169. return (inputs.prompt_template as PromptItem).edition_type === EditionType.jinja2
  170. })()
  171. const handleAddEmptyVariable = useCallback(() => {
  172. const newInputs = produce(inputRef.current, (draft) => {
  173. if (!draft.prompt_config) {
  174. draft.prompt_config = {
  175. jinja2_variables: [],
  176. }
  177. }
  178. if (!draft.prompt_config.jinja2_variables)
  179. draft.prompt_config.jinja2_variables = []
  180. draft.prompt_config.jinja2_variables.push({
  181. variable: '',
  182. value_selector: [],
  183. })
  184. })
  185. setInputs(newInputs)
  186. }, [setInputs])
  187. const handleAddVariable = useCallback((payload: Variable) => {
  188. const newInputs = produce(inputRef.current, (draft) => {
  189. if (!draft.prompt_config) {
  190. draft.prompt_config = {
  191. jinja2_variables: [],
  192. }
  193. }
  194. if (!draft.prompt_config.jinja2_variables)
  195. draft.prompt_config.jinja2_variables = []
  196. draft.prompt_config.jinja2_variables.push(payload)
  197. })
  198. setInputs(newInputs)
  199. }, [setInputs])
  200. const handleVarListChange = useCallback((newList: Variable[]) => {
  201. const newInputs = produce(inputRef.current, (draft) => {
  202. if (!draft.prompt_config) {
  203. draft.prompt_config = {
  204. jinja2_variables: [],
  205. }
  206. }
  207. if (!draft.prompt_config.jinja2_variables)
  208. draft.prompt_config.jinja2_variables = []
  209. draft.prompt_config.jinja2_variables = newList
  210. })
  211. setInputs(newInputs)
  212. }, [setInputs])
  213. const handleVarNameChange = useCallback((oldName: string, newName: string) => {
  214. const newInputs = produce(inputRef.current, (draft) => {
  215. if (isChatModel) {
  216. const promptTemplate = draft.prompt_template as PromptItem[]
  217. promptTemplate.filter(item => item.edition_type === EditionType.jinja2).forEach((item) => {
  218. item.jinja2_text = (item.jinja2_text || '').replaceAll(`{{ ${oldName} }}`, `{{ ${newName} }}`)
  219. })
  220. }
  221. else {
  222. if ((draft.prompt_template as PromptItem).edition_type !== EditionType.jinja2)
  223. return
  224. const promptTemplate = draft.prompt_template as PromptItem
  225. promptTemplate.jinja2_text = (promptTemplate.jinja2_text || '').replaceAll(`{{ ${oldName} }}`, `{{ ${newName} }}`)
  226. }
  227. })
  228. setInputs(newInputs)
  229. }, [isChatModel, setInputs])
  230. // context
  231. const handleContextVarChange = useCallback((newVar: ValueSelector | string) => {
  232. const newInputs = produce(inputs, (draft) => {
  233. draft.context.variable_selector = newVar as ValueSelector || []
  234. draft.context.enabled = !!(newVar && newVar.length > 0)
  235. })
  236. setInputs(newInputs)
  237. }, [inputs, setInputs])
  238. const handlePromptChange = useCallback((newPrompt: PromptItem[] | PromptItem) => {
  239. const newInputs = produce(inputRef.current, (draft) => {
  240. draft.prompt_template = newPrompt
  241. })
  242. setInputs(newInputs)
  243. }, [setInputs])
  244. const handleMemoryChange = useCallback((newMemory?: Memory) => {
  245. const newInputs = produce(inputs, (draft) => {
  246. draft.memory = newMemory
  247. })
  248. setInputs(newInputs)
  249. }, [inputs, setInputs])
  250. const handleSyeQueryChange = useCallback((newQuery: string) => {
  251. const newInputs = produce(inputs, (draft) => {
  252. if (!draft.memory) {
  253. draft.memory = {
  254. window: {
  255. enabled: false,
  256. size: 10,
  257. },
  258. query_prompt_template: newQuery,
  259. }
  260. }
  261. else {
  262. draft.memory.query_prompt_template = newQuery
  263. }
  264. })
  265. setInputs(newInputs)
  266. }, [inputs, setInputs])
  267. const handleVisionResolutionEnabledChange = useCallback((enabled: boolean) => {
  268. const newInputs = produce(inputs, (draft) => {
  269. if (!draft.vision) {
  270. draft.vision = {
  271. enabled,
  272. configs: {
  273. detail: Resolution.high,
  274. },
  275. }
  276. }
  277. else {
  278. draft.vision.enabled = enabled
  279. if (!draft.vision.configs) {
  280. draft.vision.configs = {
  281. detail: Resolution.high,
  282. }
  283. }
  284. }
  285. })
  286. setInputs(newInputs)
  287. }, [inputs, setInputs])
  288. const handleVisionResolutionChange = useCallback((newResolution: Resolution) => {
  289. const newInputs = produce(inputs, (draft) => {
  290. if (!draft.vision.configs) {
  291. draft.vision.configs = {
  292. detail: Resolution.high,
  293. }
  294. }
  295. draft.vision.configs.detail = newResolution
  296. })
  297. setInputs(newInputs)
  298. }, [inputs, setInputs])
  299. const filterInputVar = useCallback((varPayload: Var) => {
  300. return [VarType.number, VarType.string, VarType.secret].includes(varPayload.type)
  301. }, [])
  302. const filterVar = useCallback((varPayload: Var) => {
  303. return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret].includes(varPayload.type)
  304. }, [])
  305. const {
  306. availableVars,
  307. availableNodesWithParent,
  308. } = useAvailableVarList(id, {
  309. onlyLeafNodeVar: false,
  310. filterVar,
  311. })
  312. // single run
  313. const {
  314. isShowSingleRun,
  315. hideSingleRun,
  316. getInputVars,
  317. runningStatus,
  318. handleRun,
  319. handleStop,
  320. runInputData,
  321. setRunInputData,
  322. runResult,
  323. toVarInputs,
  324. } = useOneStepRun<LLMNodeType>({
  325. id,
  326. data: inputs,
  327. defaultRunInputData: {
  328. '#context#': [RETRIEVAL_OUTPUT_STRUCT],
  329. '#files#': [],
  330. },
  331. })
  332. const inputVarValues = (() => {
  333. const vars: Record<string, any> = {}
  334. Object.keys(runInputData)
  335. .filter(key => !['#context#', '#files#'].includes(key))
  336. .forEach((key) => {
  337. vars[key] = runInputData[key]
  338. })
  339. return vars
  340. })()
  341. const setInputVarValues = useCallback((newPayload: Record<string, any>) => {
  342. const newVars = {
  343. ...newPayload,
  344. '#context#': runInputData['#context#'],
  345. '#files#': runInputData['#files#'],
  346. }
  347. setRunInputData(newVars)
  348. }, [runInputData, setRunInputData])
  349. const contexts = runInputData['#context#']
  350. const setContexts = useCallback((newContexts: string[]) => {
  351. setRunInputData({
  352. ...runInputData,
  353. '#context#': newContexts,
  354. })
  355. }, [runInputData, setRunInputData])
  356. const visionFiles = runInputData['#files#']
  357. const setVisionFiles = useCallback((newFiles: any[]) => {
  358. setRunInputData({
  359. ...runInputData,
  360. '#files#': newFiles,
  361. })
  362. }, [runInputData, setRunInputData])
  363. const allVarStrArr = (() => {
  364. const arr = isChatModel ? (inputs.prompt_template as PromptItem[]).filter(item => item.edition_type !== EditionType.jinja2).map(item => item.text) : [(inputs.prompt_template as PromptItem).text]
  365. if (isChatMode && isChatModel && !!inputs.memory) {
  366. arr.push('{{#sys.query#}}')
  367. arr.push(inputs.memory.query_prompt_template)
  368. }
  369. return arr
  370. })()
  371. const varInputs = (() => {
  372. const vars = getInputVars(allVarStrArr)
  373. if (isShowVars)
  374. return [...vars, ...toVarInputs(inputs.prompt_config?.jinja2_variables || [])]
  375. return vars
  376. })()
  377. return {
  378. readOnly,
  379. isChatMode,
  380. inputs,
  381. isChatModel,
  382. isCompletionModel,
  383. hasSetBlockStatus,
  384. shouldShowContextTip,
  385. isShowVisionConfig,
  386. handleModelChanged,
  387. handleCompletionParamsChange,
  388. isShowVars,
  389. handleVarListChange,
  390. handleVarNameChange,
  391. handleAddVariable,
  392. handleAddEmptyVariable,
  393. handleContextVarChange,
  394. filterInputVar,
  395. filterVar,
  396. availableVars,
  397. availableNodesWithParent,
  398. handlePromptChange,
  399. handleMemoryChange,
  400. handleSyeQueryChange,
  401. handleVisionResolutionEnabledChange,
  402. handleVisionResolutionChange,
  403. isShowSingleRun,
  404. hideSingleRun,
  405. inputVarValues,
  406. setInputVarValues,
  407. visionFiles,
  408. setVisionFiles,
  409. contexts,
  410. setContexts,
  411. varInputs,
  412. runningStatus,
  413. handleRun,
  414. handleStop,
  415. runResult,
  416. }
  417. }
  418. export default useConfig