provider-context.tsx 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. 'use client'
  2. import { createContext, useContext, useContextSelector } from 'use-context-selector'
  3. import useSWR from 'swr'
  4. import { useEffect, useState } from 'react'
  5. import {
  6. fetchModelList,
  7. fetchModelProviders,
  8. fetchSupportRetrievalMethods,
  9. } from '@/service/common'
  10. import {
  11. ModelStatusEnum,
  12. ModelTypeEnum,
  13. } from '@/app/components/header/account-setting/model-provider-page/declarations'
  14. import type { Model, ModelProvider } from '@/app/components/header/account-setting/model-provider-page/declarations'
  15. import type { RETRIEVE_METHOD } from '@/types/app'
  16. import { Plan, type UsagePlanInfo } from '@/app/components/billing/type'
  17. import { fetchCurrentPlanInfo } from '@/service/billing'
  18. import { parseCurrentPlan } from '@/app/components/billing/utils'
  19. import { defaultPlan } from '@/app/components/billing/config'
  20. type ProviderContextState = {
  21. modelProviders: ModelProvider[]
  22. textGenerationModelList: Model[]
  23. supportRetrievalMethods: RETRIEVE_METHOD[]
  24. isAPIKeySet: boolean
  25. plan: {
  26. type: Plan
  27. usage: UsagePlanInfo
  28. total: UsagePlanInfo
  29. }
  30. isFetchedPlan: boolean
  31. enableBilling: boolean
  32. onPlanInfoChanged: () => void
  33. enableReplaceWebAppLogo: boolean
  34. modelLoadBalancingEnabled: boolean
  35. }
  36. const ProviderContext = createContext<ProviderContextState>({
  37. modelProviders: [],
  38. textGenerationModelList: [],
  39. supportRetrievalMethods: [],
  40. isAPIKeySet: true,
  41. plan: {
  42. type: Plan.sandbox,
  43. usage: {
  44. vectorSpace: 32,
  45. buildApps: 12,
  46. teamMembers: 1,
  47. annotatedResponse: 1,
  48. },
  49. total: {
  50. vectorSpace: 200,
  51. buildApps: 50,
  52. teamMembers: 1,
  53. annotatedResponse: 10,
  54. },
  55. },
  56. isFetchedPlan: false,
  57. enableBilling: false,
  58. onPlanInfoChanged: () => { },
  59. enableReplaceWebAppLogo: false,
  60. modelLoadBalancingEnabled: false,
  61. })
  62. export const useProviderContext = () => useContext(ProviderContext)
  63. // Adding a dangling comma to avoid the generic parsing issue in tsx, see:
  64. // https://github.com/microsoft/TypeScript/issues/15713
  65. // eslint-disable-next-line @typescript-eslint/comma-dangle
  66. export const useProviderContextSelector = <T,>(selector: (state: ProviderContextState) => T): T =>
  67. useContextSelector(ProviderContext, selector)
  68. type ProviderContextProviderProps = {
  69. children: React.ReactNode
  70. }
  71. export const ProviderContextProvider = ({
  72. children,
  73. }: ProviderContextProviderProps) => {
  74. const { data: providersData } = useSWR('/workspaces/current/model-providers', fetchModelProviders)
  75. const fetchModelListUrlPrefix = '/workspaces/current/models/model-types/'
  76. const { data: textGenerationModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelTypeEnum.textGeneration}`, fetchModelList)
  77. const { data: supportRetrievalMethods } = useSWR('/datasets/retrieval-setting', fetchSupportRetrievalMethods)
  78. const [plan, setPlan] = useState(defaultPlan)
  79. const [isFetchedPlan, setIsFetchedPlan] = useState(false)
  80. const [enableBilling, setEnableBilling] = useState(true)
  81. const [enableReplaceWebAppLogo, setEnableReplaceWebAppLogo] = useState(false)
  82. const [modelLoadBalancingEnabled, setModelLoadBalancingEnabled] = useState(false)
  83. const fetchPlan = async () => {
  84. const data = await fetchCurrentPlanInfo()
  85. const enabled = data.billing.enabled
  86. setEnableBilling(enabled)
  87. setEnableReplaceWebAppLogo(data.can_replace_logo)
  88. if (enabled) {
  89. setPlan(parseCurrentPlan(data))
  90. setIsFetchedPlan(true)
  91. }
  92. if (data.model_load_balancing_enabled)
  93. setModelLoadBalancingEnabled(true)
  94. }
  95. useEffect(() => {
  96. fetchPlan()
  97. }, [])
  98. return (
  99. <ProviderContext.Provider value={{
  100. modelProviders: providersData?.data || [],
  101. textGenerationModelList: textGenerationModelList?.data || [],
  102. isAPIKeySet: !!textGenerationModelList?.data.some(model => model.status === ModelStatusEnum.active),
  103. supportRetrievalMethods: supportRetrievalMethods?.retrieval_method || [],
  104. plan,
  105. isFetchedPlan,
  106. enableBilling,
  107. onPlanInfoChanged: fetchPlan,
  108. enableReplaceWebAppLogo,
  109. modelLoadBalancingEnabled,
  110. }}>
  111. {children}
  112. </ProviderContext.Provider>
  113. )
  114. }
  115. export default ProviderContext