index.tsx 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. import type { FC } from 'react'
  2. import { Fragment, useState } from 'react'
  3. import { Popover, Transition } from '@headlessui/react'
  4. import { useTranslation } from 'react-i18next'
  5. import _ from 'lodash-es'
  6. import cn from 'classnames'
  7. import s from './style.module.css'
  8. import type { BackendModel, ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations'
  9. import { ModelType } from '@/app/components/header/account-setting/model-page/declarations'
  10. import { ChevronDown } from '@/app/components/base/icons/src/vender/line/arrows'
  11. import { Check, SearchLg } from '@/app/components/base/icons/src/vender/line/general'
  12. import { XCircle } from '@/app/components/base/icons/src/vender/solid/general'
  13. import { AlertCircle } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback'
  14. import Tooltip from '@/app/components/base/tooltip'
  15. import ModelIcon from '@/app/components/app/configuration/config-model/model-icon'
  16. import ModelName from '@/app/components/app/configuration/config-model/model-name'
  17. import ProviderName from '@/app/components/app/configuration/config-model/provider-name'
  18. import { useProviderContext } from '@/context/provider-context'
  19. import ModelModeTypeLabel from '@/app/components/app/configuration/config-model/model-mode-type-label'
  20. import type { ModelModeType } from '@/types/app'
  21. import { CubeOutline } from '@/app/components/base/icons/src/vender/line/shapes'
  22. import { useModalContext } from '@/context/modal-context'
  23. type Props = {
  24. value: {
  25. providerName: ProviderEnum
  26. modelName: string
  27. } | undefined
  28. modelType: ModelType
  29. isShowModelModeType?: boolean
  30. isShowAddModel?: boolean
  31. supportAgentThought?: boolean
  32. onChange: (value: BackendModel) => void
  33. popClassName?: string
  34. readonly?: boolean
  35. triggerIconSmall?: boolean
  36. }
  37. type ModelOption = {
  38. type: 'model'
  39. value: string
  40. providerName: ProviderEnum
  41. modelDisplayName: string
  42. model_mode: ModelModeType
  43. } | {
  44. type: 'provider'
  45. value: ProviderEnum
  46. }
  47. const ModelSelector: FC<Props> = ({
  48. value,
  49. modelType,
  50. isShowModelModeType,
  51. isShowAddModel,
  52. supportAgentThought,
  53. onChange,
  54. popClassName,
  55. readonly,
  56. triggerIconSmall,
  57. }) => {
  58. const { t } = useTranslation()
  59. const { setShowAccountSettingModal } = useModalContext()
  60. const { textGenerationModelList, embeddingsModelList, speech2textModelList, agentThoughtModelList } = useProviderContext()
  61. const [search, setSearch] = useState('')
  62. const modelList = supportAgentThought
  63. ? agentThoughtModelList
  64. : ({
  65. [ModelType.textGeneration]: textGenerationModelList,
  66. [ModelType.embeddings]: embeddingsModelList,
  67. [ModelType.speech2text]: speech2textModelList,
  68. })[modelType]
  69. const currModel = modelList.find(item => item.model_name === value?.modelName && item.model_provider.provider_name === value.providerName)
  70. const allModelNames = (() => {
  71. if (!search)
  72. return {}
  73. const res: Record<string, string> = {}
  74. modelList.forEach(({ model_name, model_display_name }) => {
  75. res[model_name] = model_display_name
  76. })
  77. return res
  78. })()
  79. const filteredModelList = search
  80. ? modelList.filter(({ model_name }) => {
  81. if (allModelNames[model_name].includes(search))
  82. return true
  83. return false
  84. })
  85. : modelList
  86. const hasRemoved = value && !modelList.find(({ model_name, model_provider }) => model_name === value.modelName && model_provider.provider_name === value.providerName)
  87. const modelOptions: ModelOption[] = (() => {
  88. const providers = _.uniq(filteredModelList.map(item => item.model_provider.provider_name))
  89. const res: ModelOption[] = []
  90. providers.forEach((providerName) => {
  91. res.push({
  92. type: 'provider',
  93. value: providerName,
  94. })
  95. const models = filteredModelList.filter(m => m.model_provider.provider_name === providerName)
  96. models.forEach(({ model_name, model_display_name, model_mode }) => {
  97. res.push({
  98. type: 'model',
  99. providerName,
  100. value: model_name,
  101. modelDisplayName: model_display_name,
  102. model_mode,
  103. })
  104. })
  105. })
  106. return res
  107. })()
  108. return (
  109. <div className=''>
  110. <Popover className='relative'>
  111. <Popover.Button className={cn('flex items-center px-2.5 w-full h-9 rounded-lg', readonly ? '!cursor-auto' : 'bg-gray-100', hasRemoved && '!bg-[#FEF3F2]')}>
  112. {
  113. ({ open }) => (
  114. <>
  115. {
  116. value
  117. ? (
  118. <>
  119. <ModelIcon
  120. className={cn('mr-1.5', !triggerIconSmall && 'w-5 h-5')}
  121. modelId={value.modelName}
  122. providerName={value.providerName}
  123. />
  124. <div className='mr-1.5 grow flex items-center text-left text-sm text-gray-900 truncate'>
  125. <ModelName modelId={value.modelName} modelDisplayName={currModel?.model_display_name} />
  126. {isShowModelModeType && (
  127. <ModelModeTypeLabel className='ml-2' type={currModel?.model_mode as ModelModeType} />
  128. )}
  129. </div>
  130. </>
  131. )
  132. : (
  133. <div className='grow text-left text-sm text-gray-800 opacity-60'>{t('common.modelProvider.selectModel')}</div>
  134. )
  135. }
  136. {
  137. hasRemoved && (
  138. <Tooltip
  139. selector='model-selector-remove-tip'
  140. htmlContent={
  141. <div className='w-[261px] text-gray-500'>{t('common.modelProvider.selector.tip')}</div>
  142. }
  143. >
  144. <AlertCircle className='mr-1 w-4 h-4 text-[#F04438]' />
  145. </Tooltip>
  146. )
  147. }
  148. {!readonly && <ChevronDown className={`w-4 h-4 text-gray-700 ${open ? 'opacity-100' : 'opacity-60'}`} />}
  149. </>
  150. )
  151. }
  152. </Popover.Button>
  153. {!readonly && (
  154. <Transition
  155. as={Fragment}
  156. leave='transition ease-in duration-100'
  157. leaveFrom='opacity-100'
  158. leaveTo='opacity-0'
  159. >
  160. <Popover.Panel className={cn(popClassName, isShowModelModeType ? 'max-w-[312px]' : 'max-w-[260px]', 'absolute top-10 p-1 min-w-[232px] max-h-[366px] bg-white border-[0.5px] border-gray-200 rounded-lg shadow-lg overflow-auto z-10')}>
  161. <div className='px-2 pt-2 pb-1'>
  162. <div className='flex items-center px-2 h-8 bg-gray-100 rounded-lg'>
  163. <div className='mr-1.5 p-[1px]'><SearchLg className='w-[14px] h-[14px] text-gray-400' /></div>
  164. <div className='grow px-0.5'>
  165. <input
  166. value={search}
  167. onChange={e => setSearch(e.target.value)}
  168. className={`
  169. block w-full h-8 bg-transparent text-[13px] text-gray-700
  170. outline-none appearance-none border-none
  171. `}
  172. placeholder={t('common.modelProvider.searchModel') || ''}
  173. />
  174. </div>
  175. {
  176. search && (
  177. <div className='ml-1 p-0.5 cursor-pointer' onClick={() => setSearch('')}>
  178. <XCircle className='w-3 h-3 text-gray-400' />
  179. </div>
  180. )
  181. }
  182. </div>
  183. </div>
  184. {
  185. modelOptions.map((model) => {
  186. if (model.type === 'provider') {
  187. return (
  188. <div
  189. className='px-3 pt-2 pb-1 text-xs font-medium text-gray-500'
  190. key={`${model.type}-${model.value}`}
  191. >
  192. <ProviderName provideName={model.value} />
  193. </div>
  194. )
  195. }
  196. if (model.type === 'model') {
  197. return (
  198. <Popover.Button
  199. key={`${model.providerName}-${model.value}`}
  200. className={`${s.optionItem}
  201. flex items-center px-3 w-full h-8 rounded-lg hover:bg-gray-50
  202. ${!readonly ? 'cursor-pointer' : 'cursor-auto'}
  203. ${(value?.providerName === model.providerName && value?.modelName === model.value) && 'bg-gray-50'}
  204. `}
  205. onClick={() => {
  206. const selectedModel = modelList.find((item) => {
  207. return item.model_name === model.value && item.model_provider.provider_name === model.providerName
  208. })
  209. onChange(selectedModel as BackendModel)
  210. }}
  211. >
  212. <ModelIcon
  213. className='mr-2 shrink-0'
  214. modelId={model.value}
  215. providerName={model.providerName}
  216. />
  217. <div className='mr-2 grow flex items-center text-left text-sm text-gray-900 truncate'>
  218. <ModelName modelId={model.value} modelDisplayName={model.modelDisplayName} />
  219. {isShowModelModeType && (
  220. <ModelModeTypeLabel className={`${s.modelModeLabel} ml-2`} type={model.model_mode} />
  221. )}
  222. </div>
  223. { (value?.providerName === model.providerName && value?.modelName === model.value) && <Check className='shrink-0 w-4 h-4 text-primary-600' /> }
  224. </Popover.Button>
  225. )
  226. }
  227. return null
  228. })
  229. }
  230. {(search && filteredModelList.length === 0) && (
  231. <div className='px-3 pt-1.5 h-[30px] text-center text-xs text-gray-500'>{t('common.modelProvider.noModelFound', { model: search })}</div>
  232. )}
  233. {isShowAddModel && (
  234. <div
  235. className='border-t flex items-center h-9 pl-3 text-xs text-[#155EEF] cursor-pointer'
  236. style={{
  237. borderColor: 'rgba(0, 0, 0, 0.05)',
  238. }}
  239. onClick={() => setShowAccountSettingModal({ payload: 'provider' })}
  240. >
  241. <CubeOutline className='w-4 h-4 mr-2' />
  242. <div>{t('common.model.addMoreModel')}</div>
  243. </div>
  244. )}
  245. </Popover.Panel>
  246. </Transition>
  247. )}
  248. </Popover>
  249. </div>
  250. )
  251. }
  252. export default ModelSelector