retrieval-config.tsx 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. 'use client'
  2. import type { FC } from 'react'
  3. import React, { useCallback, useState } from 'react'
  4. import { useTranslation } from 'react-i18next'
  5. import { RiArrowDownSLine } from '@remixicon/react'
  6. import type { MultipleRetrievalConfig, SingleRetrievalConfig } from '../types'
  7. import type { ModelConfig } from '../../../types'
  8. import cn from '@/utils/classnames'
  9. import {
  10. PortalToFollowElem,
  11. PortalToFollowElemContent,
  12. PortalToFollowElemTrigger,
  13. } from '@/app/components/base/portal-to-follow-elem'
  14. import ConfigRetrievalContent from '@/app/components/app/configuration/dataset-config/params-config/config-content'
  15. import { RETRIEVE_TYPE } from '@/types/app'
  16. import { DATASET_DEFAULT } from '@/config'
  17. import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
  18. import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
  19. import type {
  20. DatasetConfigs,
  21. } from '@/models/debug'
  22. type Props = {
  23. payload: {
  24. retrieval_mode: RETRIEVE_TYPE
  25. multiple_retrieval_config?: MultipleRetrievalConfig
  26. single_retrieval_config?: SingleRetrievalConfig
  27. }
  28. onRetrievalModeChange: (mode: RETRIEVE_TYPE) => void
  29. onMultipleRetrievalConfigChange: (config: MultipleRetrievalConfig) => void
  30. singleRetrievalModelConfig?: ModelConfig
  31. onSingleRetrievalModelChange?: (config: ModelConfig) => void
  32. onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void
  33. readonly?: boolean
  34. }
  35. const RetrievalConfig: FC<Props> = ({
  36. payload,
  37. onRetrievalModeChange,
  38. onMultipleRetrievalConfigChange,
  39. singleRetrievalModelConfig,
  40. onSingleRetrievalModelChange,
  41. onSingleRetrievalModelParamsChange,
  42. readonly,
  43. }) => {
  44. const { t } = useTranslation()
  45. const [open, setOpen] = useState(false)
  46. const {
  47. defaultModel: rerankDefaultModel,
  48. } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
  49. const { multiple_retrieval_config } = payload
  50. const handleChange = useCallback((configs: DatasetConfigs, isRetrievalModeChange?: boolean) => {
  51. if (isRetrievalModeChange) {
  52. onRetrievalModeChange(configs.retrieval_model)
  53. return
  54. }
  55. onMultipleRetrievalConfigChange({
  56. top_k: configs.top_k,
  57. score_threshold: configs.score_threshold_enabled ? (configs.score_threshold || DATASET_DEFAULT.score_threshold) : null,
  58. reranking_model: payload.retrieval_mode === RETRIEVE_TYPE.oneWay
  59. ? undefined
  60. : (!configs.reranking_model?.reranking_provider_name
  61. ? {
  62. provider: rerankDefaultModel?.provider?.provider || '',
  63. model: rerankDefaultModel?.model || '',
  64. }
  65. : {
  66. provider: configs.reranking_model?.reranking_provider_name,
  67. model: configs.reranking_model?.reranking_model_name,
  68. }),
  69. })
  70. }, [onMultipleRetrievalConfigChange, payload.retrieval_mode, rerankDefaultModel?.provider?.provider, rerankDefaultModel?.model, onRetrievalModeChange])
  71. return (
  72. <PortalToFollowElem
  73. open={open}
  74. onOpenChange={setOpen}
  75. placement='bottom-end'
  76. offset={{
  77. // mainAxis: 12,
  78. crossAxis: -2,
  79. }}
  80. >
  81. <PortalToFollowElemTrigger
  82. onClick={() => {
  83. if (readonly)
  84. return
  85. setOpen(v => !v)
  86. }}
  87. >
  88. <div className={cn(!readonly && 'cursor-pointer', open && 'bg-gray-100', 'flex items-center h-6 px-2 rounded-md hover:bg-gray-100 group select-none')}>
  89. <div className={cn(open ? 'text-gray-700' : 'text-gray-500', 'leading-[18px] text-xs font-medium group-hover:bg-gray-100')}>{payload.retrieval_mode === RETRIEVE_TYPE.oneWay ? t('appDebug.datasetConfig.retrieveOneWay.title') : t('appDebug.datasetConfig.retrieveMultiWay.title')}</div>
  90. {!readonly && <RiArrowDownSLine className='w-3 h-3 ml-1' />}
  91. </div>
  92. </PortalToFollowElemTrigger>
  93. <PortalToFollowElemContent style={{ zIndex: 1001 }}>
  94. <div className='w-[404px] pt-3 pb-4 px-4 shadow-xl rounded-2xl border border-gray-200 bg-white'>
  95. <ConfigRetrievalContent
  96. datasetConfigs={
  97. {
  98. retrieval_model: payload.retrieval_mode,
  99. reranking_model: !multiple_retrieval_config?.reranking_model?.provider
  100. ? {
  101. reranking_provider_name: rerankDefaultModel?.provider?.provider || '',
  102. reranking_model_name: rerankDefaultModel?.model || '',
  103. }
  104. : {
  105. reranking_provider_name: multiple_retrieval_config?.reranking_model?.provider || '',
  106. reranking_model_name: multiple_retrieval_config?.reranking_model?.model || '',
  107. },
  108. top_k: multiple_retrieval_config?.top_k || DATASET_DEFAULT.top_k,
  109. score_threshold_enabled: !(multiple_retrieval_config?.score_threshold === undefined || multiple_retrieval_config?.score_threshold === null),
  110. score_threshold: multiple_retrieval_config?.score_threshold,
  111. datasets: {
  112. datasets: [],
  113. },
  114. }
  115. }
  116. onChange={handleChange}
  117. isInWorkflow
  118. singleRetrievalModelConfig={singleRetrievalModelConfig}
  119. onSingleRetrievalModelChange={onSingleRetrievalModelChange}
  120. onSingleRetrievalModelParamsChange={onSingleRetrievalModelParamsChange}
  121. />
  122. </div>
  123. </PortalToFollowElemContent>
  124. </PortalToFollowElem>
  125. )
  126. }
  127. export default React.memo(RetrievalConfig)