import type { FC } from 'react' import React, { Fragment, useEffect, useRef, useState } from 'react' import useSWR from 'swr' import { useTranslation } from 'react-i18next' import _ from 'lodash-es' import cn from 'classnames' import ModelModal from '../model-modal' import cohereConfig from '../configs/cohere' import s from './style.module.css' import type { BackendModel, FormValue, ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations' import { ModelType } from '@/app/components/header/account-setting/model-page/declarations' import { ChevronDown } from '@/app/components/base/icons/src/vender/line/arrows' import { Check, LinkExternal01, SearchLg } from '@/app/components/base/icons/src/vender/line/general' import { XCircle } from '@/app/components/base/icons/src/vender/solid/general' import { AlertCircle } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback' import Tooltip from '@/app/components/base/tooltip' import ModelIcon from '@/app/components/app/configuration/config-model/model-icon' import ModelName from '@/app/components/app/configuration/config-model/model-name' import ProviderName from '@/app/components/app/configuration/config-model/provider-name' import { useProviderContext } from '@/context/provider-context' import ModelModeTypeLabel from '@/app/components/app/configuration/config-model/model-mode-type-label' import type { ModelModeType } from '@/types/app' import { CubeOutline } from '@/app/components/base/icons/src/vender/line/shapes' import { useModalContext } from '@/context/modal-context' import { useEventEmitterContextContext } from '@/context/event-emitter' import { fetchDefaultModal, setModelProvider } from '@/service/common' import { useToastContext } from '@/app/components/base/toast' import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' type Props = { value: { providerName: ProviderEnum modelName: string } | undefined modelType: ModelType isShowModelModeType?: boolean isShowAddModel?: boolean supportAgentThought?: boolean onChange: (value: BackendModel) => void popClassName?: string readonly?: boolean triggerIconSmall?: boolean whenEmptyGoToSetting?: boolean onUpdate?: () => void widthSameToTrigger?: boolean } type ModelOption = { type: 'model' value: string providerName: ProviderEnum modelDisplayName: string model_mode: ModelModeType } | { type: 'provider' value: ProviderEnum } const ModelSelector: FC = ({ value, modelType, isShowModelModeType, isShowAddModel, supportAgentThought, onChange, popClassName, readonly, triggerIconSmall, whenEmptyGoToSetting, onUpdate, widthSameToTrigger, }) => { const { t } = useTranslation() const { setShowAccountSettingModal } = useModalContext() const { textGenerationModelList, embeddingsModelList, speech2textModelList, rerankModelList, agentThoughtModelList, updateModelList, } = useProviderContext() const [search, setSearch] = useState('') const modelList = supportAgentThought ? agentThoughtModelList : ({ [ModelType.textGeneration]: textGenerationModelList, [ModelType.embeddings]: embeddingsModelList, [ModelType.speech2text]: speech2textModelList, [ModelType.reranking]: rerankModelList, })[modelType] const currModel = modelList.find(item => item.model_name === value?.modelName && item.model_provider.provider_name === value.providerName) const allModelNames = (() => { if (!search) return {} const res: Record = {} modelList.forEach(({ model_name, model_display_name }) => { res[model_name] = model_display_name }) return res })() const filteredModelList = search ? modelList.filter(({ model_name }) => { if (allModelNames[model_name].includes(search)) return true return false }) : modelList const hasRemoved = (value && value.modelName && value.providerName) && !modelList.find(({ model_name, model_provider }) => model_name === value.modelName && model_provider.provider_name === value.providerName) const modelOptions: ModelOption[] = (() => { const providers = _.uniq(filteredModelList.map(item => item.model_provider.provider_name)) const res: ModelOption[] = [] providers.forEach((providerName) => { res.push({ type: 'provider', value: providerName, }) const models = filteredModelList.filter(m => m.model_provider.provider_name === providerName) models.forEach(({ model_name, model_display_name, model_mode }) => { res.push({ type: 'model', providerName, value: model_name, modelDisplayName: model_display_name, model_mode, }) }) }) return res })() const { eventEmitter } = useEventEmitterContextContext() const [showRerankModal, setShowRerankModal] = useState(false) const [shouldFetchRerankDefaultModel, setShouldFetchRerankDefaultModel] = useState(false) const { notify } = useToastContext() const { data: rerankDefaultModel } = useSWR(shouldFetchRerankDefaultModel ? '/workspaces/current/default-model?model_type=reranking' : null, fetchDefaultModal) const handleOpenRerankModal = (e: React.MouseEvent) => { e.stopPropagation() setShowRerankModal(true) } const handleRerankModalSave = async (originValue?: FormValue) => { if (originValue) { try { eventEmitter?.emit('provider-save') const res = await setModelProvider({ url: `/workspaces/current/model-providers/${cohereConfig.modal.key}`, body: { config: originValue, }, }) if (res.result === 'success') { notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) updateModelList(ModelType.reranking) setShowRerankModal(false) setShouldFetchRerankDefaultModel(true) if (onUpdate) onUpdate() } eventEmitter?.emit('') } catch (e) { eventEmitter?.emit('') } } } const [open, setOpen] = useState(false) const triggerRef = useRef(null) useEffect(() => { if (rerankDefaultModel && whenEmptyGoToSetting) onChange(rerankDefaultModel) }, [rerankDefaultModel]) return (
setOpen(v => !v)} className={cn('flex items-center px-2.5 w-full h-9 rounded-lg', readonly ? '!cursor-auto bg-gray-100 opacity-50' : 'bg-gray-100', hasRemoved && '!bg-[#FEF3F2]')}> {
{ (value && value.modelName && value.providerName) ? ( <>
{isShowModelModeType && ( )}
) : whenEmptyGoToSetting ? (
{t('common.modelProvider.selector.rerankTip')}
) : (
{t('common.modelProvider.selectModel')}
) } { hasRemoved && ( {t('common.modelProvider.selector.tip')}
} > ) } { !readonly && !whenEmptyGoToSetting && ( ) } { whenEmptyGoToSetting && (value && value.modelName && value.providerName) && ( ) }
} {!readonly && (
setSearch(e.target.value)} className={` block w-full h-8 bg-transparent text-[13px] text-gray-700 outline-none appearance-none border-none `} placeholder={t('common.modelProvider.searchModel') || ''} />
{ search && (
setSearch('')}>
) }
{ modelOptions.map((model) => { if (model.type === 'provider') { return (
) } if (model.type === 'model') { return (
{ const selectedModel = modelList.find((item) => { return item.model_name === model.value && item.model_provider.provider_name === model.providerName }) onChange(selectedModel as BackendModel) setOpen(false) }} >
{isShowModelModeType && ( )}
{(value?.providerName === model.providerName && value?.modelName === model.value) && }
) } return null }) } {modelList.length !== 0 && (search && filteredModelList.length === 0) && (
{t('common.modelProvider.noModelFound', { model: search })}
)} {isShowAddModel && (
setShowAccountSettingModal({ payload: 'provider' })} >
{t('common.model.addMoreModel')}
)}
)} setShowRerankModal(false)} onSave={handleRerankModalSave} mode={'add'} />
) } export default ModelSelector