provider-context.tsx 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. 'use client'
  2. import { createContext, useContext } from 'use-context-selector'
  3. import useSWR from 'swr'
  4. import { fetchModelList, fetchTenantInfo } from '@/service/common'
  5. import { ModelFeature, ModelType } from '@/app/components/header/account-setting/model-page/declarations'
  6. import type { BackendModel } from '@/app/components/header/account-setting/model-page/declarations'
  7. const ProviderContext = createContext<{
  8. currentProvider: {
  9. provider: string
  10. provider_name: string
  11. token_is_set: boolean
  12. is_valid: boolean
  13. token_is_valid: boolean
  14. } | null | undefined
  15. textGenerationModelList: BackendModel[]
  16. embeddingsModelList: BackendModel[]
  17. speech2textModelList: BackendModel[]
  18. agentThoughtModelList: BackendModel[]
  19. updateModelList: (type: ModelType) => void
  20. }>({
  21. currentProvider: null,
  22. textGenerationModelList: [],
  23. embeddingsModelList: [],
  24. speech2textModelList: [],
  25. agentThoughtModelList: [],
  26. updateModelList: () => {},
  27. })
  28. export const useProviderContext = () => useContext(ProviderContext)
  29. type ProviderContextProviderProps = {
  30. children: React.ReactNode
  31. }
  32. export const ProviderContextProvider = ({
  33. children,
  34. }: ProviderContextProviderProps) => {
  35. const { data: userInfo } = useSWR({ url: '/info' }, fetchTenantInfo)
  36. const currentProvider = userInfo?.providers?.find(({ token_is_set, is_valid, provider_name }) => token_is_set && is_valid && (provider_name === 'openai' || provider_name === 'azure_openai'))
  37. const fetchModelListUrlPrefix = '/workspaces/current/models/model-type/'
  38. const { data: textGenerationModelList, mutate: mutateTextGenerationModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.textGeneration}`, fetchModelList)
  39. const { data: embeddingsModelList, mutate: mutateEmbeddingsModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.embeddings}`, fetchModelList)
  40. const { data: speech2textModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.speech2text}`, fetchModelList)
  41. const agentThoughtModelList = textGenerationModelList?.filter((item) => {
  42. return item.features?.includes(ModelFeature.agentThought)
  43. })
  44. const updateModelList = (type: ModelType) => {
  45. if (type === ModelType.textGeneration)
  46. mutateTextGenerationModelList()
  47. if (type === ModelType.embeddings)
  48. mutateEmbeddingsModelList()
  49. }
  50. return (
  51. <ProviderContext.Provider value={{
  52. currentProvider,
  53. textGenerationModelList: textGenerationModelList || [],
  54. embeddingsModelList: embeddingsModelList || [],
  55. speech2textModelList: speech2textModelList || [],
  56. agentThoughtModelList: agentThoughtModelList || [],
  57. updateModelList,
  58. }}>
  59. {children}
  60. </ProviderContext.Provider>
  61. )
  62. }
  63. export default ProviderContext