retrieval-config.tsx 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. 'use client'
  2. import type { FC } from 'react'
  3. import React, { useCallback, useState } from 'react'
  4. import { RiEqualizer2Line } from '@remixicon/react'
  5. import { useTranslation } from 'react-i18next'
  6. import type { MultipleRetrievalConfig, SingleRetrievalConfig } from '../types'
  7. import type { ModelConfig } from '../../../types'
  8. import cn from '@/utils/classnames'
  9. import {
  10. PortalToFollowElem,
  11. PortalToFollowElemContent,
  12. PortalToFollowElemTrigger,
  13. } from '@/app/components/base/portal-to-follow-elem'
  14. import ConfigRetrievalContent from '@/app/components/app/configuration/dataset-config/params-config/config-content'
  15. import { RETRIEVE_TYPE } from '@/types/app'
  16. import { DATASET_DEFAULT } from '@/config'
  17. import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
  18. import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
  19. import Button from '@/app/components/base/button'
  20. import type { DatasetConfigs } from '@/models/debug'
  21. import type { DataSet } from '@/models/datasets'
  22. type Props = {
  23. payload: {
  24. retrieval_mode: RETRIEVE_TYPE
  25. multiple_retrieval_config?: MultipleRetrievalConfig
  26. single_retrieval_config?: SingleRetrievalConfig
  27. }
  28. onRetrievalModeChange: (mode: RETRIEVE_TYPE) => void
  29. onMultipleRetrievalConfigChange: (config: MultipleRetrievalConfig) => void
  30. singleRetrievalModelConfig?: ModelConfig
  31. onSingleRetrievalModelChange?: (config: ModelConfig) => void
  32. onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void
  33. readonly?: boolean
  34. openFromProps?: boolean
  35. onOpenFromPropsChange?: (openFromProps: boolean) => void
  36. selectedDatasets: DataSet[]
  37. }
  38. const RetrievalConfig: FC<Props> = ({
  39. payload,
  40. onRetrievalModeChange,
  41. onMultipleRetrievalConfigChange,
  42. singleRetrievalModelConfig,
  43. onSingleRetrievalModelChange,
  44. onSingleRetrievalModelParamsChange,
  45. readonly,
  46. openFromProps,
  47. onOpenFromPropsChange,
  48. selectedDatasets,
  49. }) => {
  50. const { t } = useTranslation()
  51. const [open, setOpen] = useState(false)
  52. const mergedOpen = openFromProps !== undefined ? openFromProps : open
  53. const handleOpen = useCallback((newOpen: boolean) => {
  54. setOpen(newOpen)
  55. onOpenFromPropsChange?.(newOpen)
  56. }, [onOpenFromPropsChange])
  57. const {
  58. defaultModel: rerankDefaultModel,
  59. } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
  60. const { multiple_retrieval_config } = payload
  61. const handleChange = useCallback((configs: DatasetConfigs, isRetrievalModeChange?: boolean) => {
  62. if (isRetrievalModeChange) {
  63. onRetrievalModeChange(configs.retrieval_model)
  64. return
  65. }
  66. onMultipleRetrievalConfigChange({
  67. top_k: configs.top_k,
  68. score_threshold: configs.score_threshold_enabled ? (configs.score_threshold ?? DATASET_DEFAULT.score_threshold) : null,
  69. reranking_model: payload.retrieval_mode === RETRIEVE_TYPE.oneWay
  70. ? undefined
  71. : (!configs.reranking_model?.reranking_provider_name
  72. ? {
  73. provider: rerankDefaultModel?.provider?.provider || '',
  74. model: rerankDefaultModel?.model || '',
  75. }
  76. : {
  77. provider: configs.reranking_model?.reranking_provider_name,
  78. model: configs.reranking_model?.reranking_model_name,
  79. }),
  80. reranking_mode: configs.reranking_mode,
  81. weights: configs.weights as any,
  82. reranking_enable: configs.reranking_enable,
  83. })
  84. }, [onMultipleRetrievalConfigChange, payload.retrieval_mode, rerankDefaultModel?.provider?.provider, rerankDefaultModel?.model, onRetrievalModeChange])
  85. return (
  86. <PortalToFollowElem
  87. open={mergedOpen}
  88. onOpenChange={handleOpen}
  89. placement='bottom-end'
  90. offset={{
  91. crossAxis: -2,
  92. }}
  93. >
  94. <PortalToFollowElemTrigger
  95. onClick={() => {
  96. if (readonly)
  97. return
  98. handleOpen(!mergedOpen)
  99. }}
  100. >
  101. <Button
  102. variant='ghost'
  103. size='small'
  104. disabled={readonly}
  105. className={cn(open && 'bg-components-button-ghost-bg-hover')}
  106. >
  107. <RiEqualizer2Line className='mr-1 w-3.5 h-3.5' />
  108. {t('dataset.retrievalSettings')}
  109. </Button>
  110. </PortalToFollowElemTrigger>
  111. <PortalToFollowElemContent style={{ zIndex: 1001 }}>
  112. <div className='w-[404px] pt-3 pb-4 px-4 shadow-xl rounded-2xl border border-gray-200 bg-white'>
  113. <ConfigRetrievalContent
  114. datasetConfigs={
  115. {
  116. retrieval_model: payload.retrieval_mode,
  117. reranking_model: multiple_retrieval_config?.reranking_model?.provider
  118. ? {
  119. reranking_provider_name: multiple_retrieval_config.reranking_model?.provider,
  120. reranking_model_name: multiple_retrieval_config.reranking_model?.model,
  121. }
  122. : {
  123. reranking_provider_name: '',
  124. reranking_model_name: '',
  125. },
  126. top_k: multiple_retrieval_config?.top_k || DATASET_DEFAULT.top_k,
  127. score_threshold_enabled: !(multiple_retrieval_config?.score_threshold === undefined || multiple_retrieval_config.score_threshold === null),
  128. score_threshold: multiple_retrieval_config?.score_threshold,
  129. datasets: {
  130. datasets: [],
  131. },
  132. reranking_mode: multiple_retrieval_config?.reranking_mode,
  133. weights: multiple_retrieval_config?.weights,
  134. reranking_enable: multiple_retrieval_config?.reranking_enable,
  135. }
  136. }
  137. onChange={handleChange}
  138. isInWorkflow
  139. singleRetrievalModelConfig={singleRetrievalModelConfig}
  140. onSingleRetrievalModelChange={onSingleRetrievalModelChange}
  141. onSingleRetrievalModelParamsChange={onSingleRetrievalModelParamsChange}
  142. selectedDatasets={selectedDatasets}
  143. />
  144. </div>
  145. </PortalToFollowElemContent>
  146. </PortalToFollowElem>
  147. )
  148. }
  149. export default React.memo(RetrievalConfig)