portal-select.tsx 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. import type { FC } from 'react'
  2. import React, { Fragment, useEffect, useRef, useState } from 'react'
  3. import useSWR from 'swr'
  4. import { useTranslation } from 'react-i18next'
  5. import _ from 'lodash-es'
  6. import cn from 'classnames'
  7. import ModelModal from '../model-modal'
  8. import cohereConfig from '../configs/cohere'
  9. import s from './style.module.css'
  10. import type { BackendModel, FormValue, ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations'
  11. import { ModelType } from '@/app/components/header/account-setting/model-page/declarations'
  12. import { ChevronDown } from '@/app/components/base/icons/src/vender/line/arrows'
  13. import { Check, LinkExternal01, SearchLg } from '@/app/components/base/icons/src/vender/line/general'
  14. import { XCircle } from '@/app/components/base/icons/src/vender/solid/general'
  15. import { AlertCircle } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback'
  16. import Tooltip from '@/app/components/base/tooltip'
  17. import ModelIcon from '@/app/components/app/configuration/config-model/model-icon'
  18. import ModelName from '@/app/components/app/configuration/config-model/model-name'
  19. import ProviderName from '@/app/components/app/configuration/config-model/provider-name'
  20. import { useProviderContext } from '@/context/provider-context'
  21. import ModelModeTypeLabel from '@/app/components/app/configuration/config-model/model-mode-type-label'
  22. import type { ModelModeType } from '@/types/app'
  23. import { CubeOutline } from '@/app/components/base/icons/src/vender/line/shapes'
  24. import { useModalContext } from '@/context/modal-context'
  25. import { useEventEmitterContextContext } from '@/context/event-emitter'
  26. import { fetchDefaultModal, setModelProvider } from '@/service/common'
  27. import { useToastContext } from '@/app/components/base/toast'
  28. import {
  29. PortalToFollowElem,
  30. PortalToFollowElemContent,
  31. PortalToFollowElemTrigger,
  32. } from '@/app/components/base/portal-to-follow-elem'
  33. type Props = {
  34. value: {
  35. providerName: ProviderEnum
  36. modelName: string
  37. } | undefined
  38. modelType: ModelType
  39. isShowModelModeType?: boolean
  40. isShowAddModel?: boolean
  41. supportAgentThought?: boolean
  42. onChange: (value: BackendModel) => void
  43. popClassName?: string
  44. readonly?: boolean
  45. triggerIconSmall?: boolean
  46. whenEmptyGoToSetting?: boolean
  47. onUpdate?: () => void
  48. widthSameToTrigger?: boolean
  49. }
  50. type ModelOption = {
  51. type: 'model'
  52. value: string
  53. providerName: ProviderEnum
  54. modelDisplayName: string
  55. model_mode: ModelModeType
  56. } | {
  57. type: 'provider'
  58. value: ProviderEnum
  59. }
  60. const ModelSelector: FC<Props> = ({
  61. value,
  62. modelType,
  63. isShowModelModeType,
  64. isShowAddModel,
  65. supportAgentThought,
  66. onChange,
  67. popClassName,
  68. readonly,
  69. triggerIconSmall,
  70. whenEmptyGoToSetting,
  71. onUpdate,
  72. widthSameToTrigger,
  73. }) => {
  74. const { t } = useTranslation()
  75. const { setShowAccountSettingModal } = useModalContext()
  76. const {
  77. textGenerationModelList,
  78. embeddingsModelList,
  79. speech2textModelList,
  80. rerankModelList,
  81. agentThoughtModelList,
  82. updateModelList,
  83. } = useProviderContext()
  84. const [search, setSearch] = useState('')
  85. const modelList = supportAgentThought
  86. ? agentThoughtModelList
  87. : ({
  88. [ModelType.textGeneration]: textGenerationModelList,
  89. [ModelType.embeddings]: embeddingsModelList,
  90. [ModelType.speech2text]: speech2textModelList,
  91. [ModelType.reranking]: rerankModelList,
  92. })[modelType]
  93. const currModel = modelList.find(item => item.model_name === value?.modelName && item.model_provider.provider_name === value.providerName)
  94. const allModelNames = (() => {
  95. if (!search)
  96. return {}
  97. const res: Record<string, string> = {}
  98. modelList.forEach(({ model_name, model_display_name }) => {
  99. res[model_name] = model_display_name
  100. })
  101. return res
  102. })()
  103. const filteredModelList = search
  104. ? modelList.filter(({ model_name }) => {
  105. if (allModelNames[model_name].includes(search))
  106. return true
  107. return false
  108. })
  109. : modelList
  110. const hasRemoved = (value && value.modelName && value.providerName) && !modelList.find(({ model_name, model_provider }) => model_name === value.modelName && model_provider.provider_name === value.providerName)
  111. const modelOptions: ModelOption[] = (() => {
  112. const providers = _.uniq(filteredModelList.map(item => item.model_provider.provider_name))
  113. const res: ModelOption[] = []
  114. providers.forEach((providerName) => {
  115. res.push({
  116. type: 'provider',
  117. value: providerName,
  118. })
  119. const models = filteredModelList.filter(m => m.model_provider.provider_name === providerName)
  120. models.forEach(({ model_name, model_display_name, model_mode }) => {
  121. res.push({
  122. type: 'model',
  123. providerName,
  124. value: model_name,
  125. modelDisplayName: model_display_name,
  126. model_mode,
  127. })
  128. })
  129. })
  130. return res
  131. })()
  132. const { eventEmitter } = useEventEmitterContextContext()
  133. const [showRerankModal, setShowRerankModal] = useState(false)
  134. const [shouldFetchRerankDefaultModel, setShouldFetchRerankDefaultModel] = useState(false)
  135. const { notify } = useToastContext()
  136. const { data: rerankDefaultModel } = useSWR(shouldFetchRerankDefaultModel ? '/workspaces/current/default-model?model_type=reranking' : null, fetchDefaultModal)
  137. const handleOpenRerankModal = (e: React.MouseEvent<HTMLDivElement>) => {
  138. e.stopPropagation()
  139. setShowRerankModal(true)
  140. }
  141. const handleRerankModalSave = async (originValue?: FormValue) => {
  142. if (originValue) {
  143. try {
  144. eventEmitter?.emit('provider-save')
  145. const res = await setModelProvider({
  146. url: `/workspaces/current/model-providers/${cohereConfig.modal.key}`,
  147. body: {
  148. config: originValue,
  149. },
  150. })
  151. if (res.result === 'success') {
  152. notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
  153. updateModelList(ModelType.reranking)
  154. setShowRerankModal(false)
  155. setShouldFetchRerankDefaultModel(true)
  156. if (onUpdate)
  157. onUpdate()
  158. }
  159. eventEmitter?.emit('')
  160. }
  161. catch (e) {
  162. eventEmitter?.emit('')
  163. }
  164. }
  165. }
  166. const [open, setOpen] = useState(false)
  167. const triggerRef = useRef<HTMLDivElement>(null)
  168. useEffect(() => {
  169. if (rerankDefaultModel && whenEmptyGoToSetting)
  170. onChange(rerankDefaultModel)
  171. }, [rerankDefaultModel])
  172. return (
  173. <PortalToFollowElem
  174. open={open}
  175. onOpenChange={setOpen}
  176. placement='bottom-start'
  177. offset={4}
  178. >
  179. <div className='relative'>
  180. <PortalToFollowElemTrigger onClick={() => setOpen(v => !v)} className={cn('flex items-center px-2.5 w-full h-9 rounded-lg', readonly ? '!cursor-auto bg-gray-100 opacity-50' : 'bg-gray-100', hasRemoved && '!bg-[#FEF3F2]')}>
  181. {
  182. <div ref={triggerRef} className='flex items-center w-full cursor-pointer'>
  183. {
  184. (value && value.modelName && value.providerName)
  185. ? (
  186. <>
  187. <ModelIcon
  188. className={cn('mr-1.5', !triggerIconSmall && 'w-5 h-5')}
  189. modelId={value.modelName}
  190. providerName={value.providerName}
  191. />
  192. <div className='mr-1.5 grow flex items-center text-left text-sm text-gray-900 truncate'>
  193. <ModelName modelId={value.modelName} modelDisplayName={currModel?.model_display_name || value.modelName} />
  194. {isShowModelModeType && (
  195. <ModelModeTypeLabel className='ml-2' type={currModel?.model_mode as ModelModeType} />
  196. )}
  197. </div>
  198. </>
  199. )
  200. : whenEmptyGoToSetting
  201. ? (
  202. <div className='grow flex items-center h-9 justify-between' onClick={handleOpenRerankModal}>
  203. <div className='flex items-center text-[13px] font-medium text-primary-500'>
  204. <CubeOutline className='mr-1.5 w-4 h-4' />
  205. {t('common.modelProvider.selector.rerankTip')}
  206. </div>
  207. <LinkExternal01 className='w-3 h-3 text-gray-500' />
  208. </div>
  209. )
  210. : (
  211. <div className='grow text-left text-sm text-gray-800 opacity-60'>{t('common.modelProvider.selectModel')}</div>
  212. )
  213. }
  214. {
  215. hasRemoved && (
  216. <Tooltip
  217. selector='model-selector-remove-tip'
  218. htmlContent={
  219. <div className='w-[261px] text-gray-500'>{t('common.modelProvider.selector.tip')}</div>
  220. }
  221. >
  222. <AlertCircle className='mr-1 w-4 h-4 text-[#F04438]' />
  223. </Tooltip>
  224. )
  225. }
  226. {
  227. !readonly && !whenEmptyGoToSetting && (
  228. <ChevronDown className={`w-4 h-4 text-gray-700 ${open ? 'opacity-100' : 'opacity-60'}`} />
  229. )
  230. }
  231. {
  232. whenEmptyGoToSetting && (value && value.modelName && value.providerName) && (
  233. <ChevronDown className={`w-4 h-4 text-gray-700 ${open ? 'opacity-100' : 'opacity-60'}`} />
  234. )
  235. }
  236. </div>
  237. }
  238. </PortalToFollowElemTrigger>
  239. {!readonly && (
  240. <PortalToFollowElemContent
  241. className={cn(popClassName, !widthSameToTrigger && (isShowModelModeType ? 'max-w-[312px]' : 'max-w-[260px]'), 'absolute top-10 p-1 min-w-[232px] max-h-[366px] bg-white border-[0.5px] border-gray-200 rounded-lg shadow-lg overflow-auto z-[999]')}
  242. style={{
  243. width: (widthSameToTrigger && triggerRef.current?.offsetWidth) ? `${triggerRef.current?.offsetWidth}px` : 'auto',
  244. }}
  245. >
  246. <div className='px-2 pt-2 pb-1'>
  247. <div className='flex items-center px-2 h-8 bg-gray-100 rounded-lg'>
  248. <div className='mr-1.5 p-[1px]'><SearchLg className='w-[14px] h-[14px] text-gray-400' /></div>
  249. <div className='grow px-0.5'>
  250. <input
  251. value={search}
  252. onChange={e => setSearch(e.target.value)}
  253. className={`
  254. block w-full h-8 bg-transparent text-[13px] text-gray-700
  255. outline-none appearance-none border-none
  256. `}
  257. placeholder={t('common.modelProvider.searchModel') || ''}
  258. />
  259. </div>
  260. {
  261. search && (
  262. <div className='ml-1 p-0.5 cursor-pointer' onClick={() => setSearch('')}>
  263. <XCircle className='w-3 h-3 text-gray-400' />
  264. </div>
  265. )
  266. }
  267. </div>
  268. </div>
  269. {
  270. modelOptions.map((model) => {
  271. if (model.type === 'provider') {
  272. return (
  273. <div
  274. className='px-3 pt-2 pb-1 text-xs font-medium text-gray-500'
  275. key={`${model.type}-${model.value}`}
  276. >
  277. <ProviderName provideName={model.value} />
  278. </div>
  279. )
  280. }
  281. if (model.type === 'model') {
  282. return (
  283. <div
  284. key={`${model.providerName}-${model.value}`}
  285. className={`${s.optionItem}
  286. flex items-center px-3 w-full h-8 rounded-lg hover:bg-gray-50
  287. ${!readonly ? 'cursor-pointer' : 'cursor-auto'}
  288. ${(value?.providerName === model.providerName && value?.modelName === model.value) && 'bg-gray-50'}
  289. `}
  290. onClick={() => {
  291. const selectedModel = modelList.find((item) => {
  292. return item.model_name === model.value && item.model_provider.provider_name === model.providerName
  293. })
  294. onChange(selectedModel as BackendModel)
  295. setOpen(false)
  296. }}
  297. >
  298. <ModelIcon
  299. className='mr-2 shrink-0'
  300. modelId={model.value}
  301. providerName={model.providerName}
  302. />
  303. <div className='mr-2 grow flex items-center text-left text-sm text-gray-900 truncate'>
  304. <ModelName modelId={model.value} modelDisplayName={model.modelDisplayName} />
  305. {isShowModelModeType && (
  306. <ModelModeTypeLabel className={`${s.modelModeLabel} ml-2`} type={model.model_mode} />
  307. )}
  308. </div>
  309. {(value?.providerName === model.providerName && value?.modelName === model.value) && <Check className='shrink-0 w-4 h-4 text-primary-600' />}
  310. </div>
  311. )
  312. }
  313. return null
  314. })
  315. }
  316. {modelList.length !== 0 && (search && filteredModelList.length === 0) && (
  317. <div className='px-3 pt-1.5 h-[30px] text-center text-xs text-gray-500'>{t('common.modelProvider.noModelFound', { model: search })}</div>
  318. )}
  319. {isShowAddModel && (
  320. <div
  321. className='border-t flex items-center h-9 pl-3 text-xs text-[#155EEF] cursor-pointer'
  322. style={{
  323. borderColor: 'rgba(0, 0, 0, 0.05)',
  324. }}
  325. onClick={() => setShowAccountSettingModal({ payload: 'provider' })}
  326. >
  327. <CubeOutline className='w-4 h-4 mr-2' />
  328. <div>{t('common.model.addMoreModel')}</div>
  329. </div>
  330. )}
  331. </PortalToFollowElemContent>
  332. )}
  333. </div>
  334. <ModelModal
  335. isShow={showRerankModal}
  336. modelModal={cohereConfig.modal}
  337. onCancel={() => setShowRerankModal(false)}
  338. onSave={handleRerankModalSave}
  339. mode={'add'}
  340. />
  341. </PortalToFollowElem>
  342. )
  343. }
  344. export default ModelSelector