index.tsx 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. import type { FC } from 'react'
  2. import { Fragment, useState } from 'react'
  3. import { Popover, Transition } from '@headlessui/react'
  4. import { useTranslation } from 'react-i18next'
  5. import _ from 'lodash-es'
  6. import cn from 'classnames'
  7. import type { BackendModel, ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations'
  8. import { ModelType } from '@/app/components/header/account-setting/model-page/declarations'
  9. import { ChevronDown } from '@/app/components/base/icons/src/vender/line/arrows'
  10. import { Check, SearchLg } from '@/app/components/base/icons/src/vender/line/general'
  11. import { XCircle } from '@/app/components/base/icons/src/vender/solid/general'
  12. import { AlertCircle } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback'
  13. import Tooltip from '@/app/components/base/tooltip'
  14. import ModelIcon from '@/app/components/app/configuration/config-model/model-icon'
  15. import ModelName, { supportI18nModelName } from '@/app/components/app/configuration/config-model/model-name'
  16. import ProviderName from '@/app/components/app/configuration/config-model/provider-name'
  17. import { useProviderContext } from '@/context/provider-context'
  18. type Props = {
  19. value: {
  20. providerName: ProviderEnum
  21. modelName: string
  22. } | undefined
  23. modelType: ModelType
  24. supportAgentThought?: boolean
  25. onChange: (value: BackendModel) => void
  26. popClassName?: string
  27. readonly?: boolean
  28. triggerIconSmall?: boolean
  29. }
  30. const ModelSelector: FC<Props> = ({
  31. value,
  32. modelType,
  33. supportAgentThought,
  34. onChange,
  35. popClassName,
  36. readonly,
  37. triggerIconSmall,
  38. }) => {
  39. const { t } = useTranslation()
  40. const { textGenerationModelList, embeddingsModelList, speech2textModelList, agentThoughtModelList } = useProviderContext()
  41. const [search, setSearch] = useState('')
  42. const modelList = supportAgentThought
  43. ? agentThoughtModelList
  44. : ({
  45. [ModelType.textGeneration]: textGenerationModelList,
  46. [ModelType.embeddings]: embeddingsModelList,
  47. [ModelType.speech2text]: speech2textModelList,
  48. })[modelType]
  49. const allModelNames = (() => {
  50. if (!search)
  51. return {}
  52. const res: Record<string, string> = {}
  53. modelList.forEach(({ model_name }) => {
  54. res[model_name] = supportI18nModelName.includes(model_name) ? t(`common.modelName.${model_name}`) : model_name
  55. })
  56. return res
  57. })()
  58. const filteredModelList = search
  59. ? modelList.filter(({ model_name }) => {
  60. if (allModelNames[model_name].includes(search))
  61. return true
  62. return false
  63. })
  64. : modelList
  65. const hasRemoved = value && !modelList.find(({ model_name }) => model_name === value.modelName)
  66. const modelOptions: any[] = (() => {
  67. const providers = _.uniq(filteredModelList.map(item => item.model_provider.provider_name))
  68. const res: any[] = []
  69. providers.forEach((providerName) => {
  70. res.push({
  71. type: 'provider',
  72. value: providerName,
  73. })
  74. const models = filteredModelList.filter(m => m.model_provider.provider_name === providerName)
  75. models.forEach(({ model_name }) => {
  76. res.push({
  77. type: 'model',
  78. providerName,
  79. value: model_name,
  80. })
  81. })
  82. })
  83. return res
  84. })()
  85. return (
  86. <div className=''>
  87. <Popover className='relative'>
  88. <Popover.Button className={cn('flex items-center px-2.5 w-full h-9 rounded-lg', readonly ? '!cursor-auto' : 'bg-gray-100', hasRemoved && '!bg-[#FEF3F2]')}>
  89. {
  90. ({ open }) => (
  91. <>
  92. {
  93. value
  94. ? (
  95. <>
  96. <ModelIcon
  97. className={cn('mr-1.5', !triggerIconSmall && 'w-5 h-5')}
  98. modelId={value.modelName}
  99. providerName={value.providerName}
  100. />
  101. <div className='mr-1.5 grow text-left text-sm text-gray-900 truncate'><ModelName modelId={value.modelName} /></div>
  102. </>
  103. )
  104. : (
  105. <div className='grow text-left text-sm text-gray-800 opacity-60'>{t('common.modelProvider.selectModel')}</div>
  106. )
  107. }
  108. {
  109. hasRemoved && (
  110. <Tooltip
  111. selector='model-selector-remove-tip'
  112. htmlContent={
  113. <div className='w-[261px] text-gray-500'>{t('common.modelProvider.selector.tip')}</div>
  114. }
  115. >
  116. <AlertCircle className='mr-1 w-4 h-4 text-[#F04438]' />
  117. </Tooltip>
  118. )
  119. }
  120. {!readonly && <ChevronDown className={`w-4 h-4 text-gray-700 ${open ? 'opacity-100' : 'opacity-60'}`} />}
  121. </>
  122. )
  123. }
  124. </Popover.Button>
  125. {!readonly && (
  126. <Transition
  127. as={Fragment}
  128. leave='transition ease-in duration-100'
  129. leaveFrom='opacity-100'
  130. leaveTo='opacity-0'
  131. >
  132. <Popover.Panel className={cn(popClassName, 'absolute top-10 p-1 min-w-[232px] max-w-[260px] max-h-[366px] bg-white border-[0.5px] border-gray-200 rounded-lg shadow-lg overflow-auto z-10')}>
  133. <div className='px-2 pt-2 pb-1'>
  134. <div className='flex items-center px-2 h-8 bg-gray-100 rounded-lg'>
  135. <div className='mr-1.5 p-[1px]'><SearchLg className='w-[14px] h-[14px] text-gray-400' /></div>
  136. <div className='grow px-0.5'>
  137. <input
  138. value={search}
  139. onChange={e => setSearch(e.target.value)}
  140. className={`
  141. block w-full h-8 bg-transparent text-[13px] text-gray-700
  142. outline-none appearance-none border-none
  143. `}
  144. placeholder={t('common.modelProvider.searchModel') || ''}
  145. />
  146. </div>
  147. {
  148. search && (
  149. <div className='ml-1 p-0.5 cursor-pointer' onClick={() => setSearch('')}>
  150. <XCircle className='w-3 h-3 text-gray-400' />
  151. </div>
  152. )
  153. }
  154. </div>
  155. </div>
  156. {
  157. modelOptions.map((model: any) => {
  158. if (model.type === 'provider') {
  159. return (
  160. <div
  161. className='px-3 pt-2 pb-1 text-xs font-medium text-gray-500'
  162. key={`${model.type}-${model.value}`}
  163. >
  164. <ProviderName provideName={model.value} />
  165. </div>
  166. )
  167. }
  168. if (model.type === 'model') {
  169. return (
  170. <Popover.Button
  171. key={`${model.providerName}-${model.value}`}
  172. className={`
  173. flex items-center px-3 w-full h-8 rounded-lg hover:bg-gray-50
  174. ${!readonly ? 'cursor-pointer' : 'cursor-auto'}
  175. ${(value?.providerName === model.providerName && value?.modelName === model.value) && 'bg-gray-50'}
  176. `}
  177. onClick={() => {
  178. const selectedModel = modelList.find((item) => {
  179. return item.model_name === model.value && item.model_provider.provider_name === model.providerName
  180. })
  181. onChange(selectedModel as BackendModel)
  182. }}
  183. >
  184. <ModelIcon
  185. className='mr-2 shrink-0'
  186. modelId={model.value}
  187. providerName={model.providerName}
  188. />
  189. <div className='grow text-left text-sm text-gray-900 truncate'><ModelName modelId={model.value} /></div>
  190. { (value?.providerName === model.providerName && value?.modelName === model.value) && <Check className='shrink-0 w-4 h-4 text-primary-600' /> }
  191. </Popover.Button>
  192. )
  193. }
  194. return null
  195. })
  196. }
  197. {(search && filteredModelList.length === 0) && (
  198. <div className='px-3 pt-1.5 h-[30px] text-center text-xs text-gray-500'>{t('common.modelProvider.noModelFound', { model: search })}</div>
  199. )}
  200. </Popover.Panel>
  201. </Transition>
  202. )}
  203. </Popover>
  204. </div>
  205. )
  206. }
  207. export default ModelSelector