|  | @@ -7,6 +7,7 @@ from datetime import datetime, timedelta, timezone
 | 
	
		
			
				|  |  |  from hashlib import sha256
 | 
	
		
			
				|  |  |  from typing import Any, Optional
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +from pydantic import BaseModel
 | 
	
		
			
				|  |  |  from sqlalchemy import func
 | 
	
		
			
				|  |  |  from werkzeug.exceptions import Unauthorized
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -49,9 +50,39 @@ from tasks.mail_invite_member_task import send_invite_member_mail_task
 | 
	
		
			
				|  |  |  from tasks.mail_reset_password_task import send_reset_password_mail_task
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +class TokenPair(BaseModel):
 | 
	
		
			
				|  |  | +    access_token: str
 | 
	
		
			
				|  |  | +    refresh_token: str
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +REFRESH_TOKEN_PREFIX = "refresh_token:"
 | 
	
		
			
				|  |  | +ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
 | 
	
		
			
				|  |  | +REFRESH_TOKEN_EXPIRY = timedelta(days=30)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  class AccountService:
 | 
	
		
			
				|  |  |      reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=5, time_window=60 * 60)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    @staticmethod
 | 
	
		
			
				|  |  | +    def _get_refresh_token_key(refresh_token: str) -> str:
 | 
	
		
			
				|  |  | +        return f"{REFRESH_TOKEN_PREFIX}{refresh_token}"
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @staticmethod
 | 
	
		
			
				|  |  | +    def _get_account_refresh_token_key(account_id: str) -> str:
 | 
	
		
			
				|  |  | +        return f"{ACCOUNT_REFRESH_TOKEN_PREFIX}{account_id}"
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @staticmethod
 | 
	
		
			
				|  |  | +    def _store_refresh_token(refresh_token: str, account_id: str) -> None:
 | 
	
		
			
				|  |  | +        redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id)
 | 
	
		
			
				|  |  | +        redis_client.setex(
 | 
	
		
			
				|  |  | +            AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @staticmethod
 | 
	
		
			
				|  |  | +    def _delete_refresh_token(refresh_token: str, account_id: str) -> None:
 | 
	
		
			
				|  |  | +        redis_client.delete(AccountService._get_refresh_token_key(refresh_token))
 | 
	
		
			
				|  |  | +        redis_client.delete(AccountService._get_account_refresh_token_key(account_id))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      @staticmethod
 | 
	
		
			
				|  |  |      def load_user(user_id: str) -> None | Account:
 | 
	
		
			
				|  |  |          account = Account.query.filter_by(id=user_id).first()
 | 
	
	
		
			
				|  | @@ -61,9 +92,7 @@ class AccountService:
 | 
	
		
			
				|  |  |          if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
 | 
	
		
			
				|  |  |              raise Unauthorized("Account is banned or closed.")
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by(
 | 
	
		
			
				|  |  | -            account_id=account.id, current=True
 | 
	
		
			
				|  |  | -        ).first()
 | 
	
		
			
				|  |  | +        current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
 | 
	
		
			
				|  |  |          if current_tenant:
 | 
	
		
			
				|  |  |              account.current_tenant_id = current_tenant.tenant_id
 | 
	
		
			
				|  |  |          else:
 | 
	
	
		
			
				|  | @@ -84,10 +113,12 @@ class AccountService:
 | 
	
		
			
				|  |  |          return account
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @staticmethod
 | 
	
		
			
				|  |  | -    def get_account_jwt_token(account, *, exp: timedelta = timedelta(days=30)):
 | 
	
		
			
				|  |  | +    def get_account_jwt_token(account: Account) -> str:
 | 
	
		
			
				|  |  | +        exp_dt = datetime.now(timezone.utc) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
 | 
	
		
			
				|  |  | +        exp = int(exp_dt.timestamp())
 | 
	
		
			
				|  |  |          payload = {
 | 
	
		
			
				|  |  |              "user_id": account.id,
 | 
	
		
			
				|  |  | -            "exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp,
 | 
	
		
			
				|  |  | +            "exp": exp,
 | 
	
		
			
				|  |  |              "iss": dify_config.EDITION,
 | 
	
		
			
				|  |  |              "sub": "Console API Passport",
 | 
	
		
			
				|  |  |          }
 | 
	
	
		
			
				|  | @@ -213,7 +244,7 @@ class AccountService:
 | 
	
		
			
				|  |  |          return account
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @staticmethod
 | 
	
		
			
				|  |  | -    def update_last_login(account: Account, *, ip_address: str) -> None:
 | 
	
		
			
				|  |  | +    def update_login_info(account: Account, *, ip_address: str) -> None:
 | 
	
		
			
				|  |  |          """Update last login time and ip"""
 | 
	
		
			
				|  |  |          account.last_login_at = datetime.now(timezone.utc).replace(tzinfo=None)
 | 
	
		
			
				|  |  |          account.last_login_ip = ip_address
 | 
	
	
		
			
				|  | @@ -221,22 +252,45 @@ class AccountService:
 | 
	
		
			
				|  |  |          db.session.commit()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @staticmethod
 | 
	
		
			
				|  |  | -    def login(account: Account, *, ip_address: Optional[str] = None):
 | 
	
		
			
				|  |  | +    def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair:
 | 
	
		
			
				|  |  |          if ip_address:
 | 
	
		
			
				|  |  | -            AccountService.update_last_login(account, ip_address=ip_address)
 | 
	
		
			
				|  |  | -        exp = timedelta(days=30)
 | 
	
		
			
				|  |  | -        token = AccountService.get_account_jwt_token(account, exp=exp)
 | 
	
		
			
				|  |  | -        redis_client.set(_get_login_cache_key(account_id=account.id, token=token), "1", ex=int(exp.total_seconds()))
 | 
	
		
			
				|  |  | -        return token
 | 
	
		
			
				|  |  | +            AccountService.update_login_info(account=account, ip_address=ip_address)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        access_token = AccountService.get_account_jwt_token(account=account)
 | 
	
		
			
				|  |  | +        refresh_token = _generate_refresh_token()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        AccountService._store_refresh_token(refresh_token, account.id)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return TokenPair(access_token=access_token, refresh_token=refresh_token)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @staticmethod
 | 
	
		
			
				|  |  | -    def logout(*, account: Account, token: str):
 | 
	
		
			
				|  |  | -        redis_client.delete(_get_login_cache_key(account_id=account.id, token=token))
 | 
	
		
			
				|  |  | +    def logout(*, account: Account) -> None:
 | 
	
		
			
				|  |  | +        refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id))
 | 
	
		
			
				|  |  | +        if refresh_token:
 | 
	
		
			
				|  |  | +            AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @staticmethod
 | 
	
		
			
				|  |  | -    def load_logged_in_account(*, account_id: str, token: str):
 | 
	
		
			
				|  |  | -        if not redis_client.get(_get_login_cache_key(account_id=account_id, token=token)):
 | 
	
		
			
				|  |  | -            return None
 | 
	
		
			
				|  |  | +    def refresh_token(refresh_token: str) -> TokenPair:
 | 
	
		
			
				|  |  | +        # Verify the refresh token
 | 
	
		
			
				|  |  | +        account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token))
 | 
	
		
			
				|  |  | +        if not account_id:
 | 
	
		
			
				|  |  | +            raise ValueError("Invalid refresh token")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        account = AccountService.load_user(account_id.decode("utf-8"))
 | 
	
		
			
				|  |  | +        if not account:
 | 
	
		
			
				|  |  | +            raise ValueError("Invalid account")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # Generate new access token and refresh token
 | 
	
		
			
				|  |  | +        new_access_token = AccountService.get_account_jwt_token(account)
 | 
	
		
			
				|  |  | +        new_refresh_token = _generate_refresh_token()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        AccountService._delete_refresh_token(refresh_token, account.id)
 | 
	
		
			
				|  |  | +        AccountService._store_refresh_token(new_refresh_token, account.id)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @staticmethod
 | 
	
		
			
				|  |  | +    def load_logged_in_account(*, account_id: str):
 | 
	
		
			
				|  |  |          return AccountService.load_user(account_id)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @classmethod
 | 
	
	
		
			
				|  | @@ -258,10 +312,6 @@ class AccountService:
 | 
	
		
			
				|  |  |          return TokenManager.get_token_data(token, "reset_password")
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -def _get_login_cache_key(*, account_id: str, token: str):
 | 
	
		
			
				|  |  | -    return f"account_login:{account_id}:{token}"
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |  class TenantService:
 | 
	
		
			
				|  |  |      @staticmethod
 | 
	
		
			
				|  |  |      def create_tenant(name: str) -> Tenant:
 | 
	
	
		
			
				|  | @@ -698,3 +748,8 @@ class RegisterService:
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |              invitation = json.loads(data)
 | 
	
		
			
				|  |  |              return invitation
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +def _generate_refresh_token(length: int = 64):
 | 
	
		
			
				|  |  | +    token = secrets.token_hex(length)
 | 
	
		
			
				|  |  | +    return token
 |