use-config.ts 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. import {
  2. useCallback,
  3. useEffect,
  4. useRef,
  5. useState,
  6. } from 'react'
  7. import produce from 'immer'
  8. import { isEqual } from 'lodash-es'
  9. import type { ValueSelector, Var } from '../../types'
  10. import { BlockEnum, VarType } from '../../types'
  11. import {
  12. useIsChatMode, useNodesReadOnly,
  13. useWorkflow,
  14. } from '../../hooks'
  15. import type { KnowledgeRetrievalNodeType, MultipleRetrievalConfig } from './types'
  16. import {
  17. getMultipleRetrievalConfig,
  18. getSelectedDatasetsMode,
  19. } from './utils'
  20. import { RETRIEVE_TYPE } from '@/types/app'
  21. import { DATASET_DEFAULT } from '@/config'
  22. import type { DataSet } from '@/models/datasets'
  23. import { fetchDatasets } from '@/service/datasets'
  24. import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
  25. import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
  26. import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
  27. import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
  28. const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
  29. const { nodesReadOnly: readOnly } = useNodesReadOnly()
  30. const isChatMode = useIsChatMode()
  31. const { getBeforeNodesInSameBranch } = useWorkflow()
  32. const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start)
  33. const startNodeId = startNode?.id
  34. const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
  35. const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => {
  36. const newInputs = produce(s, (draft) => {
  37. if (s.retrieval_mode === RETRIEVE_TYPE.multiWay)
  38. delete draft.single_retrieval_config
  39. else
  40. delete draft.multiple_retrieval_config
  41. })
  42. // not work in pass to draft...
  43. doSetInputs(newInputs)
  44. }, [doSetInputs])
  45. const inputRef = useRef(inputs)
  46. useEffect(() => {
  47. inputRef.current = inputs
  48. }, [inputs])
  49. const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => {
  50. const newInputs = produce(inputs, (draft) => {
  51. draft.query_variable_selector = newVar as ValueSelector
  52. })
  53. setInputs(newInputs)
  54. }, [inputs, setInputs])
  55. const {
  56. currentProvider,
  57. currentModel,
  58. } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
  59. const {
  60. defaultModel: rerankDefaultModel,
  61. } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
  62. const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
  63. const newInputs = produce(inputRef.current, (draft) => {
  64. if (!draft.single_retrieval_config) {
  65. draft.single_retrieval_config = {
  66. model: {
  67. provider: '',
  68. name: '',
  69. mode: '',
  70. completion_params: {},
  71. },
  72. }
  73. }
  74. const draftModel = draft.single_retrieval_config?.model
  75. draftModel.provider = model.provider
  76. draftModel.name = model.modelId
  77. draftModel.mode = model.mode!
  78. })
  79. setInputs(newInputs)
  80. }, [setInputs])
  81. const handleCompletionParamsChange = useCallback((newParams: Record<string, any>) => {
  82. // inputRef.current.single_retrieval_config?.model is old when change the provider...
  83. if (isEqual(newParams, inputRef.current.single_retrieval_config?.model.completion_params))
  84. return
  85. const newInputs = produce(inputRef.current, (draft) => {
  86. if (!draft.single_retrieval_config) {
  87. draft.single_retrieval_config = {
  88. model: {
  89. provider: '',
  90. name: '',
  91. mode: '',
  92. completion_params: {},
  93. },
  94. }
  95. }
  96. draft.single_retrieval_config.model.completion_params = newParams
  97. })
  98. setInputs(newInputs)
  99. }, [setInputs])
  100. // set defaults models
  101. useEffect(() => {
  102. const inputs = inputRef.current
  103. if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider)
  104. return
  105. if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider)
  106. return
  107. const newInput = produce(inputs, (draft) => {
  108. if (currentProvider?.provider && currentModel?.model) {
  109. const hasSetModel = draft.single_retrieval_config?.model?.provider
  110. if (!hasSetModel) {
  111. draft.single_retrieval_config = {
  112. model: {
  113. provider: currentProvider?.provider,
  114. name: currentModel?.model,
  115. mode: currentModel?.model_properties?.mode as string,
  116. completion_params: {},
  117. },
  118. }
  119. }
  120. }
  121. const multipleRetrievalConfig = draft.multiple_retrieval_config
  122. draft.multiple_retrieval_config = {
  123. top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
  124. score_threshold: multipleRetrievalConfig?.score_threshold,
  125. reranking_model: multipleRetrievalConfig?.reranking_model,
  126. }
  127. })
  128. setInputs(newInput)
  129. // eslint-disable-next-line react-hooks/exhaustive-deps
  130. }, [currentProvider?.provider, currentModel, rerankDefaultModel])
  131. const [selectedDatasets, setSelectedDatasets] = useState<DataSet[]>([])
  132. const [rerankModelOpen, setRerankModelOpen] = useState(false)
  133. const handleRetrievalModeChange = useCallback((newMode: RETRIEVE_TYPE) => {
  134. const newInputs = produce(inputs, (draft) => {
  135. draft.retrieval_mode = newMode
  136. if (newMode === RETRIEVE_TYPE.multiWay) {
  137. const multipleRetrievalConfig = draft.multiple_retrieval_config
  138. draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets)
  139. }
  140. else {
  141. const hasSetModel = draft.single_retrieval_config?.model?.provider
  142. if (!hasSetModel) {
  143. draft.single_retrieval_config = {
  144. model: {
  145. provider: currentProvider?.provider || '',
  146. name: currentModel?.model || '',
  147. mode: currentModel?.model_properties?.mode as string,
  148. completion_params: {},
  149. },
  150. }
  151. }
  152. }
  153. })
  154. setInputs(newInputs)
  155. }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets])
  156. const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => {
  157. const newInputs = produce(inputs, (draft) => {
  158. draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets)
  159. })
  160. setInputs(newInputs)
  161. }, [inputs, setInputs, selectedDatasets])
  162. // datasets
  163. useEffect(() => {
  164. (async () => {
  165. const inputs = inputRef.current
  166. const datasetIds = inputs.dataset_ids
  167. if (datasetIds?.length > 0) {
  168. const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasetIds } })
  169. setSelectedDatasets(dataSetsWithDetail)
  170. }
  171. const newInputs = produce(inputs, (draft) => {
  172. draft.dataset_ids = datasetIds
  173. })
  174. setInputs(newInputs)
  175. })()
  176. // eslint-disable-next-line react-hooks/exhaustive-deps
  177. }, [])
  178. useEffect(() => {
  179. let query_variable_selector: ValueSelector = inputs.query_variable_selector
  180. if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId)
  181. query_variable_selector = [startNodeId, 'sys.query']
  182. setInputs({
  183. ...inputs,
  184. query_variable_selector,
  185. })
  186. // eslint-disable-next-line react-hooks/exhaustive-deps
  187. }, [])
  188. const handleOnDatasetsChange = useCallback((newDatasets: DataSet[]) => {
  189. const {
  190. allEconomic,
  191. mixtureHighQualityAndEconomic,
  192. inconsistentEmbeddingModel,
  193. } = getSelectedDatasetsMode(newDatasets)
  194. const newInputs = produce(inputs, (draft) => {
  195. draft.dataset_ids = newDatasets.map(d => d.id)
  196. if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
  197. const multipleRetrievalConfig = draft.multiple_retrieval_config
  198. draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets)
  199. }
  200. })
  201. setInputs(newInputs)
  202. setSelectedDatasets(newDatasets)
  203. if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel)
  204. setRerankModelOpen(true)
  205. }, [inputs, setInputs, payload.retrieval_mode])
  206. const filterVar = useCallback((varPayload: Var) => {
  207. return varPayload.type === VarType.string
  208. }, [])
  209. // single run
  210. const {
  211. isShowSingleRun,
  212. hideSingleRun,
  213. runningStatus,
  214. handleRun,
  215. handleStop,
  216. runInputData,
  217. setRunInputData,
  218. runResult,
  219. } = useOneStepRun<KnowledgeRetrievalNodeType>({
  220. id,
  221. data: inputs,
  222. defaultRunInputData: {
  223. query: '',
  224. },
  225. })
  226. const query = runInputData.query
  227. const setQuery = useCallback((newQuery: string) => {
  228. setRunInputData({
  229. ...runInputData,
  230. query: newQuery,
  231. })
  232. }, [runInputData, setRunInputData])
  233. return {
  234. readOnly,
  235. inputs,
  236. handleQueryVarChange,
  237. filterVar,
  238. handleRetrievalModeChange,
  239. handleMultipleRetrievalConfigChange,
  240. handleModelChanged,
  241. handleCompletionParamsChange,
  242. selectedDatasets: selectedDatasets.filter(d => d.name),
  243. handleOnDatasetsChange,
  244. isShowSingleRun,
  245. hideSingleRun,
  246. runningStatus,
  247. handleRun,
  248. handleStop,
  249. query,
  250. setQuery,
  251. runResult,
  252. rerankModelOpen,
  253. setRerankModelOpen,
  254. }
  255. }
  256. export default useConfig