Browse Source

refactor the logic of refreshing access_token (#10068)

NFish 5 months ago
parent
commit
302f4407f6

+ 3 - 2
web/app/account/avatar.tsx

@@ -23,8 +23,9 @@ export default function AppSelector() {
       params: {},
     })
 
-    if (localStorage?.getItem('console_token'))
-      localStorage.removeItem('console_token')
+    localStorage.removeItem('setup_status')
+    localStorage.removeItem('console_token')
+    localStorage.removeItem('refresh_token')
 
     router.push('/signin')
   }

+ 3 - 2
web/app/components/header/account-dropdown/index.tsx

@@ -47,8 +47,9 @@ export default function AppSelector({ isMobile }: IAppSelector) {
       params: {},
     })
 
-    if (localStorage?.getItem('console_token'))
-      localStorage.removeItem('console_token')
+    localStorage.removeItem('setup_status')
+    localStorage.removeItem('console_token')
+    localStorage.removeItem('refresh_token')
 
     router.push('/signin')
   }

+ 12 - 27
web/app/components/swr-initor.tsx

@@ -4,7 +4,6 @@ import { SWRConfig } from 'swr'
 import { useCallback, useEffect, useState } from 'react'
 import type { ReactNode } from 'react'
 import { usePathname, useRouter, useSearchParams } from 'next/navigation'
-import useRefreshToken from '@/hooks/use-refresh-token'
 import { fetchSetupStatus } from '@/service/common'
 
 type SwrInitorProps = {
@@ -15,12 +14,11 @@ const SwrInitor = ({
 }: SwrInitorProps) => {
   const router = useRouter()
   const searchParams = useSearchParams()
-  const pathname = usePathname()
-  const { getNewAccessToken } = useRefreshToken()
-  const consoleToken = searchParams.get('access_token')
-  const refreshToken = searchParams.get('refresh_token')
+  const consoleToken = decodeURIComponent(searchParams.get('access_token') || '')
+  const refreshToken = decodeURIComponent(searchParams.get('refresh_token') || '')
   const consoleTokenFromLocalStorage = localStorage?.getItem('console_token')
   const refreshTokenFromLocalStorage = localStorage?.getItem('refresh_token')
+  const pathname = usePathname()
   const [init, setInit] = useState(false)
 
   const isSetupFinished = useCallback(async () => {
@@ -41,25 +39,6 @@ const SwrInitor = ({
     }
   }, [])
 
-  const setRefreshToken = useCallback(async () => {
-    try {
-      if (!(consoleToken || refreshToken || consoleTokenFromLocalStorage || refreshTokenFromLocalStorage))
-        return Promise.reject(new Error('No token found'))
-
-      if (consoleTokenFromLocalStorage && refreshTokenFromLocalStorage)
-        await getNewAccessToken()
-
-      if (consoleToken && refreshToken) {
-        localStorage.setItem('console_token', consoleToken)
-        localStorage.setItem('refresh_token', refreshToken)
-        await getNewAccessToken()
-      }
-    }
-    catch (error) {
-      return Promise.reject(error)
-    }
-  }, [consoleToken, refreshToken, consoleTokenFromLocalStorage, refreshTokenFromLocalStorage, getNewAccessToken])
-
   useEffect(() => {
     (async () => {
       try {
@@ -68,9 +47,15 @@ const SwrInitor = ({
           router.replace('/install')
           return
         }
-        await setRefreshToken()
-        if (searchParams.has('access_token') || searchParams.has('refresh_token'))
+        if (!((consoleToken && refreshToken) || (consoleTokenFromLocalStorage && refreshTokenFromLocalStorage))) {
+          router.replace('/signin')
+          return
+        }
+        if (searchParams.has('access_token') || searchParams.has('refresh_token')) {
+          consoleToken && localStorage.setItem('console_token', consoleToken)
+          refreshToken && localStorage.setItem('refresh_token', refreshToken)
           router.replace(pathname)
+        }
 
         setInit(true)
       }
@@ -78,7 +63,7 @@ const SwrInitor = ({
         router.replace('/signin')
       }
     })()
-  }, [isSetupFinished, setRefreshToken, router, pathname, searchParams])
+  }, [isSetupFinished, router, pathname, searchParams, consoleToken, refreshToken, consoleTokenFromLocalStorage, refreshTokenFromLocalStorage])
 
   return init
     ? (

+ 1 - 4
web/app/signin/normalForm.tsx

@@ -12,11 +12,9 @@ import cn from '@/utils/classnames'
 import { getSystemFeatures, invitationCheck } from '@/service/common'
 import { defaultSystemFeatures } from '@/types/feature'
 import Toast from '@/app/components/base/toast'
-import useRefreshToken from '@/hooks/use-refresh-token'
 import { IS_CE_EDITION } from '@/config'
 
 const NormalForm = () => {
-  const { getNewAccessToken } = useRefreshToken()
   const { t } = useTranslation()
   const router = useRouter()
   const searchParams = useSearchParams()
@@ -38,7 +36,6 @@ const NormalForm = () => {
       if (consoleToken && refreshToken) {
         localStorage.setItem('console_token', consoleToken)
         localStorage.setItem('refresh_token', refreshToken)
-        getNewAccessToken()
         router.replace('/apps')
         return
       }
@@ -71,7 +68,7 @@ const NormalForm = () => {
       setSystemFeatures(defaultSystemFeatures)
     }
     finally { setIsLoading(false) }
-  }, [consoleToken, refreshToken, message, router, invite_token, isInviteLink, getNewAccessToken])
+  }, [consoleToken, refreshToken, message, router, invite_token, isInviteLink])
   useEffect(() => {
     init()
   }, [init])

+ 0 - 99
web/hooks/use-refresh-token.ts

@@ -1,99 +0,0 @@
-'use client'
-import { useCallback, useEffect, useRef } from 'react'
-import { jwtDecode } from 'jwt-decode'
-import dayjs from 'dayjs'
-import utc from 'dayjs/plugin/utc'
-import { useRouter } from 'next/navigation'
-import type { CommonResponse } from '@/models/common'
-import { fetchNewToken } from '@/service/common'
-import { fetchWithRetry } from '@/utils'
-
-dayjs.extend(utc)
-
-const useRefreshToken = () => {
-  const router = useRouter()
-  const timer = useRef<NodeJS.Timeout>()
-  const advanceTime = useRef<number>(5 * 60 * 1000)
-
-  const getExpireTime = useCallback((token: string) => {
-    if (!token)
-      return 0
-    const decoded = jwtDecode(token)
-    return (decoded.exp || 0) * 1000
-  }, [])
-
-  const getCurrentTimeStamp = useCallback(() => {
-    return dayjs.utc().valueOf()
-  }, [])
-
-  const handleError = useCallback(() => {
-    localStorage?.removeItem('is_refreshing')
-    localStorage?.removeItem('console_token')
-    localStorage?.removeItem('refresh_token')
-    router.replace('/signin')
-  }, [])
-
-  const getNewAccessToken = useCallback(async () => {
-    const currentAccessToken = localStorage?.getItem('console_token')
-    const currentRefreshToken = localStorage?.getItem('refresh_token')
-    if (!currentAccessToken || !currentRefreshToken) {
-      handleError()
-      return new Error('No access token or refresh token found')
-    }
-    if (localStorage?.getItem('is_refreshing') === '1') {
-      clearTimeout(timer.current)
-      timer.current = setTimeout(() => {
-        getNewAccessToken()
-      }, 1000)
-      return null
-    }
-    const currentTokenExpireTime = getExpireTime(currentAccessToken)
-    if (getCurrentTimeStamp() + advanceTime.current > currentTokenExpireTime) {
-      localStorage?.setItem('is_refreshing', '1')
-      const [e, res] = await fetchWithRetry(fetchNewToken({
-        body: { refresh_token: currentRefreshToken },
-      }) as Promise<CommonResponse & { data: { access_token: string; refresh_token: string } }>)
-      if (e) {
-        handleError()
-        return e
-      }
-      const { access_token, refresh_token } = res.data
-      localStorage?.setItem('is_refreshing', '0')
-      localStorage?.setItem('console_token', access_token)
-      localStorage?.setItem('refresh_token', refresh_token)
-      const newTokenExpireTime = getExpireTime(access_token)
-      clearTimeout(timer.current)
-      timer.current = setTimeout(() => {
-        getNewAccessToken()
-      }, newTokenExpireTime - advanceTime.current - getCurrentTimeStamp())
-    }
-    else {
-      const newTokenExpireTime = getExpireTime(currentAccessToken)
-      clearTimeout(timer.current)
-      timer.current = setTimeout(() => {
-        getNewAccessToken()
-      }, newTokenExpireTime - advanceTime.current - getCurrentTimeStamp())
-    }
-    return null
-  }, [getExpireTime, getCurrentTimeStamp, handleError])
-
-  const handleVisibilityChange = useCallback(() => {
-    if (document.visibilityState === 'visible')
-      getNewAccessToken()
-  }, [])
-
-  useEffect(() => {
-    window.addEventListener('visibilitychange', handleVisibilityChange)
-    return () => {
-      window.removeEventListener('visibilitychange', handleVisibilityChange)
-      clearTimeout(timer.current)
-      localStorage?.removeItem('is_refreshing')
-    }
-  }, [])
-
-  return {
-    getNewAccessToken,
-  }
-}
-
-export default useRefreshToken

+ 78 - 52
web/service/base.ts

@@ -1,3 +1,4 @@
+import { refreshAccessTokenOrRelogin } from './refresh-token'
 import { API_PREFIX, IS_CE_EDITION, PUBLIC_API_PREFIX } from '@/config'
 import Toast from '@/app/components/base/toast'
 import type { AnnotationReply, MessageEnd, MessageReplace, ThoughtItem } from '@/app/components/base/chat/chat/type'
@@ -356,39 +357,8 @@ const baseFetch = <T>(
           if (!/^(2|3)\d{2}$/.test(String(res.status))) {
             const bodyJson = res.json()
             switch (res.status) {
-              case 401: {
-                if (isPublicAPI) {
-                  return bodyJson.then((data: ResponseError) => {
-                    if (data.code === 'web_sso_auth_required')
-                      requiredWebSSOLogin()
-
-                    if (data.code === 'unauthorized') {
-                      removeAccessToken()
-                      globalThis.location.reload()
-                    }
-
-                    return Promise.reject(data)
-                  })
-                }
-                const loginUrl = `${globalThis.location.origin}/signin`
-                bodyJson.then((data: ResponseError) => {
-                  if (data.code === 'init_validate_failed' && IS_CE_EDITION && !silent)
-                    Toast.notify({ type: 'error', message: data.message, duration: 4000 })
-                  else if (data.code === 'not_init_validated' && IS_CE_EDITION)
-                    globalThis.location.href = `${globalThis.location.origin}/init`
-                  else if (data.code === 'not_setup' && IS_CE_EDITION)
-                    globalThis.location.href = `${globalThis.location.origin}/install`
-                  else if (location.pathname !== '/signin' || !IS_CE_EDITION)
-                    globalThis.location.href = loginUrl
-                  else if (!silent)
-                    Toast.notify({ type: 'error', message: data.message })
-                }).catch(() => {
-                  // Handle any other errors
-                  globalThis.location.href = loginUrl
-                })
-
-                break
-              }
+              case 401:
+                return Promise.reject(resClone)
               case 403:
                 bodyJson.then((data: ResponseError) => {
                   if (!silent)
@@ -484,7 +454,9 @@ export const upload = (options: any, isPublicAPI?: boolean, url?: string, search
 export const ssePost = (
   url: string,
   fetchOptions: FetchOptionType,
-  {
+  otherOptions: IOtherOptions,
+) => {
+  const {
     isPublicAPI = false,
     onData,
     onCompleted,
@@ -507,8 +479,7 @@ export const ssePost = (
     onTextReplace,
     onError,
     getAbortController,
-  }: IOtherOptions,
-) => {
+  } = otherOptions
   const abortController = new AbortController()
 
   const options = Object.assign({}, baseOptions, {
@@ -532,21 +503,29 @@ export const ssePost = (
   globalThis.fetch(urlWithPrefix, options as RequestInit)
     .then((res) => {
       if (!/^(2|3)\d{2}$/.test(String(res.status))) {
-        res.json().then((data: any) => {
-          if (isPublicAPI) {
-            if (data.code === 'web_sso_auth_required')
-              requiredWebSSOLogin()
-
-            if (data.code === 'unauthorized') {
-              removeAccessToken()
-              globalThis.location.reload()
-            }
-            if (res.status === 401)
-              return
-          }
-          Toast.notify({ type: 'error', message: data.message || 'Server Error' })
-        })
-        onError?.('Server Error')
+        if (res.status === 401) {
+          refreshAccessTokenOrRelogin(TIME_OUT).then(() => {
+            ssePost(url, fetchOptions, otherOptions)
+          }).catch(() => {
+            res.json().then((data: any) => {
+              if (isPublicAPI) {
+                if (data.code === 'web_sso_auth_required')
+                  requiredWebSSOLogin()
+
+                if (data.code === 'unauthorized') {
+                  removeAccessToken()
+                  globalThis.location.reload()
+                }
+              }
+            })
+          })
+        }
+        else {
+          res.json().then((data) => {
+            Toast.notify({ type: 'error', message: data.message || 'Server Error' })
+          })
+          onError?.('Server Error')
+        }
         return
       }
       return handleStream(res, (str: string, isFirstMessage: boolean, moreInfo: IOnDataMoreInfo) => {
@@ -568,7 +547,54 @@ export const ssePost = (
 
 // base request
 export const request = <T>(url: string, options = {}, otherOptions?: IOtherOptions) => {
-  return baseFetch<T>(url, options, otherOptions || {})
+  return new Promise<T>((resolve, reject) => {
+    const otherOptionsForBaseFetch = otherOptions || {}
+    baseFetch<T>(url, options, otherOptionsForBaseFetch).then(resolve).catch((errResp) => {
+      if (errResp?.status === 401) {
+        return refreshAccessTokenOrRelogin(TIME_OUT).then(() => {
+          baseFetch<T>(url, options, otherOptionsForBaseFetch).then(resolve).catch(reject)
+        }).catch(() => {
+          const {
+            isPublicAPI = false,
+            silent,
+          } = otherOptionsForBaseFetch
+          const bodyJson = errResp.json()
+          if (isPublicAPI) {
+            return bodyJson.then((data: ResponseError) => {
+              if (data.code === 'web_sso_auth_required')
+                requiredWebSSOLogin()
+
+              if (data.code === 'unauthorized') {
+                removeAccessToken()
+                globalThis.location.reload()
+              }
+
+              return Promise.reject(data)
+            })
+          }
+          const loginUrl = `${globalThis.location.origin}/signin`
+          bodyJson.then((data: ResponseError) => {
+            if (data.code === 'init_validate_failed' && IS_CE_EDITION && !silent)
+              Toast.notify({ type: 'error', message: data.message, duration: 4000 })
+            else if (data.code === 'not_init_validated' && IS_CE_EDITION)
+              globalThis.location.href = `${globalThis.location.origin}/init`
+            else if (data.code === 'not_setup' && IS_CE_EDITION)
+              globalThis.location.href = `${globalThis.location.origin}/install`
+            else if (location.pathname !== '/signin' || !IS_CE_EDITION)
+              globalThis.location.href = loginUrl
+            else if (!silent)
+              Toast.notify({ type: 'error', message: data.message })
+          }).catch(() => {
+            // Handle any other errors
+            globalThis.location.href = loginUrl
+          })
+        })
+      }
+      else {
+        reject(errResp)
+      }
+    })
+  })
 }
 
 // request methods

+ 75 - 0
web/service/refresh-token.ts

@@ -0,0 +1,75 @@
+import { apiPrefix } from '@/config'
+import { fetchWithRetry } from '@/utils'
+
+let isRefreshing = false
+function waitUntilTokenRefreshed() {
+  return new Promise<void>((resolve, reject) => {
+    function _check() {
+      const isRefreshingSign = localStorage.getItem('is_refreshing')
+      if ((isRefreshingSign && isRefreshingSign === '1') || isRefreshing) {
+        setTimeout(() => {
+          _check()
+        }, 1000)
+      }
+      else {
+        resolve()
+      }
+    }
+    _check()
+  })
+}
+
+// only one request can send
+async function getNewAccessToken(): Promise<void> {
+  try {
+    const isRefreshingSign = localStorage.getItem('is_refreshing')
+    if ((isRefreshingSign && isRefreshingSign === '1') || isRefreshing) {
+      await waitUntilTokenRefreshed()
+    }
+    else {
+      globalThis.localStorage.setItem('is_refreshing', '1')
+      isRefreshing = true
+      const refresh_token = globalThis.localStorage.getItem('refresh_token')
+
+      // Do not use baseFetch to refresh tokens.
+      // If a 401 response occurs and baseFetch itself attempts to refresh the token,
+      // it can lead to an infinite loop if the refresh attempt also returns 401.
+      // To avoid this, handle token refresh separately in a dedicated function
+      // that does not call baseFetch and uses a single retry mechanism.
+      const [error, ret] = await fetchWithRetry(globalThis.fetch(`${apiPrefix}/refresh-token`, {
+        method: 'POST',
+        headers: {
+          'Content-Type': 'application/json;utf-8',
+        },
+        body: JSON.stringify({ refresh_token }),
+      }))
+      if (error) {
+        return Promise.reject(error)
+      }
+      else {
+        if (ret.status === 401)
+          return Promise.reject(ret)
+
+        const { data } = await ret.json()
+        globalThis.localStorage.setItem('console_token', data.access_token)
+        globalThis.localStorage.setItem('refresh_token', data.refresh_token)
+      }
+    }
+  }
+  catch (error) {
+    console.error(error)
+    return Promise.reject(error)
+  }
+  finally {
+    isRefreshing = false
+    globalThis.localStorage.removeItem('is_refreshing')
+  }
+}
+
+export async function refreshAccessTokenOrRelogin(timeout: number) {
+  return Promise.race([new Promise<void>((resolve, reject) => setTimeout(() => {
+    isRefreshing = false
+    globalThis.localStorage.removeItem('is_refreshing')
+    reject(new Error('request timeout'))
+  }, timeout)), getNewAccessToken()])
+}