Browse Source

Feat/enterprise sso (#3602)

Garfield Dai 1 year ago
parent
commit
4481906be2

+ 4 - 1
api/app.py

@@ -115,7 +115,7 @@ def initialize_extensions(app):
 @login_manager.request_loader
 def load_user_from_request(request_from_flask_login):
     """Load user based on the request."""
-    if request.blueprint == 'console':
+    if request.blueprint in ['console', 'inner_api']:
         # Check if the user_id contains a dot, indicating the old format
         auth_header = request.headers.get('Authorization', '')
         if not auth_header:
@@ -153,6 +153,7 @@ def register_blueprints(app):
     from controllers.files import bp as files_bp
     from controllers.service_api import bp as service_api_bp
     from controllers.web import bp as web_bp
+    from controllers.inner_api import bp as inner_api_bp
 
     CORS(service_api_bp,
          allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
@@ -188,6 +189,8 @@ def register_blueprints(app):
          )
     app.register_blueprint(files_bp)
 
+    app.register_blueprint(inner_api_bp)
+
 
 # create app
 app = create_app()

+ 9 - 0
api/config.py

@@ -69,6 +69,8 @@ DEFAULTS = {
     'TOOL_ICON_CACHE_MAX_AGE': 3600,
     'MILVUS_DATABASE': 'default',
     'KEYWORD_DATA_SOURCE_TYPE': 'database',
+    'INNER_API': 'False',
+    'ENTERPRISE_ENABLED': 'False',
 }
 
 
@@ -133,6 +135,11 @@ class Config:
         # Alternatively you can set it with `SECRET_KEY` environment variable.
         self.SECRET_KEY = get_env('SECRET_KEY')
 
+        # Enable or disable the inner API.
+        self.INNER_API = get_bool_env('INNER_API')
+        # The inner API key is used to authenticate the inner API.
+        self.INNER_API_KEY = get_env('INNER_API_KEY')
+
         # cors settings
         self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
             'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
@@ -327,6 +334,8 @@ class Config:
         self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE')
 
         self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE')
+        self.ENTERPRISE_ENABLED = get_bool_env('ENTERPRISE_ENABLED')
+
 
 class CloudEditionConfig(Config):
 

+ 3 - 1
api/controllers/console/__init__.py

@@ -19,4 +19,6 @@ from .datasets import data_source, datasets, datasets_document, datasets_segment
 from .explore import (audio, completion, conversation, installed_app, message, parameter, recommended_app,
                       saved_message, workflow)
 # Import workspace controllers
-from .workspace import account, members, model_providers, models, tool_providers, workspace
+from .workspace import account, members, model_providers, models, tool_providers, workspace
+# Import enterprise controllers
+from .enterprise import enterprise_sso

+ 6 - 3
api/controllers/console/auth/login.py

@@ -26,10 +26,13 @@ class LoginApi(Resource):
 
         try:
             account = AccountService.authenticate(args['email'], args['password'])
-        except services.errors.account.AccountLoginError:
-            return {'code': 'unauthorized', 'message': 'Invalid email or password'}, 401
+        except services.errors.account.AccountLoginError as e:
+            return {'code': 'unauthorized', 'message': str(e)}, 401
 
-        TenantService.create_owner_tenant_if_not_exist(account)
+        # SELF_HOSTED only have one workspace
+        tenants = TenantService.get_join_tenants(account)
+        if len(tenants) == 0:
+            return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'}
 
         AccountService.update_last_login(account, request)
 

+ 0 - 0
api/controllers/console/enterprise/__init__.py


+ 59 - 0
api/controllers/console/enterprise/enterprise_sso.py

@@ -0,0 +1,59 @@
+from flask import current_app, redirect
+from flask_restful import Resource, reqparse
+
+from controllers.console import api
+from controllers.console.setup import setup_required
+from services.enterprise.enterprise_sso_service import EnterpriseSSOService
+
+
+class EnterpriseSSOSamlLogin(Resource):
+
+    @setup_required
+    def get(self):
+        return EnterpriseSSOService.get_sso_saml_login()
+
+
+class EnterpriseSSOSamlAcs(Resource):
+
+    @setup_required
+    def post(self):
+        parser = reqparse.RequestParser()
+        parser.add_argument('SAMLResponse', type=str, required=True, location='form')
+        args = parser.parse_args()
+        saml_response = args['SAMLResponse']
+
+        try:
+            token = EnterpriseSSOService.post_sso_saml_acs(saml_response)
+            return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}')
+        except Exception as e:
+            return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}')
+
+
+class EnterpriseSSOOidcLogin(Resource):
+
+    @setup_required
+    def get(self):
+        return EnterpriseSSOService.get_sso_oidc_login()
+
+
+class EnterpriseSSOOidcCallback(Resource):
+
+    @setup_required
+    def get(self):
+        parser = reqparse.RequestParser()
+        parser.add_argument('state', type=str, required=True, location='args')
+        parser.add_argument('code', type=str, required=True, location='args')
+        parser.add_argument('oidc-state', type=str, required=True, location='cookies')
+        args = parser.parse_args()
+
+        try:
+            token = EnterpriseSSOService.get_sso_oidc_callback(args)
+            return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}')
+        except Exception as e:
+            return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}')
+
+
+api.add_resource(EnterpriseSSOSamlLogin, '/enterprise/sso/saml/login')
+api.add_resource(EnterpriseSSOSamlAcs, '/enterprise/sso/saml/acs')
+api.add_resource(EnterpriseSSOOidcLogin, '/enterprise/sso/oidc/login')
+api.add_resource(EnterpriseSSOOidcCallback, '/enterprise/sso/oidc/callback')

+ 7 - 0
api/controllers/console/feature.py

@@ -1,6 +1,7 @@
 from flask_login import current_user
 from flask_restful import Resource
 
+from services.enterprise.enterprise_feature_service import EnterpriseFeatureService
 from services.feature_service import FeatureService
 
 from . import api
@@ -14,4 +15,10 @@ class FeatureApi(Resource):
         return FeatureService.get_features(current_user.current_tenant_id).dict()
 
 
+class EnterpriseFeatureApi(Resource):
+    def get(self):
+        return EnterpriseFeatureService.get_enterprise_features().dict()
+
+
 api.add_resource(FeatureApi, '/features')
+api.add_resource(EnterpriseFeatureApi, '/enterprise-features')

+ 2 - 0
api/controllers/console/setup.py

@@ -58,6 +58,8 @@ class SetupApi(Resource):
             password=args['password']
         )
 
+        TenantService.create_owner_tenant_if_not_exist(account)
+
         setup()
         AccountService.update_last_login(account, request)
 

+ 12 - 1
api/controllers/console/workspace/workspace.py

@@ -3,6 +3,7 @@ import logging
 from flask import request
 from flask_login import current_user
 from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
+from werkzeug.exceptions import Unauthorized
 
 import services
 from controllers.console import api
@@ -19,7 +20,7 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
 from extensions.ext_database import db
 from libs.helper import TimestampField
 from libs.login import login_required
-from models.account import Tenant
+from models.account import Tenant, TenantStatus
 from services.account_service import TenantService
 from services.file_service import FileService
 from services.workspace_service import WorkspaceService
@@ -116,6 +117,16 @@ class TenantApi(Resource):
 
         tenant = current_user.current_tenant
 
+        if tenant.status == TenantStatus.ARCHIVE:
+            tenants = TenantService.get_join_tenants(current_user)
+            # if there is any tenant, switch to the first one
+            if len(tenants) > 0:
+                TenantService.switch_tenant(current_user, tenants[0].id)
+                tenant = tenants[0]
+            # else, raise Unauthorized
+            else:
+                raise Unauthorized('workspace is archived')
+
         return WorkspaceService.get_tenant_info(tenant), 200
 
 

+ 8 - 0
api/controllers/inner_api/__init__.py

@@ -0,0 +1,8 @@
+from flask import Blueprint
+from libs.external_api import ExternalApi
+
+bp = Blueprint('inner_api', __name__, url_prefix='/inner/api')
+api = ExternalApi(bp)
+
+from .workspace import workspace
+

+ 0 - 0
api/controllers/inner_api/workspace/__init__.py


+ 37 - 0
api/controllers/inner_api/workspace/workspace.py

@@ -0,0 +1,37 @@
+from flask_restful import Resource, reqparse
+
+from controllers.console.setup import setup_required
+from controllers.inner_api import api
+from controllers.inner_api.wraps import inner_api_only
+from events.tenant_event import tenant_was_created
+from models.account import Account
+from services.account_service import TenantService
+
+
+class EnterpriseWorkspace(Resource):
+
+    @setup_required
+    @inner_api_only
+    def post(self):
+        parser = reqparse.RequestParser()
+        parser.add_argument('name', type=str, required=True, location='json')
+        parser.add_argument('owner_email', type=str, required=True, location='json')
+        args = parser.parse_args()
+
+        account = Account.query.filter_by(email=args['owner_email']).first()
+        if account is None:
+            return {
+                'message': 'owner account not found.'
+            }, 404
+
+        tenant = TenantService.create_tenant(args['name'])
+        TenantService.create_tenant_member(tenant, account, role='owner')
+
+        tenant_was_created.send(tenant)
+
+        return {
+            'message': 'enterprise workspace created.'
+        }
+
+
+api.add_resource(EnterpriseWorkspace, '/enterprise/workspace')

+ 61 - 0
api/controllers/inner_api/wraps.py

@@ -0,0 +1,61 @@
+from base64 import b64encode
+from functools import wraps
+from hashlib import sha1
+from hmac import new as hmac_new
+
+from flask import abort, current_app, request
+
+from extensions.ext_database import db
+from models.model import EndUser
+
+
+def inner_api_only(view):
+    @wraps(view)
+    def decorated(*args, **kwargs):
+        if not current_app.config['INNER_API']:
+            abort(404)
+
+        # get header 'X-Inner-Api-Key'
+        inner_api_key = request.headers.get('X-Inner-Api-Key')
+        if not inner_api_key or inner_api_key != current_app.config['INNER_API_KEY']:
+            abort(404)
+
+        return view(*args, **kwargs)
+
+    return decorated
+
+
+def inner_api_user_auth(view):
+    @wraps(view)
+    def decorated(*args, **kwargs):
+        if not current_app.config['INNER_API']:
+            return view(*args, **kwargs)
+
+        # get header 'X-Inner-Api-Key'
+        authorization = request.headers.get('Authorization')
+        if not authorization:
+            return view(*args, **kwargs)
+
+        parts = authorization.split(':')
+        if len(parts) != 2:
+            return view(*args, **kwargs)
+
+        user_id, token = parts
+        if ' ' in user_id:
+            user_id = user_id.split(' ')[1]
+
+        inner_api_key = request.headers.get('X-Inner-Api-Key')
+
+        data_to_sign = f'DIFY {user_id}'
+
+        signature = hmac_new(inner_api_key.encode('utf-8'), data_to_sign.encode('utf-8'), sha1)
+        signature = b64encode(signature.digest()).decode('utf-8')
+
+        if signature != token:
+            return view(*args, **kwargs)
+
+        kwargs['user'] = db.session.query(EndUser).filter(EndUser.id == user_id).first()
+
+        return view(*args, **kwargs)
+
+    return decorated

+ 6 - 1
api/controllers/service_api/wraps.py

@@ -12,7 +12,7 @@ from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
 
 from extensions.ext_database import db
 from libs.login import _get_user
-from models.account import Account, Tenant, TenantAccountJoin
+from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
 from models.model import ApiToken, App, EndUser
 from services.feature_service import FeatureService
 
@@ -47,6 +47,10 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
             if not app_model.enable_api:
                 raise NotFound()
 
+            tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first()
+            if tenant.status == TenantStatus.ARCHIVE:
+                raise NotFound()
+
             kwargs['app_model'] = app_model
 
             if fetch_user_arg:
@@ -137,6 +141,7 @@ def validate_dataset_token(view=None):
                 .filter(Tenant.id == api_token.tenant_id) \
                 .filter(TenantAccountJoin.tenant_id == Tenant.id) \
                 .filter(TenantAccountJoin.role.in_(['owner'])) \
+                .filter(Tenant.status == TenantStatus.NORMAL) \
                 .one_or_none() # TODO: only owner information is required, so only one is returned.
             if tenant_account_join:
                 tenant, ta = tenant_account_join

+ 4 - 0
api/controllers/web/site.py

@@ -6,6 +6,7 @@ from werkzeug.exceptions import Forbidden
 from controllers.web import api
 from controllers.web.wraps import WebApiResource
 from extensions.ext_database import db
+from models.account import TenantStatus
 from models.model import Site
 from services.feature_service import FeatureService
 
@@ -54,6 +55,9 @@ class AppSiteApi(WebApiResource):
         if not site:
             raise Forbidden()
 
+        if app_model.tenant.status == TenantStatus.ARCHIVE:
+            raise Forbidden()
+
         can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
 
         return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo)

+ 6 - 0
api/models/account.py

@@ -105,6 +105,12 @@ class Account(UserMixin, db.Model):
     def is_admin_or_owner(self):
         return self._current_tenant.current_role in ['admin', 'owner']
 
+
+class TenantStatus(str, enum.Enum):
+    NORMAL = 'normal'
+    ARCHIVE = 'archive'
+
+
 class Tenant(db.Model):
     __tablename__ = 'tenants'
     __table_args__ = (

+ 9 - 4
api/services/account_service.py

@@ -8,7 +8,7 @@ from typing import Any, Optional
 
 from flask import current_app
 from sqlalchemy import func
-from werkzeug.exceptions import Forbidden
+from werkzeug.exceptions import Unauthorized
 
 from constants.languages import language_timezone_mapping, languages
 from events.tenant_event import tenant_was_created
@@ -44,7 +44,7 @@ class AccountService:
             return None
 
         if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]:
-            raise Forbidden('Account is banned or closed.')
+            raise Unauthorized("Account is banned or closed.")
 
         current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
         if current_tenant:
@@ -255,7 +255,7 @@ class TenantService:
         """Get account join tenants"""
         return db.session.query(Tenant).join(
             TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id
-        ).filter(TenantAccountJoin.account_id == account.id).all()
+        ).filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL).all()
 
     @staticmethod
     def get_current_tenant_by_account(account: Account):
@@ -279,7 +279,12 @@ class TenantService:
         if tenant_id is None:
             raise ValueError("Tenant ID must be provided.")
 
-        tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first()
+        tenant_account_join = db.session.query(TenantAccountJoin).join(Tenant, TenantAccountJoin.tenant_id == Tenant.id).filter(
+            TenantAccountJoin.account_id == account.id,
+            TenantAccountJoin.tenant_id == tenant_id,
+            Tenant.status == TenantStatus.NORMAL,
+        ).first()
+
         if not tenant_account_join:
             raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
         else:

+ 0 - 0
api/services/enterprise/__init__.py


+ 20 - 0
api/services/enterprise/base.py

@@ -0,0 +1,20 @@
+import os
+
+import requests
+
+
+class EnterpriseRequest:
+    base_url = os.environ.get('ENTERPRISE_API_URL', 'ENTERPRISE_API_URL')
+    secret_key = os.environ.get('ENTERPRISE_API_SECRET_KEY', 'ENTERPRISE_API_SECRET_KEY')
+
+    @classmethod
+    def send_request(cls, method, endpoint, json=None, params=None):
+        headers = {
+            "Content-Type": "application/json",
+            "Enterprise-Api-Secret-Key": cls.secret_key
+        }
+
+        url = f"{cls.base_url}{endpoint}"
+        response = requests.request(method, url, json=json, params=params, headers=headers)
+
+        return response.json()

+ 28 - 0
api/services/enterprise/enterprise_feature_service.py

@@ -0,0 +1,28 @@
+from flask import current_app
+from pydantic import BaseModel
+
+from services.enterprise.enterprise_service import EnterpriseService
+
+
+class EnterpriseFeatureModel(BaseModel):
+    sso_enforced_for_signin: bool = False
+    sso_enforced_for_signin_protocol: str = ''
+
+
+class EnterpriseFeatureService:
+
+    @classmethod
+    def get_enterprise_features(cls) -> EnterpriseFeatureModel:
+        features = EnterpriseFeatureModel()
+
+        if current_app.config['ENTERPRISE_ENABLED']:
+            cls._fulfill_params_from_enterprise(features)
+
+        return features
+
+    @classmethod
+    def _fulfill_params_from_enterprise(cls, features):
+        enterprise_info = EnterpriseService.get_info()
+
+        features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin']
+        features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol']

+ 8 - 0
api/services/enterprise/enterprise_service.py

@@ -0,0 +1,8 @@
+from services.enterprise.base import EnterpriseRequest
+
+
+class EnterpriseService:
+
+    @classmethod
+    def get_info(cls):
+        return EnterpriseRequest.send_request('GET', '/info')

+ 60 - 0
api/services/enterprise/enterprise_sso_service.py

@@ -0,0 +1,60 @@
+import logging
+
+from models.account import Account, AccountStatus
+from services.account_service import AccountService, TenantService
+from services.enterprise.base import EnterpriseRequest
+
+logger = logging.getLogger(__name__)
+
+
+class EnterpriseSSOService:
+
+    @classmethod
+    def get_sso_saml_login(cls) -> str:
+        return EnterpriseRequest.send_request('GET', '/sso/saml/login')
+
+    @classmethod
+    def post_sso_saml_acs(cls, saml_response: str) -> str:
+        response = EnterpriseRequest.send_request('POST', '/sso/saml/acs', json={'SAMLResponse': saml_response})
+        if 'email' not in response or response['email'] is None:
+            logger.exception(response)
+            raise Exception('Saml response is invalid')
+
+        return cls.login_with_email(response.get('email'))
+
+    @classmethod
+    def get_sso_oidc_login(cls):
+        return EnterpriseRequest.send_request('GET', '/sso/oidc/login')
+
+    @classmethod
+    def get_sso_oidc_callback(cls, args: dict):
+        state_from_query = args['state']
+        code_from_query = args['code']
+        state_from_cookies = args['oidc-state']
+
+        if state_from_cookies != state_from_query:
+            raise Exception('invalid state or code')
+
+        response = EnterpriseRequest.send_request('GET', '/sso/oidc/callback', params={'code': code_from_query})
+        if 'email' not in response or response['email'] is None:
+            logger.exception(response)
+            raise Exception('OIDC response is invalid')
+
+        return cls.login_with_email(response.get('email'))
+
+    @classmethod
+    def login_with_email(cls, email: str) -> str:
+        account = Account.query.filter_by(email=email).first()
+        if account is None:
+            raise Exception('account not found, please contact system admin to invite you to join in a workspace')
+
+        if account.status == AccountStatus.BANNED:
+            raise Exception('account is banned, please contact system admin')
+
+        tenants = TenantService.get_join_tenants(account)
+        if len(tenants) == 0:
+            raise Exception("workspace not found, please contact system admin to invite you to join in a workspace")
+
+        token = AccountService.get_account_jwt_token(account)
+
+        return token

+ 4 - 0
web/app/components/header/account-dropdown/index.tsx

@@ -39,6 +39,10 @@ export default function AppSelector({ isMobile }: IAppSelecotr) {
       url: '/logout',
       params: {},
     })
+
+    if (localStorage?.getItem('console_token'))
+      localStorage.removeItem('console_token')
+
     router.push('/signin')
   }
 

+ 0 - 3
web/app/signin/_header.tsx

@@ -10,9 +10,6 @@ import LogoSite from '@/app/components/base/logo/logo-site'
 const Header = () => {
   const { locale, setLocaleOnClient } = useContext(I18n)
 
-  if (localStorage?.getItem('console_token'))
-    localStorage.removeItem('console_token')
-
   return <div className='flex items-center justify-between p-6 w-full'>
     <LogoSite />
     <Select

+ 87 - 0
web/app/signin/enterpriseSSOForm.tsx

@@ -0,0 +1,87 @@
+'use client'
+import cn from 'classnames'
+import { useRouter, useSearchParams } from 'next/navigation'
+import type { FC } from 'react'
+import { useEffect, useState } from 'react'
+import { useTranslation } from 'react-i18next'
+import Toast from '@/app/components/base/toast'
+import { getOIDCSSOUrl, getSAMLSSOUrl } from '@/service/enterprise'
+import Button from '@/app/components/base/button'
+
+type EnterpriseSSOFormProps = {
+  protocol: string
+}
+
+const EnterpriseSSOForm: FC<EnterpriseSSOFormProps> = ({
+  protocol,
+}) => {
+  const searchParams = useSearchParams()
+  const consoleToken = searchParams.get('console_token')
+  const message = searchParams.get('message')
+
+  const router = useRouter()
+  const { t } = useTranslation()
+
+  const [isLoading, setIsLoading] = useState(false)
+
+  useEffect(() => {
+    if (consoleToken) {
+      localStorage.setItem('console_token', consoleToken)
+      router.replace('/apps')
+    }
+
+    if (message) {
+      Toast.notify({
+        type: 'error',
+        message,
+      })
+    }
+  }, [])
+
+  const handleSSOLogin = () => {
+    setIsLoading(true)
+    if (protocol === 'saml') {
+      getSAMLSSOUrl().then((res) => {
+        router.push(res.url)
+      }).finally(() => {
+        setIsLoading(false)
+      })
+    }
+    else {
+      getOIDCSSOUrl().then((res) => {
+        document.cookie = `oidc-state=${res.state}`
+        router.push(res.url)
+      }).finally(() => {
+        setIsLoading(false)
+      })
+    }
+  }
+
+  return (
+    <div className={
+      cn(
+        'flex flex-col items-center w-full grow items-center justify-center',
+        'px-6',
+        'md:px-[108px]',
+      )
+    }>
+      <div className='flex flex-col md:w-[400px]'>
+        <div className="w-full mx-auto">
+          <h2 className="text-[32px] font-bold text-gray-900">{t('login.pageTitle')}</h2>
+        </div>
+        <div className="w-full mx-auto mt-10">
+          <Button
+            tabIndex={0}
+            type='primary'
+            onClick={() => { handleSSOLogin() }}
+            disabled={isLoading}
+            className="w-full !fone-medium !text-sm"
+          >{t('login.sso')}
+          </Button>
+        </div>
+      </div>
+    </div>
+  )
+}
+
+export default EnterpriseSSOForm

+ 11 - 2
web/app/signin/normalForm.tsx

@@ -96,8 +96,17 @@ const NormalForm = () => {
           remember_me: true,
         },
       })
-      localStorage.setItem('console_token', res.data)
-      router.replace('/apps')
+
+      if (res.result === 'success') {
+        localStorage.setItem('console_token', res.data)
+        router.replace('/apps')
+      }
+      else {
+        Toast.notify({
+          type: 'error',
+          message: res.data,
+        })
+      }
     }
     finally {
       setIsLoading(false)

+ 43 - 5
web/app/signin/page.tsx

@@ -1,12 +1,29 @@
-import React from 'react'
+'use client'
+import React, { useEffect, useState } from 'react'
 import cn from 'classnames'
 import Script from 'next/script'
+import Loading from '../components/base/loading'
 import Forms from './forms'
 import Header from './_header'
 import style from './page.module.css'
+import EnterpriseSSOForm from './enterpriseSSOForm'
 import { IS_CE_EDITION } from '@/config'
+import { getEnterpriseFeatures } from '@/service/enterprise'
+import type { EnterpriseFeatures } from '@/types/enterprise'
+import { defaultEnterpriseFeatures } from '@/types/enterprise'
 
 const SignIn = () => {
+  const [loading, setLoading] = useState<boolean>(true)
+  const [enterpriseFeatures, setEnterpriseFeatures] = useState<EnterpriseFeatures>(defaultEnterpriseFeatures)
+
+  useEffect(() => {
+    getEnterpriseFeatures().then((res) => {
+      setEnterpriseFeatures(res)
+    }).finally(() => {
+      setLoading(false)
+    })
+  }, [])
+
   return (
     <>
       {!IS_CE_EDITION && (
@@ -40,10 +57,31 @@ gtag('config', 'AW-11217955271"');
           )
         }>
           <Header />
-          <Forms />
-          <div className='px-8 py-6 text-sm font-normal text-gray-500'>
-            © {new Date().getFullYear()} LangGenius, Inc. All rights reserved.
-          </div>
+
+          {loading && (
+            <div className={
+              cn(
+                'flex flex-col items-center w-full grow items-center justify-center',
+                'px-6',
+                'md:px-[108px]',
+              )
+            }>
+              <Loading type='area' />
+            </div>
+          )}
+
+          {!loading && !enterpriseFeatures.sso_enforced_for_signin && (
+            <>
+              <Forms />
+              <div className='px-8 py-6 text-sm font-normal text-gray-500'>
+                © {new Date().getFullYear()} LangGenius, Inc. All rights reserved.
+              </div>
+            </>
+          )}
+
+          {!loading && enterpriseFeatures.sso_enforced_for_signin && (
+            <EnterpriseSSOForm protocol={enterpriseFeatures.sso_enforced_for_signin_protocol} />
+          )}
         </div>
 
       </div>

+ 1 - 0
web/i18n/en-US/login.ts

@@ -9,6 +9,7 @@ const translation = {
   namePlaceholder: 'Your username',
   forget: 'Forgot your password?',
   signBtn: 'Sign in',
+  sso: 'Continue with SSO',
   installBtn: 'Set up',
   setAdminAccount: 'Setting up an admin account',
   setAdminAccountDesc: 'Maximum privileges for admin account, which can be used to create applications and manage LLM providers, etc.',

+ 14 - 0
web/service/enterprise.ts

@@ -0,0 +1,14 @@
+import { get } from './base'
+import type { EnterpriseFeatures } from '@/types/enterprise'
+
+export const getEnterpriseFeatures = () => {
+  return get<EnterpriseFeatures>('/enterprise-features')
+}
+
+export const getSAMLSSOUrl = () => {
+  return get<{ url: string }>('/enterprise/sso/saml/login')
+}
+
+export const getOIDCSSOUrl = () => {
+  return get<{ url: string; state: string }>('/enterprise/sso/oidc/login')
+}

+ 9 - 0
web/types/enterprise.ts

@@ -0,0 +1,9 @@
+export type EnterpriseFeatures = {
+  sso_enforced_for_signin: boolean
+  sso_enforced_for_signin_protocol: string
+}
+
+export const defaultEnterpriseFeatures: EnterpriseFeatures = {
+  sso_enforced_for_signin: false,
+  sso_enforced_for_signin_protocol: '',
+}