index.tsx 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. import type { FC } from 'react'
  2. import { useState } from 'react'
  3. import { useTranslation } from 'react-i18next'
  4. import ModelSelector from '../model-selector'
  5. import {
  6. useModelList,
  7. useSystemDefaultModelAndModelList,
  8. useUpdateModelList,
  9. } from '../hooks'
  10. import type {
  11. DefaultModel,
  12. DefaultModelResponse,
  13. } from '../declarations'
  14. import { ModelTypeEnum } from '../declarations'
  15. import Tooltip from '@/app/components/base/tooltip'
  16. import { HelpCircle, Settings01 } from '@/app/components/base/icons/src/vender/line/general'
  17. import {
  18. PortalToFollowElem,
  19. PortalToFollowElemContent,
  20. PortalToFollowElemTrigger,
  21. } from '@/app/components/base/portal-to-follow-elem'
  22. import Button from '@/app/components/base/button'
  23. import { useProviderContext } from '@/context/provider-context'
  24. import { updateDefaultModel } from '@/service/common'
  25. import { useToastContext } from '@/app/components/base/toast'
  26. type SystemModelSelectorProps = {
  27. textGenerationDefaultModel: DefaultModelResponse | undefined
  28. embeddingsDefaultModel: DefaultModelResponse | undefined
  29. rerankDefaultModel: DefaultModelResponse | undefined
  30. speech2textDefaultModel: DefaultModelResponse | undefined
  31. }
  32. const SystemModel: FC<SystemModelSelectorProps> = ({
  33. textGenerationDefaultModel,
  34. embeddingsDefaultModel,
  35. rerankDefaultModel,
  36. speech2textDefaultModel,
  37. }) => {
  38. const { t } = useTranslation()
  39. const { notify } = useToastContext()
  40. const { textGenerationModelList } = useProviderContext()
  41. const updateModelList = useUpdateModelList()
  42. const { data: embeddingModelList } = useModelList(2)
  43. const { data: rerankModelList } = useModelList(3)
  44. const { data: speech2textModelList } = useModelList(4)
  45. const [changedModelTypes, setChangedModelTypes] = useState<ModelTypeEnum[]>([])
  46. const [currentTextGenerationDefaultModel, changeCurrentTextGenerationDefaultModel] = useSystemDefaultModelAndModelList(textGenerationDefaultModel, textGenerationModelList)
  47. const [currentEmbeddingsDefaultModel, changeCurrentEmbeddingsDefaultModel] = useSystemDefaultModelAndModelList(embeddingsDefaultModel, embeddingModelList)
  48. const [currentRerankDefaultModel, changeCurrentRerankDefaultModel] = useSystemDefaultModelAndModelList(rerankDefaultModel, rerankModelList)
  49. const [currentSpeech2textDefaultModel, changeCurrentSpeech2textDefaultModel] = useSystemDefaultModelAndModelList(speech2textDefaultModel, speech2textModelList)
  50. const [open, setOpen] = useState(false)
  51. const getCurrentDefaultModelByModelType = (modelType: ModelTypeEnum) => {
  52. if (modelType === ModelTypeEnum.textGeneration)
  53. return currentTextGenerationDefaultModel
  54. else if (modelType === ModelTypeEnum.textEmbedding)
  55. return currentEmbeddingsDefaultModel
  56. else if (modelType === ModelTypeEnum.rerank)
  57. return currentRerankDefaultModel
  58. else if (modelType === ModelTypeEnum.speech2text)
  59. return currentSpeech2textDefaultModel
  60. return undefined
  61. }
  62. const handleChangeDefaultModel = (modelType: ModelTypeEnum, model: DefaultModel) => {
  63. if (modelType === ModelTypeEnum.textGeneration)
  64. changeCurrentTextGenerationDefaultModel(model)
  65. else if (modelType === ModelTypeEnum.textEmbedding)
  66. changeCurrentEmbeddingsDefaultModel(model)
  67. else if (modelType === ModelTypeEnum.rerank)
  68. changeCurrentRerankDefaultModel(model)
  69. else if (modelType === ModelTypeEnum.speech2text)
  70. changeCurrentSpeech2textDefaultModel(model)
  71. if (!changedModelTypes.includes(modelType))
  72. setChangedModelTypes([...changedModelTypes, modelType])
  73. }
  74. const handleSave = async () => {
  75. const res = await updateDefaultModel({
  76. url: '/workspaces/current/default-model',
  77. body: {
  78. model_settings: [ModelTypeEnum.textGeneration, ModelTypeEnum.textEmbedding, ModelTypeEnum.rerank, ModelTypeEnum.speech2text].map((modelType) => {
  79. return {
  80. model_type: modelType,
  81. provider: getCurrentDefaultModelByModelType(modelType)?.provider,
  82. model: getCurrentDefaultModelByModelType(modelType)?.model,
  83. }
  84. }),
  85. },
  86. })
  87. if (res.result === 'success') {
  88. notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
  89. setOpen(false)
  90. changedModelTypes.forEach((modelType) => {
  91. if (modelType === ModelTypeEnum.textGeneration)
  92. updateModelList(modelType)
  93. else if (modelType === ModelTypeEnum.textEmbedding)
  94. updateModelList(modelType)
  95. else if (modelType === ModelTypeEnum.rerank)
  96. updateModelList(modelType)
  97. else if (modelType === ModelTypeEnum.speech2text)
  98. updateModelList(modelType)
  99. })
  100. }
  101. }
  102. return (
  103. <PortalToFollowElem
  104. open={open}
  105. onOpenChange={setOpen}
  106. placement='bottom-end'
  107. offset={{
  108. mainAxis: 4,
  109. crossAxis: 8,
  110. }}
  111. >
  112. <PortalToFollowElemTrigger onClick={() => setOpen(v => !v)}>
  113. <div className={`
  114. flex items-center px-2 h-6 text-xs text-gray-700 cursor-pointer bg-white rounded-md border-[0.5px] border-gray-200 shadow-xs
  115. hover:bg-gray-100 hover:shadow-none
  116. ${open && 'bg-gray-100 shadow-none'}
  117. `}>
  118. <Settings01 className='mr-1 w-3 h-3 text-gray-500' />
  119. {t('common.modelProvider.systemModelSettings')}
  120. </div>
  121. </PortalToFollowElemTrigger>
  122. <PortalToFollowElemContent className='z-50'>
  123. <div className='pt-4 w-[360px] rounded-xl border-[0.5px] border-black/5 bg-white shadow-xl'>
  124. <div className='px-6 py-1'>
  125. <div className='flex items-center h-8 text-[13px] font-medium text-gray-900'>
  126. {t('common.modelProvider.systemReasoningModel.key')}
  127. <Tooltip
  128. selector='model-page-system-reasoning-model-tip'
  129. htmlContent={
  130. <div className='w-[261px] text-gray-500'>{t('common.modelProvider.systemReasoningModel.tip')}</div>
  131. }
  132. >
  133. <HelpCircle className='ml-0.5 w-[14px] h-[14px] text-gray-400' />
  134. </Tooltip>
  135. </div>
  136. <div>
  137. <ModelSelector
  138. defaultModel={currentTextGenerationDefaultModel}
  139. modelList={textGenerationModelList}
  140. onSelect={model => handleChangeDefaultModel(ModelTypeEnum.textGeneration, model)}
  141. />
  142. </div>
  143. </div>
  144. <div className='px-6 py-1'>
  145. <div className='flex items-center h-8 text-[13px] font-medium text-gray-900'>
  146. {t('common.modelProvider.embeddingModel.key')}
  147. <Tooltip
  148. selector='model-page-system-embedding-model-tip'
  149. htmlContent={
  150. <div className='w-[261px] text-gray-500'>{t('common.modelProvider.embeddingModel.tip')}</div>
  151. }
  152. >
  153. <HelpCircle className='ml-0.5 w-[14px] h-[14px] text-gray-400' />
  154. </Tooltip>
  155. </div>
  156. <div>
  157. <ModelSelector
  158. defaultModel={currentEmbeddingsDefaultModel}
  159. modelList={embeddingModelList}
  160. onSelect={model => handleChangeDefaultModel(ModelTypeEnum.textEmbedding, model)}
  161. />
  162. </div>
  163. </div>
  164. <div className='px-6 py-1'>
  165. <div className='flex items-center h-8 text-[13px] font-medium text-gray-900'>
  166. {t('common.modelProvider.rerankModel.key')}
  167. <Tooltip
  168. selector='model-page-system-rerankModel-model-tip'
  169. htmlContent={
  170. <div className='w-[261px] text-gray-500'>{t('common.modelProvider.rerankModel.tip')}</div>
  171. }
  172. >
  173. <HelpCircle className='ml-0.5 w-[14px] h-[14px] text-gray-400' />
  174. </Tooltip>
  175. </div>
  176. <div>
  177. <ModelSelector
  178. defaultModel={currentRerankDefaultModel}
  179. modelList={rerankModelList}
  180. onSelect={model => handleChangeDefaultModel(ModelTypeEnum.rerank, model)}
  181. />
  182. </div>
  183. </div>
  184. <div className='px-6 py-1'>
  185. <div className='flex items-center h-8 text-[13px] font-medium text-gray-900'>
  186. {t('common.modelProvider.speechToTextModel.key')}
  187. <Tooltip
  188. selector='model-page-system-speechToText-model-tip'
  189. htmlContent={
  190. <div className='w-[261px] text-gray-500'>{t('common.modelProvider.speechToTextModel.tip')}</div>
  191. }
  192. >
  193. <HelpCircle className='ml-0.5 w-[14px] h-[14px] text-gray-400' />
  194. </Tooltip>
  195. </div>
  196. <div>
  197. <ModelSelector
  198. defaultModel={currentSpeech2textDefaultModel}
  199. modelList={speech2textModelList}
  200. onSelect={model => handleChangeDefaultModel(ModelTypeEnum.speech2text, model)}
  201. />
  202. </div>
  203. </div>
  204. <div className='flex items-center justify-end px-6 py-4'>
  205. <Button
  206. className='mr-2 !h-8 !text-[13px]'
  207. onClick={() => setOpen(false)}
  208. >
  209. {t('common.operation.cancel')}
  210. </Button>
  211. <Button
  212. type='primary'
  213. className='!h-8 !text-[13px]'
  214. onClick={handleSave}
  215. >
  216. {t('common.operation.save')}
  217. </Button>
  218. </div>
  219. </div>
  220. </PortalToFollowElemContent>
  221. </PortalToFollowElem>
  222. )
  223. }
  224. export default SystemModel