provider-context.tsx 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. 'use client'
  2. import { createContext, useContext } from 'use-context-selector'
  3. import useSWR from 'swr'
  4. import { useEffect, useState } from 'react'
  5. import { fetchDefaultModal, fetchModelList, fetchSupportRetrievalMethods } from '@/service/common'
  6. import { ModelFeature, ModelType } from '@/app/components/header/account-setting/model-page/declarations'
  7. import type { BackendModel } from '@/app/components/header/account-setting/model-page/declarations'
  8. import type { RETRIEVE_METHOD } from '@/types/app'
  9. import { Plan, type UsagePlanInfo } from '@/app/components/billing/type'
  10. import { fetchCurrentPlanInfo } from '@/service/billing'
  11. import { parseCurrentPlan } from '@/app/components/billing/utils'
  12. import { defaultPlan } from '@/app/components/billing/config'
  13. const ProviderContext = createContext<{
  14. textGenerationModelList: BackendModel[]
  15. embeddingsModelList: BackendModel[]
  16. speech2textModelList: BackendModel[]
  17. rerankModelList: BackendModel[]
  18. agentThoughtModelList: BackendModel[]
  19. updateModelList: (type: ModelType) => void
  20. textGenerationDefaultModel?: BackendModel
  21. mutateTextGenerationDefaultModel: () => void
  22. embeddingsDefaultModel?: BackendModel
  23. isEmbeddingsDefaultModelValid: boolean
  24. mutateEmbeddingsDefaultModel: () => void
  25. speech2textDefaultModel?: BackendModel
  26. mutateSpeech2textDefaultModel: () => void
  27. rerankDefaultModel?: BackendModel
  28. isRerankDefaultModelVaild: boolean
  29. mutateRerankDefaultModel: () => void
  30. supportRetrievalMethods: RETRIEVE_METHOD[]
  31. plan: {
  32. type: Plan
  33. usage: UsagePlanInfo
  34. total: UsagePlanInfo
  35. }
  36. isFetchedPlan: boolean
  37. enableBilling: boolean
  38. }>({
  39. textGenerationModelList: [],
  40. embeddingsModelList: [],
  41. speech2textModelList: [],
  42. rerankModelList: [],
  43. agentThoughtModelList: [],
  44. updateModelList: () => { },
  45. textGenerationDefaultModel: undefined,
  46. mutateTextGenerationDefaultModel: () => { },
  47. speech2textDefaultModel: undefined,
  48. mutateSpeech2textDefaultModel: () => { },
  49. embeddingsDefaultModel: undefined,
  50. isEmbeddingsDefaultModelValid: false,
  51. mutateEmbeddingsDefaultModel: () => { },
  52. rerankDefaultModel: undefined,
  53. isRerankDefaultModelVaild: false,
  54. mutateRerankDefaultModel: () => { },
  55. supportRetrievalMethods: [],
  56. plan: {
  57. type: Plan.sandbox,
  58. usage: {
  59. vectorSpace: 32,
  60. buildApps: 12,
  61. teamMembers: 1,
  62. annotatedResponse: 1,
  63. },
  64. total: {
  65. vectorSpace: 200,
  66. buildApps: 50,
  67. teamMembers: 1,
  68. annotatedResponse: 10,
  69. },
  70. },
  71. isFetchedPlan: false,
  72. enableBilling: false,
  73. })
  74. export const useProviderContext = () => useContext(ProviderContext)
  75. type ProviderContextProviderProps = {
  76. children: React.ReactNode
  77. }
  78. export const ProviderContextProvider = ({
  79. children,
  80. }: ProviderContextProviderProps) => {
  81. const { data: textGenerationDefaultModel, mutate: mutateTextGenerationDefaultModel } = useSWR('/workspaces/current/default-model?model_type=text-generation', fetchDefaultModal)
  82. const { data: embeddingsDefaultModel, mutate: mutateEmbeddingsDefaultModel } = useSWR('/workspaces/current/default-model?model_type=embeddings', fetchDefaultModal)
  83. const { data: speech2textDefaultModel, mutate: mutateSpeech2textDefaultModel } = useSWR('/workspaces/current/default-model?model_type=speech2text', fetchDefaultModal)
  84. const { data: rerankDefaultModel, mutate: mutateRerankDefaultModel } = useSWR('/workspaces/current/default-model?model_type=reranking', fetchDefaultModal)
  85. const fetchModelListUrlPrefix = '/workspaces/current/models/model-type/'
  86. const { data: textGenerationModelList, mutate: mutateTextGenerationModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.textGeneration}`, fetchModelList)
  87. const { data: embeddingsModelList, mutate: mutateEmbeddingsModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.embeddings}`, fetchModelList)
  88. const { data: speech2textModelList, mutate: mutateSpeech2textModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.speech2text}`, fetchModelList)
  89. const { data: rerankModelList, mutate: mutateRerankModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.reranking}`, fetchModelList)
  90. const { data: supportRetrievalMethods } = useSWR('/datasets/retrieval-setting', fetchSupportRetrievalMethods)
  91. const agentThoughtModelList = textGenerationModelList?.filter((item) => {
  92. return item.features?.includes(ModelFeature.agentThought)
  93. })
  94. const isRerankDefaultModelVaild = !!rerankModelList?.find(
  95. item => item.model_name === rerankDefaultModel?.model_name && item.model_provider.provider_name === rerankDefaultModel?.model_provider.provider_name,
  96. )
  97. const isEmbeddingsDefaultModelValid = !!embeddingsModelList?.find(
  98. item => item.model_name === embeddingsDefaultModel?.model_name && item.model_provider.provider_name === embeddingsDefaultModel?.model_provider.provider_name,
  99. )
  100. const updateModelList = (type: ModelType) => {
  101. if (type === ModelType.textGeneration)
  102. mutateTextGenerationModelList()
  103. if (type === ModelType.embeddings)
  104. mutateEmbeddingsModelList()
  105. if (type === ModelType.speech2text)
  106. mutateSpeech2textModelList()
  107. if (type === ModelType.reranking)
  108. mutateRerankModelList()
  109. }
  110. const [plan, setPlan] = useState(defaultPlan)
  111. const [isFetchedPlan, setIsFetchedPlan] = useState(false)
  112. const [enableBilling, setEnableBilling] = useState(true)
  113. useEffect(() => {
  114. (async () => {
  115. const data = await fetchCurrentPlanInfo()
  116. const enabled = data.enabled
  117. setEnableBilling(enabled)
  118. if (enabled) {
  119. setPlan(parseCurrentPlan(data))
  120. // setPlan(parseCurrentPlan({
  121. // ...data,
  122. // annotation_quota_limit: {
  123. // ...data.annotation_quota_limit,
  124. // limit: 10,
  125. // },
  126. // }))
  127. setIsFetchedPlan(true)
  128. }
  129. })()
  130. }, [])
  131. return (
  132. <ProviderContext.Provider value={{
  133. textGenerationModelList: textGenerationModelList || [],
  134. embeddingsModelList: embeddingsModelList || [],
  135. speech2textModelList: speech2textModelList || [],
  136. rerankModelList: rerankModelList || [],
  137. agentThoughtModelList: agentThoughtModelList || [],
  138. updateModelList,
  139. textGenerationDefaultModel,
  140. mutateTextGenerationDefaultModel,
  141. embeddingsDefaultModel,
  142. mutateEmbeddingsDefaultModel,
  143. speech2textDefaultModel,
  144. mutateSpeech2textDefaultModel,
  145. rerankDefaultModel,
  146. isRerankDefaultModelVaild,
  147. isEmbeddingsDefaultModelValid,
  148. mutateRerankDefaultModel,
  149. supportRetrievalMethods: supportRetrievalMethods?.retrieval_method || [],
  150. plan,
  151. isFetchedPlan,
  152. enableBilling,
  153. }}>
  154. {children}
  155. </ProviderContext.Provider>
  156. )
  157. }
  158. export default ProviderContext