Browse Source

chore(api/services): apply ruff reformatting (#7599)

Co-authored-by: -LAN- <laipz8200@outlook.com>
Bowen Liang 7 months ago
parent
commit
17fd773a30
49 changed files with 2426 additions and 2468 deletions
  1. 0 1
      api/pyproject.toml
  2. 1 1
      api/services/__init__.py
  3. 138 150
      api/services/account_service.py
  4. 41 23
      api/services/advanced_prompt_template_service.py
  5. 71 61
      api/services/agent_service.py
  6. 193 179
      api/services/annotation_service.py
  7. 23 16
      api/services/api_based_extension_service.py
  8. 92 93
      api/services/app_dsl_service.py
  9. 66 73
      api/services/app_generate_service.py
  10. 0 1
      api/services/app_model_config_service.py
  11. 77 88
      api/services/app_service.py
  12. 26 30
      api/services/audio_service.py
  13. 2 4
      api/services/auth/api_key_auth_factory.py
  14. 36 32
      api/services/auth/api_key_auth_service.py
  15. 16 25
      api/services/auth/firecrawl.py
  16. 23 40
      api/services/billing_service.py
  17. 9 6
      api/services/code_based_extension_service.py
  18. 46 33
      api/services/conversation_service.py
  19. 263 278
      api/services/dataset_service.py
  20. 3 6
      api/services/enterprise/base.py
  21. 2 3
      api/services/enterprise/enterprise_service.py
  22. 18 20
      api/services/entities/model_provider_entities.py
  23. 0 1
      api/services/errors/account.py
  24. 1 1
      api/services/errors/base.py
  25. 34 34
      api/services/feature_service.py
  26. 40 28
      api/services/file_service.py
  27. 40 39
      api/services/hit_testing_service.py
  28. 99 93
      api/services/message_service.py
  29. 119 110
      api/services/model_load_balancing_service.py
  30. 79 121
      api/services/model_provider_service.py
  31. 5 4
      api/services/moderation_service.py
  32. 10 13
      api/services/operation_service.py
  33. 33 19
      api/services/ops_service.py
  34. 62 63
      api/services/recommended_app_service.py
  35. 37 31
      api/services/saved_message_service.py
  36. 58 62
      api/services/tag_service.py
  37. 178 171
      api/services/tools/api_tools_manage_service.py
  38. 80 72
      api/services/tools/builtin_tools_manage_service.py
  39. 1 1
      api/services/tools/tool_labels_service.py
  40. 3 6
      api/services/tools/tools_manage_service.py
  41. 47 63
      api/services/tools/tools_transform_service.py
  42. 127 120
      api/services/tools/workflow_tools_manage_service.py
  43. 9 13
      api/services/vector_service.py
  44. 44 27
      api/services/web_conversation_service.py
  45. 59 94
      api/services/website_service.py
  46. 9 23
      api/services/workflow_app_service.py
  47. 39 31
      api/services/workflow_run_service.py
  48. 37 40
      api/services/workflow_service.py
  49. 30 25
      api/services/workspace_service.py

+ 0 - 1
api/pyproject.toml

@@ -74,7 +74,6 @@ exclude = [
     "controllers/**/*.py",
     "models/**/*.py",
     "migrations/**/*",
-    "services/**/*.py",
 ]
 
 [tool.pytest_env]

+ 1 - 1
api/services/__init__.py

@@ -1,3 +1,3 @@
 from . import errors
 
-__all__ = ['errors']
+__all__ = ["errors"]

+ 138 - 150
api/services/account_service.py

@@ -39,12 +39,7 @@ from tasks.mail_reset_password_task import send_reset_password_mail_task
 
 
 class AccountService:
-
-    reset_password_rate_limiter = RateLimiter(
-        prefix="reset_password_rate_limit",
-        max_attempts=5,
-        time_window=60 * 60
-    )
+    reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=5, time_window=60 * 60)
 
     @staticmethod
     def load_user(user_id: str) -> None | Account:
@@ -55,12 +50,15 @@ class AccountService:
         if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]:
             raise Unauthorized("Account is banned or closed.")
 
-        current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
+        current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by(
+            account_id=account.id, current=True
+        ).first()
         if current_tenant:
             account.current_tenant_id = current_tenant.tenant_id
         else:
-            available_ta = TenantAccountJoin.query.filter_by(account_id=account.id) \
-                .order_by(TenantAccountJoin.id.asc()).first()
+            available_ta = (
+                TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
+            )
             if not available_ta:
                 return None
 
@@ -74,14 +72,13 @@ class AccountService:
 
         return account
 
-
     @staticmethod
     def get_account_jwt_token(account, *, exp: timedelta = timedelta(days=30)):
         payload = {
             "user_id": account.id,
             "exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp,
             "iss": dify_config.EDITION,
-            "sub": 'Console API Passport',
+            "sub": "Console API Passport",
         }
 
         token = PassportService().issue(payload)
@@ -93,10 +90,10 @@ class AccountService:
 
         account = Account.query.filter_by(email=email).first()
         if not account:
-            raise AccountLoginError('Invalid email or password.')
+            raise AccountLoginError("Invalid email or password.")
 
         if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
-            raise AccountLoginError('Account is banned or closed.')
+            raise AccountLoginError("Account is banned or closed.")
 
         if account.status == AccountStatus.PENDING.value:
             account.status = AccountStatus.ACTIVE.value
@@ -104,7 +101,7 @@ class AccountService:
             db.session.commit()
 
         if account.password is None or not compare_password(password, account.password, account.password_salt):
-            raise AccountLoginError('Invalid email or password.')
+            raise AccountLoginError("Invalid email or password.")
         return account
 
     @staticmethod
@@ -129,11 +126,9 @@ class AccountService:
         return account
 
     @staticmethod
-    def create_account(email: str,
-                       name: str,
-                       interface_language: str,
-                       password: Optional[str] = None,
-                       interface_theme: str = 'light') -> Account:
+    def create_account(
+        email: str, name: str, interface_language: str, password: Optional[str] = None, interface_theme: str = "light"
+    ) -> Account:
         """create account"""
         account = Account()
         account.email = email
@@ -155,7 +150,7 @@ class AccountService:
         account.interface_theme = interface_theme
 
         # Set timezone based on language
-        account.timezone = language_timezone_mapping.get(interface_language, 'UTC')
+        account.timezone = language_timezone_mapping.get(interface_language, "UTC")
 
         db.session.add(account)
         db.session.commit()
@@ -166,8 +161,9 @@ class AccountService:
         """Link account integrate"""
         try:
             # Query whether there is an existing binding record for the same provider
-            account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(account_id=account.id,
-                                                                                             provider=provider).first()
+            account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(
+                account_id=account.id, provider=provider
+            ).first()
 
             if account_integrate:
                 # If it exists, update the record
@@ -176,15 +172,16 @@ class AccountService:
                 account_integrate.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
             else:
                 # If it does not exist, create a new record
-                account_integrate = AccountIntegrate(account_id=account.id, provider=provider, open_id=open_id,
-                                                     encrypted_token="")
+                account_integrate = AccountIntegrate(
+                    account_id=account.id, provider=provider, open_id=open_id, encrypted_token=""
+                )
                 db.session.add(account_integrate)
 
             db.session.commit()
-            logging.info(f'Account {account.id} linked {provider} account {open_id}.')
+            logging.info(f"Account {account.id} linked {provider} account {open_id}.")
         except Exception as e:
-            logging.exception(f'Failed to link {provider} account {open_id} to Account {account.id}')
-            raise LinkAccountIntegrateError('Failed to link account.') from e
+            logging.exception(f"Failed to link {provider} account {open_id} to Account {account.id}")
+            raise LinkAccountIntegrateError("Failed to link account.") from e
 
     @staticmethod
     def close_account(account: Account) -> None:
@@ -218,7 +215,7 @@ class AccountService:
             AccountService.update_last_login(account, ip_address=ip_address)
         exp = timedelta(days=30)
         token = AccountService.get_account_jwt_token(account, exp=exp)
-        redis_client.set(_get_login_cache_key(account_id=account.id, token=token), '1', ex=int(exp.total_seconds()))
+        redis_client.set(_get_login_cache_key(account_id=account.id, token=token), "1", ex=int(exp.total_seconds()))
         return token
 
     @staticmethod
@@ -236,22 +233,18 @@ class AccountService:
         if cls.reset_password_rate_limiter.is_rate_limited(account.email):
             raise RateLimitExceededError(f"Rate limit exceeded for email: {account.email}. Please try again later.")
 
-        token = TokenManager.generate_token(account, 'reset_password')
-        send_reset_password_mail_task.delay(
-            language=account.interface_language,
-            to=account.email,
-            token=token
-        )
+        token = TokenManager.generate_token(account, "reset_password")
+        send_reset_password_mail_task.delay(language=account.interface_language, to=account.email, token=token)
         cls.reset_password_rate_limiter.increment_rate_limit(account.email)
         return token
 
     @classmethod
     def revoke_reset_password_token(cls, token: str):
-        TokenManager.revoke_token(token, 'reset_password')
+        TokenManager.revoke_token(token, "reset_password")
 
     @classmethod
     def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]:
-        return TokenManager.get_token_data(token, 'reset_password')
+        return TokenManager.get_token_data(token, "reset_password")
 
 
 def _get_login_cache_key(*, account_id: str, token: str):
@@ -259,7 +252,6 @@ def _get_login_cache_key(*, account_id: str, token: str):
 
 
 class TenantService:
-
     @staticmethod
     def create_tenant(name: str) -> Tenant:
         """Create tenant"""
@@ -275,31 +267,28 @@ class TenantService:
     @staticmethod
     def create_owner_tenant_if_not_exist(account: Account):
         """Create owner tenant if not exist"""
-        available_ta = TenantAccountJoin.query.filter_by(account_id=account.id) \
-            .order_by(TenantAccountJoin.id.asc()).first()
+        available_ta = (
+            TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
+        )
 
         if available_ta:
             return
 
         tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
-        TenantService.create_tenant_member(tenant, account, role='owner')
+        TenantService.create_tenant_member(tenant, account, role="owner")
         account.current_tenant = tenant
         db.session.commit()
         tenant_was_created.send(tenant)
 
     @staticmethod
-    def create_tenant_member(tenant: Tenant, account: Account, role: str = 'normal') -> TenantAccountJoin:
+    def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin:
         """Create tenant member"""
         if role == TenantAccountJoinRole.OWNER.value:
             if TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER]):
-                logging.error(f'Tenant {tenant.id} has already an owner.')
-                raise Exception('Tenant already has an owner.')
+                logging.error(f"Tenant {tenant.id} has already an owner.")
+                raise Exception("Tenant already has an owner.")
 
-        ta = TenantAccountJoin(
-            tenant_id=tenant.id,
-            account_id=account.id,
-            role=role
-        )
+        ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role)
         db.session.add(ta)
         db.session.commit()
         return ta
@@ -307,9 +296,12 @@ class TenantService:
     @staticmethod
     def get_join_tenants(account: Account) -> list[Tenant]:
         """Get account join tenants"""
-        return db.session.query(Tenant).join(
-            TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id
-        ).filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL).all()
+        return (
+            db.session.query(Tenant)
+            .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
+            .filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
+            .all()
+        )
 
     @staticmethod
     def get_current_tenant_by_account(account: Account):
@@ -333,16 +325,23 @@ class TenantService:
         if tenant_id is None:
             raise ValueError("Tenant ID must be provided.")
 
-        tenant_account_join = db.session.query(TenantAccountJoin).join(Tenant, TenantAccountJoin.tenant_id == Tenant.id).filter(
-            TenantAccountJoin.account_id == account.id,
-            TenantAccountJoin.tenant_id == tenant_id,
-            Tenant.status == TenantStatus.NORMAL,
-        ).first()
+        tenant_account_join = (
+            db.session.query(TenantAccountJoin)
+            .join(Tenant, TenantAccountJoin.tenant_id == Tenant.id)
+            .filter(
+                TenantAccountJoin.account_id == account.id,
+                TenantAccountJoin.tenant_id == tenant_id,
+                Tenant.status == TenantStatus.NORMAL,
+            )
+            .first()
+        )
 
         if not tenant_account_join:
             raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
         else:
-            TenantAccountJoin.query.filter(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id).update({'current': False})
+            TenantAccountJoin.query.filter(
+                TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id
+            ).update({"current": False})
             tenant_account_join.current = True
             # Set the current tenant for the account
             account.current_tenant_id = tenant_account_join.tenant_id
@@ -354,9 +353,7 @@ class TenantService:
         query = (
             db.session.query(Account, TenantAccountJoin.role)
             .select_from(Account)
-            .join(
-                TenantAccountJoin, Account.id == TenantAccountJoin.account_id
-            )
+            .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
             .filter(TenantAccountJoin.tenant_id == tenant.id)
         )
 
@@ -375,11 +372,9 @@ class TenantService:
         query = (
             db.session.query(Account, TenantAccountJoin.role)
             .select_from(Account)
-            .join(
-                TenantAccountJoin, Account.id == TenantAccountJoin.account_id
-            )
+            .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
             .filter(TenantAccountJoin.tenant_id == tenant.id)
-            .filter(TenantAccountJoin.role == 'dataset_operator')
+            .filter(TenantAccountJoin.role == "dataset_operator")
         )
 
         # Initialize an empty list to store the updated accounts
@@ -395,20 +390,25 @@ class TenantService:
     def has_roles(tenant: Tenant, roles: list[TenantAccountJoinRole]) -> bool:
         """Check if user has any of the given roles for a tenant"""
         if not all(isinstance(role, TenantAccountJoinRole) for role in roles):
-            raise ValueError('all roles must be TenantAccountJoinRole')
+            raise ValueError("all roles must be TenantAccountJoinRole")
 
-        return db.session.query(TenantAccountJoin).filter(
-            TenantAccountJoin.tenant_id == tenant.id,
-            TenantAccountJoin.role.in_([role.value for role in roles])
-        ).first() is not None
+        return (
+            db.session.query(TenantAccountJoin)
+            .filter(
+                TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles])
+            )
+            .first()
+            is not None
+        )
 
     @staticmethod
     def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoinRole]:
         """Get the role of the current account for a given tenant"""
-        join = db.session.query(TenantAccountJoin).filter(
-            TenantAccountJoin.tenant_id == tenant.id,
-            TenantAccountJoin.account_id == account.id
-        ).first()
+        join = (
+            db.session.query(TenantAccountJoin)
+            .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
+            .first()
+        )
         return join.role if join else None
 
     @staticmethod
@@ -420,29 +420,26 @@ class TenantService:
     def check_member_permission(tenant: Tenant, operator: Account, member: Account, action: str) -> None:
         """Check member permission"""
         perms = {
-            'add': [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
-            'remove': [TenantAccountRole.OWNER],
-            'update': [TenantAccountRole.OWNER]
+            "add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
+            "remove": [TenantAccountRole.OWNER],
+            "update": [TenantAccountRole.OWNER],
         }
-        if action not in ['add', 'remove', 'update']:
+        if action not in ["add", "remove", "update"]:
             raise InvalidActionError("Invalid action.")
 
         if member:
             if operator.id == member.id:
                 raise CannotOperateSelfError("Cannot operate self.")
 
-        ta_operator = TenantAccountJoin.query.filter_by(
-            tenant_id=tenant.id,
-            account_id=operator.id
-        ).first()
+        ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first()
 
         if not ta_operator or ta_operator.role not in perms[action]:
-            raise NoPermissionError(f'No permission to {action} member.')
+            raise NoPermissionError(f"No permission to {action} member.")
 
     @staticmethod
     def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
         """Remove member from tenant"""
-        if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, 'remove'):
+        if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, "remove"):
             raise CannotOperateSelfError("Cannot operate self.")
 
         ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
@@ -455,23 +452,17 @@ class TenantService:
     @staticmethod
     def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None:
         """Update member role"""
-        TenantService.check_member_permission(tenant, operator, member, 'update')
+        TenantService.check_member_permission(tenant, operator, member, "update")
 
-        target_member_join = TenantAccountJoin.query.filter_by(
-            tenant_id=tenant.id,
-            account_id=member.id
-        ).first()
+        target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first()
 
         if target_member_join.role == new_role:
             raise RoleAlreadyAssignedError("The provided role is already assigned to the member.")
 
-        if new_role == 'owner':
+        if new_role == "owner":
             # Find the current owner and change their role to 'admin'
-            current_owner_join = TenantAccountJoin.query.filter_by(
-                tenant_id=tenant.id,
-                role='owner'
-            ).first()
-            current_owner_join.role = 'admin'
+            current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
+            current_owner_join.role = "admin"
 
         # Update the role of the target member
         target_member_join.role = new_role
@@ -480,8 +471,8 @@ class TenantService:
     @staticmethod
     def dissolve_tenant(tenant: Tenant, operator: Account) -> None:
         """Dissolve tenant"""
-        if not TenantService.check_member_permission(tenant, operator, operator, 'remove'):
-            raise NoPermissionError('No permission to dissolve tenant.')
+        if not TenantService.check_member_permission(tenant, operator, operator, "remove"):
+            raise NoPermissionError("No permission to dissolve tenant.")
         db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
         db.session.delete(tenant)
         db.session.commit()
@@ -494,10 +485,9 @@ class TenantService:
 
 
 class RegisterService:
-
     @classmethod
     def _get_invitation_token_key(cls, token: str) -> str:
-        return f'member_invite:token:{token}'
+        return f"member_invite:token:{token}"
 
     @classmethod
     def setup(cls, email: str, name: str, password: str, ip_address: str) -> None:
@@ -523,9 +513,7 @@ class RegisterService:
 
             TenantService.create_owner_tenant_if_not_exist(account)
 
-            dify_setup = DifySetup(
-                version=dify_config.CURRENT_VERSION
-            )
+            dify_setup = DifySetup(version=dify_config.CURRENT_VERSION)
             db.session.add(dify_setup)
             db.session.commit()
         except Exception as e:
@@ -535,34 +523,35 @@ class RegisterService:
             db.session.query(Tenant).delete()
             db.session.commit()
 
-            logging.exception(f'Setup failed: {e}')
-            raise ValueError(f'Setup failed: {e}')
+            logging.exception(f"Setup failed: {e}")
+            raise ValueError(f"Setup failed: {e}")
 
     @classmethod
-    def register(cls, email, name,
-                 password: Optional[str] = None,
-                 open_id: Optional[str] = None,
-                 provider: Optional[str] = None,
-                 language: Optional[str] = None,
-                 status: Optional[AccountStatus] = None) -> Account:
+    def register(
+        cls,
+        email,
+        name,
+        password: Optional[str] = None,
+        open_id: Optional[str] = None,
+        provider: Optional[str] = None,
+        language: Optional[str] = None,
+        status: Optional[AccountStatus] = None,
+    ) -> Account:
         db.session.begin_nested()
         """Register account"""
         try:
             account = AccountService.create_account(
-                email=email,
-                name=name,
-                interface_language=language if language else languages[0],
-                password=password
+                email=email, name=name, interface_language=language if language else languages[0], password=password
             )
             account.status = AccountStatus.ACTIVE.value if not status else status.value
             account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None)
 
             if open_id is not None or provider is not None:
                 AccountService.link_account_integrate(provider, open_id, account)
-            if dify_config.EDITION != 'SELF_HOSTED':
+            if dify_config.EDITION != "SELF_HOSTED":
                 tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
 
-                TenantService.create_tenant_member(tenant, account, role='owner')
+                TenantService.create_tenant_member(tenant, account, role="owner")
                 account.current_tenant = tenant
 
                 tenant_was_created.send(tenant)
@@ -570,30 +559,29 @@ class RegisterService:
             db.session.commit()
         except Exception as e:
             db.session.rollback()
-            logging.error(f'Register failed: {e}')
-            raise AccountRegisterError(f'Registration failed: {e}') from e
+            logging.error(f"Register failed: {e}")
+            raise AccountRegisterError(f"Registration failed: {e}") from e
 
         return account
 
     @classmethod
-    def invite_new_member(cls, tenant: Tenant, email: str, language: str, role: str = 'normal', inviter: Account = None) -> str:
+    def invite_new_member(
+        cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account = None
+    ) -> str:
         """Invite new member"""
         account = Account.query.filter_by(email=email).first()
 
         if not account:
-            TenantService.check_member_permission(tenant, inviter, None, 'add')
-            name = email.split('@')[0]
+            TenantService.check_member_permission(tenant, inviter, None, "add")
+            name = email.split("@")[0]
 
             account = cls.register(email=email, name=name, language=language, status=AccountStatus.PENDING)
             # Create new tenant member for invited tenant
             TenantService.create_tenant_member(tenant, account, role)
             TenantService.switch_tenant(account, tenant.id)
         else:
-            TenantService.check_member_permission(tenant, inviter, account, 'add')
-            ta = TenantAccountJoin.query.filter_by(
-                tenant_id=tenant.id,
-                account_id=account.id
-            ).first()
+            TenantService.check_member_permission(tenant, inviter, account, "add")
+            ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
 
             if not ta:
                 TenantService.create_tenant_member(tenant, account, role)
@@ -609,7 +597,7 @@ class RegisterService:
             language=account.interface_language,
             to=email,
             token=token,
-            inviter_name=inviter.name if inviter else 'Dify',
+            inviter_name=inviter.name if inviter else "Dify",
             workspace_name=tenant.name,
         )
 
@@ -619,23 +607,19 @@ class RegisterService:
     def generate_invite_token(cls, tenant: Tenant, account: Account) -> str:
         token = str(uuid.uuid4())
         invitation_data = {
-            'account_id': account.id,
-            'email': account.email,
-            'workspace_id': tenant.id,
+            "account_id": account.id,
+            "email": account.email,
+            "workspace_id": tenant.id,
         }
         expiryHours = dify_config.INVITE_EXPIRY_HOURS
-        redis_client.setex(
-            cls._get_invitation_token_key(token),
-            expiryHours * 60 * 60,
-            json.dumps(invitation_data)
-        )
+        redis_client.setex(cls._get_invitation_token_key(token), expiryHours * 60 * 60, json.dumps(invitation_data))
         return token
 
     @classmethod
     def revoke_token(cls, workspace_id: str, email: str, token: str):
         if workspace_id and email:
             email_hash = sha256(email.encode()).hexdigest()
-            cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token)
+            cache_key = "member_invite_token:{}, {}:{}".format(workspace_id, email_hash, token)
             redis_client.delete(cache_key)
         else:
             redis_client.delete(cls._get_invitation_token_key(token))
@@ -646,17 +630,21 @@ class RegisterService:
         if not invitation_data:
             return None
 
-        tenant = db.session.query(Tenant).filter(
-            Tenant.id == invitation_data['workspace_id'],
-            Tenant.status == 'normal'
-        ).first()
+        tenant = (
+            db.session.query(Tenant)
+            .filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal")
+            .first()
+        )
 
         if not tenant:
             return None
 
-        tenant_account = db.session.query(Account, TenantAccountJoin.role).join(
-            TenantAccountJoin, Account.id == TenantAccountJoin.account_id
-        ).filter(Account.email == invitation_data['email'], TenantAccountJoin.tenant_id == tenant.id).first()
+        tenant_account = (
+            db.session.query(Account, TenantAccountJoin.role)
+            .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
+            .filter(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id)
+            .first()
+        )
 
         if not tenant_account:
             return None
@@ -665,29 +653,29 @@ class RegisterService:
         if not account:
             return None
 
-        if invitation_data['account_id'] != str(account.id):
+        if invitation_data["account_id"] != str(account.id):
             return None
 
         return {
-            'account': account,
-            'data': invitation_data,
-            'tenant': tenant,
+            "account": account,
+            "data": invitation_data,
+            "tenant": tenant,
         }
 
     @classmethod
     def _get_invitation_by_token(cls, token: str, workspace_id: str, email: str) -> Optional[dict[str, str]]:
         if workspace_id is not None and email is not None:
             email_hash = sha256(email.encode()).hexdigest()
-            cache_key = f'member_invite_token:{workspace_id}, {email_hash}:{token}'
+            cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}"
             account_id = redis_client.get(cache_key)
 
             if not account_id:
                 return None
 
             return {
-                'account_id': account_id.decode('utf-8'),
-                'email': email,
-                'workspace_id': workspace_id,
+                "account_id": account_id.decode("utf-8"),
+                "email": email,
+                "workspace_id": workspace_id,
             }
         else:
             data = redis_client.get(cls._get_invitation_token_key(token))

+ 41 - 23
api/services/advanced_prompt_template_service.py

@@ -1,4 +1,3 @@
-
 import copy
 
 from core.prompt.prompt_templates.advanced_prompt_templates import (
@@ -17,59 +16,78 @@ from models.model import AppMode
 
 
 class AdvancedPromptTemplateService:
-
     @classmethod
     def get_prompt(cls, args: dict) -> dict:
-        app_mode = args['app_mode']
-        model_mode = args['model_mode']
-        model_name = args['model_name']
-        has_context = args['has_context']
+        app_mode = args["app_mode"]
+        model_mode = args["model_mode"]
+        model_name = args["model_name"]
+        has_context = args["has_context"]
 
-        if 'baichuan' in model_name.lower():
+        if "baichuan" in model_name.lower():
             return cls.get_baichuan_prompt(app_mode, model_mode, has_context)
         else:
             return cls.get_common_prompt(app_mode, model_mode, has_context)
 
     @classmethod
-    def get_common_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
+    def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
         context_prompt = copy.deepcopy(CONTEXT)
 
         if app_mode == AppMode.CHAT.value:
             if model_mode == "completion":
-                return cls.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
+                return cls.get_completion_prompt(
+                    copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
+                )
             elif model_mode == "chat":
                 return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
         elif app_mode == AppMode.COMPLETION.value:
             if model_mode == "completion":
-                return cls.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
+                return cls.get_completion_prompt(
+                    copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
+                )
             elif model_mode == "chat":
-                return cls.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
-            
+                return cls.get_chat_prompt(
+                    copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt
+                )
+
     @classmethod
     def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
-        if has_context == 'true':
-            prompt_template['completion_prompt_config']['prompt']['text'] = context + prompt_template['completion_prompt_config']['prompt']['text']
-        
+        if has_context == "true":
+            prompt_template["completion_prompt_config"]["prompt"]["text"] = (
+                context + prompt_template["completion_prompt_config"]["prompt"]["text"]
+            )
+
         return prompt_template
 
     @classmethod
     def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
-        if has_context == 'true':
-            prompt_template['chat_prompt_config']['prompt'][0]['text'] = context + prompt_template['chat_prompt_config']['prompt'][0]['text']
-        
+        if has_context == "true":
+            prompt_template["chat_prompt_config"]["prompt"][0]["text"] = (
+                context + prompt_template["chat_prompt_config"]["prompt"][0]["text"]
+            )
+
         return prompt_template
 
     @classmethod
-    def get_baichuan_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
+    def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
         baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
 
         if app_mode == AppMode.CHAT.value:
             if model_mode == "completion":
-                return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
+                return cls.get_completion_prompt(
+                    copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt
+                )
             elif model_mode == "chat":
-                return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
+                return cls.get_chat_prompt(
+                    copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
+                )
         elif app_mode == AppMode.COMPLETION.value:
             if model_mode == "completion":
-                return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
+                return cls.get_completion_prompt(
+                    copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG),
+                    has_context,
+                    baichuan_context_prompt,
+                )
             elif model_mode == "chat":
-                return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
+                return cls.get_chat_prompt(
+                    copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
+                )

+ 71 - 61
api/services/agent_service.py

@@ -10,59 +10,65 @@ from models.model import App, Conversation, EndUser, Message, MessageAgentThough
 
 class AgentService:
     @classmethod
-    def get_agent_logs(cls, app_model: App, 
-                       conversation_id: str,
-                       message_id: str) -> dict:
+    def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) -> dict:
         """
         Service to get agent logs
         """
-        conversation: Conversation = db.session.query(Conversation).filter(
-            Conversation.id == conversation_id,
-            Conversation.app_id == app_model.id,
-        ).first()
+        conversation: Conversation = (
+            db.session.query(Conversation)
+            .filter(
+                Conversation.id == conversation_id,
+                Conversation.app_id == app_model.id,
+            )
+            .first()
+        )
 
         if not conversation:
             raise ValueError(f"Conversation not found: {conversation_id}")
 
-        message: Message = db.session.query(Message).filter(
-            Message.id == message_id,
-            Message.conversation_id == conversation_id,
-        ).first()
+        message: Message = (
+            db.session.query(Message)
+            .filter(
+                Message.id == message_id,
+                Message.conversation_id == conversation_id,
+            )
+            .first()
+        )
 
         if not message:
             raise ValueError(f"Message not found: {message_id}")
-        
+
         agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
 
         if conversation.from_end_user_id:
             # only select name field
-            executor = db.session.query(EndUser, EndUser.name).filter(
-                EndUser.id == conversation.from_end_user_id
-            ).first()
+            executor = (
+                db.session.query(EndUser, EndUser.name).filter(EndUser.id == conversation.from_end_user_id).first()
+            )
         else:
-            executor = db.session.query(Account, Account.name).filter(
-                Account.id == conversation.from_account_id
-            ).first()
-        
+            executor = (
+                db.session.query(Account, Account.name).filter(Account.id == conversation.from_account_id).first()
+            )
+
         if executor:
             executor = executor.name
         else:
-            executor = 'Unknown'
+            executor = "Unknown"
 
         timezone = pytz.timezone(current_user.timezone)
 
         result = {
-            'meta': {
-                'status': 'success',
-                'executor': executor,
-                'start_time': message.created_at.astimezone(timezone).isoformat(),
-                'elapsed_time': message.provider_response_latency,
-                'total_tokens': message.answer_tokens + message.message_tokens,
-                'agent_mode': app_model.app_model_config.agent_mode_dict.get('strategy', 'react'),
-                'iterations': len(agent_thoughts),
+            "meta": {
+                "status": "success",
+                "executor": executor,
+                "start_time": message.created_at.astimezone(timezone).isoformat(),
+                "elapsed_time": message.provider_response_latency,
+                "total_tokens": message.answer_tokens + message.message_tokens,
+                "agent_mode": app_model.app_model_config.agent_mode_dict.get("strategy", "react"),
+                "iterations": len(agent_thoughts),
             },
-            'iterations': [],
-            'files': message.files,
+            "iterations": [],
+            "files": message.files,
         }
 
         agent_config = AgentConfigManager.convert(app_model.app_model_config.to_dict())
@@ -86,12 +92,12 @@ class AgentService:
                 tool_input = tool_inputs.get(tool_name, {})
                 tool_output = tool_outputs.get(tool_name, {})
                 tool_meta_data = tool_meta.get(tool_name, {})
-                tool_config = tool_meta_data.get('tool_config', {})
-                if tool_config.get('tool_provider_type', '') != 'dataset-retrieval':
+                tool_config = tool_meta_data.get("tool_config", {})
+                if tool_config.get("tool_provider_type", "") != "dataset-retrieval":
                     tool_icon = ToolManager.get_tool_icon(
                         tenant_id=app_model.tenant_id,
-                        provider_type=tool_config.get('tool_provider_type', ''),
-                        provider_id=tool_config.get('tool_provider', ''),
+                        provider_type=tool_config.get("tool_provider_type", ""),
+                        provider_id=tool_config.get("tool_provider", ""),
                     )
                     if not tool_icon:
                         tool_entity = find_agent_tool(tool_name)
@@ -102,30 +108,34 @@ class AgentService:
                                 provider_id=tool_entity.provider_id,
                             )
                 else:
-                    tool_icon = ''
-
-                tool_calls.append({
-                    'status': 'success' if not tool_meta_data.get('error') else 'error',
-                    'error': tool_meta_data.get('error'),
-                    'time_cost': tool_meta_data.get('time_cost', 0),
-                    'tool_name': tool_name,
-                    'tool_label': tool_label,
-                    'tool_input': tool_input,
-                    'tool_output': tool_output,
-                    'tool_parameters': tool_meta_data.get('tool_parameters', {}),
-                    'tool_icon': tool_icon,
-                })
-
-            result['iterations'].append({
-                'tokens': agent_thought.tokens,
-                'tool_calls': tool_calls,
-                'tool_raw': {
-                    'inputs': agent_thought.tool_input,
-                    'outputs': agent_thought.observation,
-                },
-                'thought': agent_thought.thought,
-                'created_at': agent_thought.created_at.isoformat(),
-                'files': agent_thought.files,
-            })
-
-        return result
+                    tool_icon = ""
+
+                tool_calls.append(
+                    {
+                        "status": "success" if not tool_meta_data.get("error") else "error",
+                        "error": tool_meta_data.get("error"),
+                        "time_cost": tool_meta_data.get("time_cost", 0),
+                        "tool_name": tool_name,
+                        "tool_label": tool_label,
+                        "tool_input": tool_input,
+                        "tool_output": tool_output,
+                        "tool_parameters": tool_meta_data.get("tool_parameters", {}),
+                        "tool_icon": tool_icon,
+                    }
+                )
+
+            result["iterations"].append(
+                {
+                    "tokens": agent_thought.tokens,
+                    "tool_calls": tool_calls,
+                    "tool_raw": {
+                        "inputs": agent_thought.tool_input,
+                        "outputs": agent_thought.observation,
+                    },
+                    "thought": agent_thought.thought,
+                    "created_at": agent_thought.created_at.isoformat(),
+                    "files": agent_thought.files,
+                }
+            )
+
+        return result

+ 193 - 179
api/services/annotation_service.py

@@ -23,21 +23,18 @@ class AppAnnotationService:
     @classmethod
     def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
         # get app info
-        app = db.session.query(App).filter(
-            App.id == app_id,
-            App.tenant_id == current_user.current_tenant_id,
-            App.status == 'normal'
-        ).first()
+        app = (
+            db.session.query(App)
+            .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
+            .first()
+        )
 
         if not app:
             raise NotFound("App not found")
-        if args.get('message_id'):
-            message_id = str(args['message_id'])
+        if args.get("message_id"):
+            message_id = str(args["message_id"])
             # get message info
-            message = db.session.query(Message).filter(
-                Message.id == message_id,
-                Message.app_id == app.id
-            ).first()
+            message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app.id).first()
 
             if not message:
                 raise NotFound("Message Not Exists.")
@@ -45,159 +42,166 @@ class AppAnnotationService:
             annotation = message.annotation
             # save the message annotation
             if annotation:
-                annotation.content = args['answer']
-                annotation.question = args['question']
+                annotation.content = args["answer"]
+                annotation.question = args["question"]
             else:
                 annotation = MessageAnnotation(
                     app_id=app.id,
                     conversation_id=message.conversation_id,
                     message_id=message.id,
-                    content=args['answer'],
-                    question=args['question'],
-                    account_id=current_user.id
+                    content=args["answer"],
+                    question=args["question"],
+                    account_id=current_user.id,
                 )
         else:
             annotation = MessageAnnotation(
-                app_id=app.id,
-                content=args['answer'],
-                question=args['question'],
-                account_id=current_user.id
+                app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id
             )
         db.session.add(annotation)
         db.session.commit()
         # if annotation reply is enabled , add annotation to index
-        annotation_setting = db.session.query(AppAnnotationSetting).filter(
-            AppAnnotationSetting.app_id == app_id).first()
+        annotation_setting = (
+            db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
+        )
         if annotation_setting:
-            add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id,
-                                               app_id, annotation_setting.collection_binding_id)
+            add_annotation_to_index_task.delay(
+                annotation.id,
+                args["question"],
+                current_user.current_tenant_id,
+                app_id,
+                annotation_setting.collection_binding_id,
+            )
         return annotation
 
     @classmethod
     def enable_app_annotation(cls, args: dict, app_id: str) -> dict:
-        enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id))
+        enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id))
         cache_result = redis_client.get(enable_app_annotation_key)
         if cache_result is not None:
-            return {
-                'job_id': cache_result,
-                'job_status': 'processing'
-            }
+            return {"job_id": cache_result, "job_status": "processing"}
 
         # async job
         job_id = str(uuid.uuid4())
-        enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id))
+        enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id))
         # send batch add segments task
-        redis_client.setnx(enable_app_annotation_job_key, 'waiting')
-        enable_annotation_reply_task.delay(str(job_id), app_id, current_user.id, current_user.current_tenant_id,
-                                           args['score_threshold'],
-                                           args['embedding_provider_name'], args['embedding_model_name'])
-        return {
-            'job_id': job_id,
-            'job_status': 'waiting'
-        }
+        redis_client.setnx(enable_app_annotation_job_key, "waiting")
+        enable_annotation_reply_task.delay(
+            str(job_id),
+            app_id,
+            current_user.id,
+            current_user.current_tenant_id,
+            args["score_threshold"],
+            args["embedding_provider_name"],
+            args["embedding_model_name"],
+        )
+        return {"job_id": job_id, "job_status": "waiting"}
 
     @classmethod
     def disable_app_annotation(cls, app_id: str) -> dict:
-        disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id))
+        disable_app_annotation_key = "disable_app_annotation_{}".format(str(app_id))
         cache_result = redis_client.get(disable_app_annotation_key)
         if cache_result is not None:
-            return {
-                'job_id': cache_result,
-                'job_status': 'processing'
-            }
+            return {"job_id": cache_result, "job_status": "processing"}
 
         # async job
         job_id = str(uuid.uuid4())
-        disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id))
+        disable_app_annotation_job_key = "disable_app_annotation_job_{}".format(str(job_id))
         # send batch add segments task
-        redis_client.setnx(disable_app_annotation_job_key, 'waiting')
+        redis_client.setnx(disable_app_annotation_job_key, "waiting")
         disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id)
-        return {
-            'job_id': job_id,
-            'job_status': 'waiting'
-        }
+        return {"job_id": job_id, "job_status": "waiting"}
 
     @classmethod
     def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
         # get app info
-        app = db.session.query(App).filter(
-            App.id == app_id,
-            App.tenant_id == current_user.current_tenant_id,
-            App.status == 'normal'
-        ).first()
+        app = (
+            db.session.query(App)
+            .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
+            .first()
+        )
 
         if not app:
             raise NotFound("App not found")
         if keyword:
-            annotations = (db.session.query(MessageAnnotation)
-                           .filter(MessageAnnotation.app_id == app_id)
-                           .filter(
-                or_(
-                    MessageAnnotation.question.ilike('%{}%'.format(keyword)),
-                    MessageAnnotation.content.ilike('%{}%'.format(keyword))
+            annotations = (
+                db.session.query(MessageAnnotation)
+                .filter(MessageAnnotation.app_id == app_id)
+                .filter(
+                    or_(
+                        MessageAnnotation.question.ilike("%{}%".format(keyword)),
+                        MessageAnnotation.content.ilike("%{}%".format(keyword)),
+                    )
                 )
+                .order_by(MessageAnnotation.created_at.desc())
+                .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
             )
-                           .order_by(MessageAnnotation.created_at.desc())
-                           .paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
         else:
-            annotations = (db.session.query(MessageAnnotation)
-                           .filter(MessageAnnotation.app_id == app_id)
-                           .order_by(MessageAnnotation.created_at.desc())
-                           .paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
+            annotations = (
+                db.session.query(MessageAnnotation)
+                .filter(MessageAnnotation.app_id == app_id)
+                .order_by(MessageAnnotation.created_at.desc())
+                .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
+            )
         return annotations.items, annotations.total
 
     @classmethod
     def export_annotation_list_by_app_id(cls, app_id: str):
         # get app info
-        app = db.session.query(App).filter(
-            App.id == app_id,
-            App.tenant_id == current_user.current_tenant_id,
-            App.status == 'normal'
-        ).first()
+        app = (
+            db.session.query(App)
+            .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
+            .first()
+        )
 
         if not app:
             raise NotFound("App not found")
-        annotations = (db.session.query(MessageAnnotation)
-                       .filter(MessageAnnotation.app_id == app_id)
-                       .order_by(MessageAnnotation.created_at.desc()).all())
+        annotations = (
+            db.session.query(MessageAnnotation)
+            .filter(MessageAnnotation.app_id == app_id)
+            .order_by(MessageAnnotation.created_at.desc())
+            .all()
+        )
         return annotations
 
     @classmethod
     def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
         # get app info
-        app = db.session.query(App).filter(
-            App.id == app_id,
-            App.tenant_id == current_user.current_tenant_id,
-            App.status == 'normal'
-        ).first()
+        app = (
+            db.session.query(App)
+            .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
+            .first()
+        )
 
         if not app:
             raise NotFound("App not found")
 
         annotation = MessageAnnotation(
-            app_id=app.id,
-            content=args['answer'],
-            question=args['question'],
-            account_id=current_user.id
+            app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id
         )
         db.session.add(annotation)
         db.session.commit()
         # if annotation reply is enabled , add annotation to index
-        annotation_setting = db.session.query(AppAnnotationSetting).filter(
-            AppAnnotationSetting.app_id == app_id).first()
+        annotation_setting = (
+            db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
+        )
         if annotation_setting:
-            add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id,
-                                               app_id, annotation_setting.collection_binding_id)
+            add_annotation_to_index_task.delay(
+                annotation.id,
+                args["question"],
+                current_user.current_tenant_id,
+                app_id,
+                annotation_setting.collection_binding_id,
+            )
         return annotation
 
     @classmethod
     def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
         # get app info
-        app = db.session.query(App).filter(
-            App.id == app_id,
-            App.tenant_id == current_user.current_tenant_id,
-            App.status == 'normal'
-        ).first()
+        app = (
+            db.session.query(App)
+            .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
+            .first()
+        )
 
         if not app:
             raise NotFound("App not found")
@@ -207,30 +211,34 @@ class AppAnnotationService:
         if not annotation:
             raise NotFound("Annotation not found")
 
-        annotation.content = args['answer']
-        annotation.question = args['question']
+        annotation.content = args["answer"]
+        annotation.question = args["question"]
 
         db.session.commit()
         # if annotation reply is enabled , add annotation to index
-        app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
-            AppAnnotationSetting.app_id == app_id
-        ).first()
+        app_annotation_setting = (
+            db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
+        )
 
         if app_annotation_setting:
-            update_annotation_to_index_task.delay(annotation.id, annotation.question,
-                                                  current_user.current_tenant_id,
-                                                  app_id, app_annotation_setting.collection_binding_id)
+            update_annotation_to_index_task.delay(
+                annotation.id,
+                annotation.question,
+                current_user.current_tenant_id,
+                app_id,
+                app_annotation_setting.collection_binding_id,
+            )
 
         return annotation
 
     @classmethod
     def delete_app_annotation(cls, app_id: str, annotation_id: str):
         # get app info
-        app = db.session.query(App).filter(
-            App.id == app_id,
-            App.tenant_id == current_user.current_tenant_id,
-            App.status == 'normal'
-        ).first()
+        app = (
+            db.session.query(App)
+            .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
+            .first()
+        )
 
         if not app:
             raise NotFound("App not found")
@@ -242,33 +250,34 @@ class AppAnnotationService:
 
         db.session.delete(annotation)
 
-        annotation_hit_histories = (db.session.query(AppAnnotationHitHistory)
-                                    .filter(AppAnnotationHitHistory.annotation_id == annotation_id)
-                                    .all()
-                                    )
+        annotation_hit_histories = (
+            db.session.query(AppAnnotationHitHistory)
+            .filter(AppAnnotationHitHistory.annotation_id == annotation_id)
+            .all()
+        )
         if annotation_hit_histories:
             for annotation_hit_history in annotation_hit_histories:
                 db.session.delete(annotation_hit_history)
 
         db.session.commit()
         # if annotation reply is enabled , delete annotation index
-        app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
-            AppAnnotationSetting.app_id == app_id
-        ).first()
+        app_annotation_setting = (
+            db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
+        )
 
         if app_annotation_setting:
-            delete_annotation_index_task.delay(annotation.id, app_id,
-                                               current_user.current_tenant_id,
-                                               app_annotation_setting.collection_binding_id)
+            delete_annotation_index_task.delay(
+                annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id
+            )
 
     @classmethod
     def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict:
         # get app info
-        app = db.session.query(App).filter(
-            App.id == app_id,
-            App.tenant_id == current_user.current_tenant_id,
-            App.status == 'normal'
-        ).first()
+        app = (
+            db.session.query(App)
+            .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
+            .first()
+        )
 
         if not app:
             raise NotFound("App not found")
@@ -278,10 +287,7 @@ class AppAnnotationService:
             df = pd.read_csv(file)
             result = []
             for index, row in df.iterrows():
-                content = {
-                    'question': row[0],
-                    'answer': row[1]
-                }
+                content = {"question": row[0], "answer": row[1]}
                 result.append(content)
             if len(result) == 0:
                 raise ValueError("The CSV file is empty.")
@@ -293,28 +299,24 @@ class AppAnnotationService:
                     raise ValueError("The number of annotations exceeds the limit of your subscription.")
             # async job
             job_id = str(uuid.uuid4())
-            indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
+            indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id))
             # send batch add segments task
-            redis_client.setnx(indexing_cache_key, 'waiting')
-            batch_import_annotations_task.delay(str(job_id), result, app_id,
-                                                current_user.current_tenant_id, current_user.id)
+            redis_client.setnx(indexing_cache_key, "waiting")
+            batch_import_annotations_task.delay(
+                str(job_id), result, app_id, current_user.current_tenant_id, current_user.id
+            )
         except Exception as e:
-            return {
-                'error_msg': str(e)
-            }
-        return {
-            'job_id': job_id,
-            'job_status': 'waiting'
-        }
+            return {"error_msg": str(e)}
+        return {"job_id": job_id, "job_status": "waiting"}
 
     @classmethod
     def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
         # get app info
-        app = db.session.query(App).filter(
-            App.id == app_id,
-            App.tenant_id == current_user.current_tenant_id,
-            App.status == 'normal'
-        ).first()
+        app = (
+            db.session.query(App)
+            .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
+            .first()
+        )
 
         if not app:
             raise NotFound("App not found")
@@ -324,12 +326,15 @@ class AppAnnotationService:
         if not annotation:
             raise NotFound("Annotation not found")
 
-        annotation_hit_histories = (db.session.query(AppAnnotationHitHistory)
-                                    .filter(AppAnnotationHitHistory.app_id == app_id,
-                                            AppAnnotationHitHistory.annotation_id == annotation_id,
-                                            )
-                                    .order_by(AppAnnotationHitHistory.created_at.desc())
-                                    .paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
+        annotation_hit_histories = (
+            db.session.query(AppAnnotationHitHistory)
+            .filter(
+                AppAnnotationHitHistory.app_id == app_id,
+                AppAnnotationHitHistory.annotation_id == annotation_id,
+            )
+            .order_by(AppAnnotationHitHistory.created_at.desc())
+            .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
+        )
         return annotation_hit_histories.items, annotation_hit_histories.total
 
     @classmethod
@@ -341,15 +346,21 @@ class AppAnnotationService:
         return annotation
 
     @classmethod
-    def add_annotation_history(cls, annotation_id: str, app_id: str, annotation_question: str,
-                               annotation_content: str, query: str, user_id: str,
-                               message_id: str, from_source: str, score: float):
+    def add_annotation_history(
+        cls,
+        annotation_id: str,
+        app_id: str,
+        annotation_question: str,
+        annotation_content: str,
+        query: str,
+        user_id: str,
+        message_id: str,
+        from_source: str,
+        score: float,
+    ):
         # add hit count to annotation
-        db.session.query(MessageAnnotation).filter(
-            MessageAnnotation.id == annotation_id
-        ).update(
-            {MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1},
-            synchronize_session=False
+        db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).update(
+            {MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, synchronize_session=False
         )
 
         annotation_hit_history = AppAnnotationHitHistory(
@@ -361,7 +372,7 @@ class AppAnnotationService:
             score=score,
             message_id=message_id,
             annotation_question=annotation_question,
-            annotation_content=annotation_content
+            annotation_content=annotation_content,
         )
         db.session.add(annotation_hit_history)
         db.session.commit()
@@ -369,17 +380,18 @@ class AppAnnotationService:
     @classmethod
     def get_app_annotation_setting_by_app_id(cls, app_id: str):
         # get app info
-        app = db.session.query(App).filter(
-            App.id == app_id,
-            App.tenant_id == current_user.current_tenant_id,
-            App.status == 'normal'
-        ).first()
+        app = (
+            db.session.query(App)
+            .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
+            .first()
+        )
 
         if not app:
             raise NotFound("App not found")
 
-        annotation_setting = db.session.query(AppAnnotationSetting).filter(
-            AppAnnotationSetting.app_id == app_id).first()
+        annotation_setting = (
+            db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
+        )
         if annotation_setting:
             collection_binding_detail = annotation_setting.collection_binding_detail
             return {
@@ -388,32 +400,34 @@ class AppAnnotationService:
                 "score_threshold": annotation_setting.score_threshold,
                 "embedding_model": {
                     "embedding_provider_name": collection_binding_detail.provider_name,
-                    "embedding_model_name": collection_binding_detail.model_name
-                }
+                    "embedding_model_name": collection_binding_detail.model_name,
+                },
             }
-        return {
-            "enabled": False
-        }
+        return {"enabled": False}
 
     @classmethod
     def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
         # get app info
-        app = db.session.query(App).filter(
-            App.id == app_id,
-            App.tenant_id == current_user.current_tenant_id,
-            App.status == 'normal'
-        ).first()
+        app = (
+            db.session.query(App)
+            .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
+            .first()
+        )
 
         if not app:
             raise NotFound("App not found")
 
-        annotation_setting = db.session.query(AppAnnotationSetting).filter(
-            AppAnnotationSetting.app_id == app_id,
-            AppAnnotationSetting.id == annotation_setting_id,
-        ).first()
+        annotation_setting = (
+            db.session.query(AppAnnotationSetting)
+            .filter(
+                AppAnnotationSetting.app_id == app_id,
+                AppAnnotationSetting.id == annotation_setting_id,
+            )
+            .first()
+        )
         if not annotation_setting:
             raise NotFound("App annotation not found")
-        annotation_setting.score_threshold = args['score_threshold']
+        annotation_setting.score_threshold = args["score_threshold"]
         annotation_setting.updated_user_id = current_user.id
         annotation_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
         db.session.add(annotation_setting)
@@ -427,6 +441,6 @@ class AppAnnotationService:
             "score_threshold": annotation_setting.score_threshold,
             "embedding_model": {
                 "embedding_provider_name": collection_binding_detail.provider_name,
-                "embedding_model_name": collection_binding_detail.model_name
-            }
+                "embedding_model_name": collection_binding_detail.model_name,
+            },
         }

+ 23 - 16
api/services/api_based_extension_service.py

@@ -5,13 +5,14 @@ from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
 
 
 class APIBasedExtensionService:
-
     @staticmethod
     def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]:
-        extension_list = db.session.query(APIBasedExtension) \
-                    .filter_by(tenant_id=tenant_id) \
-                    .order_by(APIBasedExtension.created_at.desc()) \
-                    .all()
+        extension_list = (
+            db.session.query(APIBasedExtension)
+            .filter_by(tenant_id=tenant_id)
+            .order_by(APIBasedExtension.created_at.desc())
+            .all()
+        )
 
         for extension in extension_list:
             extension.api_key = decrypt_token(extension.tenant_id, extension.api_key)
@@ -35,10 +36,12 @@ class APIBasedExtensionService:
 
     @staticmethod
     def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
-        extension = db.session.query(APIBasedExtension) \
-            .filter_by(tenant_id=tenant_id) \
-            .filter_by(id=api_based_extension_id) \
+        extension = (
+            db.session.query(APIBasedExtension)
+            .filter_by(tenant_id=tenant_id)
+            .filter_by(id=api_based_extension_id)
             .first()
+        )
 
         if not extension:
             raise ValueError("API based extension is not found")
@@ -55,20 +58,24 @@ class APIBasedExtensionService:
 
         if not extension_data.id:
             # case one: check new data, name must be unique
-            is_name_existed = db.session.query(APIBasedExtension) \
-                .filter_by(tenant_id=extension_data.tenant_id) \
-                .filter_by(name=extension_data.name) \
+            is_name_existed = (
+                db.session.query(APIBasedExtension)
+                .filter_by(tenant_id=extension_data.tenant_id)
+                .filter_by(name=extension_data.name)
                 .first()
+            )
 
             if is_name_existed:
                 raise ValueError("name must be unique, it is already existed")
         else:
             # case two: check existing data, name must be unique
-            is_name_existed = db.session.query(APIBasedExtension) \
-                .filter_by(tenant_id=extension_data.tenant_id) \
-                .filter_by(name=extension_data.name) \
-                .filter(APIBasedExtension.id != extension_data.id) \
+            is_name_existed = (
+                db.session.query(APIBasedExtension)
+                .filter_by(tenant_id=extension_data.tenant_id)
+                .filter_by(name=extension_data.name)
+                .filter(APIBasedExtension.id != extension_data.id)
                 .first()
+            )
 
             if is_name_existed:
                 raise ValueError("name must be unique, it is already existed")
@@ -92,7 +99,7 @@ class APIBasedExtensionService:
         try:
             client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key)
             resp = client.request(point=APIBasedExtensionPoint.PING, params={})
-            if resp.get('result') != 'pong':
+            if resp.get("result") != "pong":
                 raise ValueError(resp)
         except Exception as e:
             raise ValueError("connection error: {}".format(e))

+ 92 - 93
api/services/app_dsl_service.py

@@ -75,43 +75,44 @@ class AppDslService:
         # check or repair dsl version
         import_data = cls._check_or_fix_dsl(import_data)
 
-        app_data = import_data.get('app')
+        app_data = import_data.get("app")
         if not app_data:
             raise ValueError("Missing app in data argument")
 
         # get app basic info
-        name = args.get("name") if args.get("name") else app_data.get('name')
-        description = args.get("description") if args.get("description") else app_data.get('description', '')
-        icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get('icon_type')
-        icon = args.get("icon") if args.get("icon") else app_data.get('icon')
-        icon_background = args.get("icon_background") if args.get("icon_background") \
-            else app_data.get('icon_background')
+        name = args.get("name") if args.get("name") else app_data.get("name")
+        description = args.get("description") if args.get("description") else app_data.get("description", "")
+        icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get("icon_type")
+        icon = args.get("icon") if args.get("icon") else app_data.get("icon")
+        icon_background = (
+            args.get("icon_background") if args.get("icon_background") else app_data.get("icon_background")
+        )
 
         # import dsl and create app
-        app_mode = AppMode.value_of(app_data.get('mode'))
+        app_mode = AppMode.value_of(app_data.get("mode"))
         if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
             app = cls._import_and_create_new_workflow_based_app(
                 tenant_id=tenant_id,
                 app_mode=app_mode,
-                workflow_data=import_data.get('workflow'),
+                workflow_data=import_data.get("workflow"),
                 account=account,
                 name=name,
                 description=description,
                 icon_type=icon_type,
                 icon=icon,
-                icon_background=icon_background
+                icon_background=icon_background,
             )
         elif app_mode in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]:
             app = cls._import_and_create_new_model_config_based_app(
                 tenant_id=tenant_id,
                 app_mode=app_mode,
-                model_config_data=import_data.get('model_config'),
+                model_config_data=import_data.get("model_config"),
                 account=account,
                 name=name,
                 description=description,
                 icon_type=icon_type,
                 icon=icon,
-                icon_background=icon_background
+                icon_background=icon_background,
             )
         else:
             raise ValueError("Invalid app mode")
@@ -134,27 +135,26 @@ class AppDslService:
         # check or repair dsl version
         import_data = cls._check_or_fix_dsl(import_data)
 
-        app_data = import_data.get('app')
+        app_data = import_data.get("app")
         if not app_data:
             raise ValueError("Missing app in data argument")
 
         # import dsl and overwrite app
-        app_mode = AppMode.value_of(app_data.get('mode'))
+        app_mode = AppMode.value_of(app_data.get("mode"))
         if app_mode not in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
             raise ValueError("Only support import workflow in advanced-chat or workflow app.")
 
-        if app_data.get('mode') != app_model.mode:
-            raise ValueError(
-                f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}")
+        if app_data.get("mode") != app_model.mode:
+            raise ValueError(f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}")
 
         return cls._import_and_overwrite_workflow_based_app(
             app_model=app_model,
-            workflow_data=import_data.get('workflow'),
+            workflow_data=import_data.get("workflow"),
             account=account,
         )
 
     @classmethod
-    def export_dsl(cls, app_model: App, include_secret:bool = False) -> str:
+    def export_dsl(cls, app_model: App, include_secret: bool = False) -> str:
         """
         Export app
         :param app_model: App instance
@@ -168,14 +168,16 @@ class AppDslService:
             "app": {
                 "name": app_model.name,
                 "mode": app_model.mode,
-                "icon": '🤖' if app_model.icon_type == 'image' else app_model.icon,
-                "icon_background": '#FFEAD5' if app_model.icon_type == 'image' else app_model.icon_background,
-                "description": app_model.description
-            }
+                "icon": "🤖" if app_model.icon_type == "image" else app_model.icon,
+                "icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background,
+                "description": app_model.description,
+            },
         }
 
         if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
-            cls._append_workflow_export_data(export_data=export_data, app_model=app_model, include_secret=include_secret)
+            cls._append_workflow_export_data(
+                export_data=export_data, app_model=app_model, include_secret=include_secret
+            )
         else:
             cls._append_model_config_export_data(export_data, app_model)
 
@@ -188,31 +190,35 @@ class AppDslService:
 
         :param import_data: import data
         """
-        if not import_data.get('version'):
-            import_data['version'] = "0.1.0"
+        if not import_data.get("version"):
+            import_data["version"] = "0.1.0"
 
-        if not import_data.get('kind') or import_data.get('kind') != "app":
-            import_data['kind'] = "app"
+        if not import_data.get("kind") or import_data.get("kind") != "app":
+            import_data["kind"] = "app"
 
-        if import_data.get('version') != current_dsl_version:
+        if import_data.get("version") != current_dsl_version:
             # Currently only one DSL version, so no difference checks or compatibility fixes will be performed.
-            logger.warning(f"DSL version {import_data.get('version')} is not compatible "
-                           f"with current version {current_dsl_version}, related to "
-                           f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}.")
+            logger.warning(
+                f"DSL version {import_data.get('version')} is not compatible "
+                f"with current version {current_dsl_version}, related to "
+                f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}."
+            )
 
         return import_data
 
     @classmethod
-    def _import_and_create_new_workflow_based_app(cls,
-                                                  tenant_id: str,
-                                                  app_mode: AppMode,
-                                                  workflow_data: dict,
-                                                  account: Account,
-                                                  name: str,
-                                                  description: str,
-                                                  icon_type: str,
-                                                  icon: str,
-                                                  icon_background: str) -> App:
+    def _import_and_create_new_workflow_based_app(
+        cls,
+        tenant_id: str,
+        app_mode: AppMode,
+        workflow_data: dict,
+        account: Account,
+        name: str,
+        description: str,
+        icon_type: str,
+        icon: str,
+        icon_background: str,
+    ) -> App:
         """
         Import app dsl and create new workflow based app
 
@@ -227,8 +233,7 @@ class AppDslService:
         :param icon_background: app icon background
         """
         if not workflow_data:
-            raise ValueError("Missing workflow in data argument "
-                             "when app mode is advanced-chat or workflow")
+            raise ValueError("Missing workflow in data argument " "when app mode is advanced-chat or workflow")
 
         app = cls._create_app(
             tenant_id=tenant_id,
@@ -238,37 +243,32 @@ class AppDslService:
             description=description,
             icon_type=icon_type,
             icon=icon,
-            icon_background=icon_background
+            icon_background=icon_background,
         )
 
         # init draft workflow
-        environment_variables_list = workflow_data.get('environment_variables') or []
+        environment_variables_list = workflow_data.get("environment_variables") or []
         environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
-        conversation_variables_list = workflow_data.get('conversation_variables') or []
+        conversation_variables_list = workflow_data.get("conversation_variables") or []
         conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
         workflow_service = WorkflowService()
         draft_workflow = workflow_service.sync_draft_workflow(
             app_model=app,
-            graph=workflow_data.get('graph', {}),
-            features=workflow_data.get('../core/app/features', {}),
+            graph=workflow_data.get("graph", {}),
+            features=workflow_data.get("../core/app/features", {}),
             unique_hash=None,
             account=account,
             environment_variables=environment_variables,
             conversation_variables=conversation_variables,
         )
-        workflow_service.publish_workflow(
-            app_model=app,
-            account=account,
-            draft_workflow=draft_workflow
-        )
+        workflow_service.publish_workflow(app_model=app, account=account, draft_workflow=draft_workflow)
 
         return app
 
     @classmethod
-    def _import_and_overwrite_workflow_based_app(cls,
-                                                 app_model: App,
-                                                 workflow_data: dict,
-                                                 account: Account) -> Workflow:
+    def _import_and_overwrite_workflow_based_app(
+        cls, app_model: App, workflow_data: dict, account: Account
+    ) -> Workflow:
         """
         Import app dsl and overwrite workflow based app
 
@@ -277,8 +277,7 @@ class AppDslService:
         :param account: Account instance
         """
         if not workflow_data:
-            raise ValueError("Missing workflow in data argument "
-                             "when app mode is advanced-chat or workflow")
+            raise ValueError("Missing workflow in data argument " "when app mode is advanced-chat or workflow")
 
         # fetch draft workflow by app_model
         workflow_service = WorkflowService()
@@ -289,14 +288,14 @@ class AppDslService:
             unique_hash = None
 
         # sync draft workflow
-        environment_variables_list = workflow_data.get('environment_variables') or []
+        environment_variables_list = workflow_data.get("environment_variables") or []
         environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
-        conversation_variables_list = workflow_data.get('conversation_variables') or []
+        conversation_variables_list = workflow_data.get("conversation_variables") or []
         conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
         draft_workflow = workflow_service.sync_draft_workflow(
             app_model=app_model,
-            graph=workflow_data.get('graph', {}),
-            features=workflow_data.get('features', {}),
+            graph=workflow_data.get("graph", {}),
+            features=workflow_data.get("features", {}),
             unique_hash=unique_hash,
             account=account,
             environment_variables=environment_variables,
@@ -306,16 +305,18 @@ class AppDslService:
         return draft_workflow
 
     @classmethod
-    def _import_and_create_new_model_config_based_app(cls,
-                                                      tenant_id: str,
-                                                      app_mode: AppMode,
-                                                      model_config_data: dict,
-                                                      account: Account,
-                                                      name: str,
-                                                      description: str,
-                                                      icon_type: str,
-                                                      icon: str,
-                                                      icon_background: str) -> App:
+    def _import_and_create_new_model_config_based_app(
+        cls,
+        tenant_id: str,
+        app_mode: AppMode,
+        model_config_data: dict,
+        account: Account,
+        name: str,
+        description: str,
+        icon_type: str,
+        icon: str,
+        icon_background: str,
+    ) -> App:
         """
         Import app dsl and create new model config based app
 
@@ -329,8 +330,7 @@ class AppDslService:
         :param icon_background: app icon background
         """
         if not model_config_data:
-            raise ValueError("Missing model_config in data argument "
-                             "when app mode is chat, agent-chat or completion")
+            raise ValueError("Missing model_config in data argument " "when app mode is chat, agent-chat or completion")
 
         app = cls._create_app(
             tenant_id=tenant_id,
@@ -340,7 +340,7 @@ class AppDslService:
             description=description,
             icon_type=icon_type,
             icon=icon,
-            icon_background=icon_background
+            icon_background=icon_background,
         )
 
         app_model_config = AppModelConfig()
@@ -352,23 +352,22 @@ class AppDslService:
 
         app.app_model_config_id = app_model_config.id
 
-        app_model_config_was_updated.send(
-            app,
-            app_model_config=app_model_config
-        )
+        app_model_config_was_updated.send(app, app_model_config=app_model_config)
 
         return app
 
     @classmethod
-    def _create_app(cls,
-                    tenant_id: str,
-                    app_mode: AppMode,
-                    account: Account,
-                    name: str,
-                    description: str,
-                    icon_type: str,
-                    icon: str,
-                    icon_background: str) -> App:
+    def _create_app(
+        cls,
+        tenant_id: str,
+        app_mode: AppMode,
+        account: Account,
+        name: str,
+        description: str,
+        icon_type: str,
+        icon: str,
+        icon_background: str,
+    ) -> App:
         """
         Create new app
 
@@ -390,7 +389,7 @@ class AppDslService:
             icon=icon,
             icon_background=icon_background,
             enable_site=True,
-            enable_api=True
+            enable_api=True,
         )
 
         db.session.add(app)
@@ -412,7 +411,7 @@ class AppDslService:
         if not workflow:
             raise ValueError("Missing draft workflow configuration, please check.")
 
-        export_data['workflow'] = workflow.to_dict(include_secret=include_secret)
+        export_data["workflow"] = workflow.to_dict(include_secret=include_secret)
 
     @classmethod
     def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None:
@@ -425,4 +424,4 @@ class AppDslService:
         if not app_model_config:
             raise ValueError("Missing app configuration, please check.")
 
-        export_data['model_config'] = app_model_config.to_dict()
+        export_data["model_config"] = app_model_config.to_dict()

+ 66 - 73
api/services/app_generate_service.py

@@ -14,14 +14,15 @@ from services.workflow_service import WorkflowService
 
 
 class AppGenerateService:
-
     @classmethod
-    def generate(cls, app_model: App,
-                 user: Union[Account, EndUser],
-                 args: Any,
-                 invoke_from: InvokeFrom,
-                 streaming: bool = True,
-                 ):
+    def generate(
+        cls,
+        app_model: App,
+        user: Union[Account, EndUser],
+        args: Any,
+        invoke_from: InvokeFrom,
+        streaming: bool = True,
+    ):
         """
         App Content Generate
         :param app_model: app model
@@ -37,51 +38,54 @@ class AppGenerateService:
         try:
             request_id = rate_limit.enter(request_id)
             if app_model.mode == AppMode.COMPLETION.value:
-                return rate_limit.generate(CompletionAppGenerator().generate(
-                    app_model=app_model,
-                    user=user,
-                    args=args,
-                    invoke_from=invoke_from,
-                    stream=streaming
-                ), request_id)
+                return rate_limit.generate(
+                    CompletionAppGenerator().generate(
+                        app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming
+                    ),
+                    request_id,
+                )
             elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
-                return rate_limit.generate(AgentChatAppGenerator().generate(
-                    app_model=app_model,
-                    user=user,
-                    args=args,
-                    invoke_from=invoke_from,
-                    stream=streaming
-                ), request_id)
+                return rate_limit.generate(
+                    AgentChatAppGenerator().generate(
+                        app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming
+                    ),
+                    request_id,
+                )
             elif app_model.mode == AppMode.CHAT.value:
-                return rate_limit.generate(ChatAppGenerator().generate(
-                    app_model=app_model,
-                    user=user,
-                    args=args,
-                    invoke_from=invoke_from,
-                    stream=streaming
-                ), request_id)
+                return rate_limit.generate(
+                    ChatAppGenerator().generate(
+                        app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming
+                    ),
+                    request_id,
+                )
             elif app_model.mode == AppMode.ADVANCED_CHAT.value:
                 workflow = cls._get_workflow(app_model, invoke_from)
-                return rate_limit.generate(AdvancedChatAppGenerator().generate(
-                    app_model=app_model,
-                    workflow=workflow,
-                    user=user,
-                    args=args,
-                    invoke_from=invoke_from,
-                    stream=streaming
-                ), request_id)
+                return rate_limit.generate(
+                    AdvancedChatAppGenerator().generate(
+                        app_model=app_model,
+                        workflow=workflow,
+                        user=user,
+                        args=args,
+                        invoke_from=invoke_from,
+                        stream=streaming,
+                    ),
+                    request_id,
+                )
             elif app_model.mode == AppMode.WORKFLOW.value:
                 workflow = cls._get_workflow(app_model, invoke_from)
-                return rate_limit.generate(WorkflowAppGenerator().generate(
-                    app_model=app_model,
-                    workflow=workflow,
-                    user=user,
-                    args=args,
-                    invoke_from=invoke_from,
-                    stream=streaming
-                ), request_id)
+                return rate_limit.generate(
+                    WorkflowAppGenerator().generate(
+                        app_model=app_model,
+                        workflow=workflow,
+                        user=user,
+                        args=args,
+                        invoke_from=invoke_from,
+                        stream=streaming,
+                    ),
+                    request_id,
+                )
             else:
-                raise ValueError(f'Invalid app mode {app_model.mode}')
+                raise ValueError(f"Invalid app mode {app_model.mode}")
         finally:
             if not streaming:
                 rate_limit.exit(request_id)
@@ -94,38 +98,31 @@ class AppGenerateService:
         return max_active_requests
 
     @classmethod
-    def generate_single_iteration(cls, app_model: App,
-                                  user: Union[Account, EndUser],
-                                  node_id: str,
-                                  args: Any,
-                                  streaming: bool = True):
+    def generate_single_iteration(
+        cls, app_model: App, user: Union[Account, EndUser], node_id: str, args: Any, streaming: bool = True
+    ):
         if app_model.mode == AppMode.ADVANCED_CHAT.value:
             workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
             return AdvancedChatAppGenerator().single_iteration_generate(
-                app_model=app_model,
-                workflow=workflow,
-                node_id=node_id,
-                user=user,
-                args=args,
-                stream=streaming
+                app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming
             )
         elif app_model.mode == AppMode.WORKFLOW.value:
             workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
             return WorkflowAppGenerator().single_iteration_generate(
-                app_model=app_model,
-                workflow=workflow,
-                node_id=node_id,
-                user=user,
-                args=args,
-                stream=streaming
+                app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming
             )
         else:
-            raise ValueError(f'Invalid app mode {app_model.mode}')
+            raise ValueError(f"Invalid app mode {app_model.mode}")
 
     @classmethod
-    def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
-                                message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \
-            -> Union[dict, Generator]:
+    def generate_more_like_this(
+        cls,
+        app_model: App,
+        user: Union[Account, EndUser],
+        message_id: str,
+        invoke_from: InvokeFrom,
+        streaming: bool = True,
+    ) -> Union[dict, Generator]:
         """
         Generate more like this
         :param app_model: app model
@@ -136,11 +133,7 @@ class AppGenerateService:
         :return:
         """
         return CompletionAppGenerator().generate_more_like_this(
-            app_model=app_model,
-            message_id=message_id,
-            user=user,
-            invoke_from=invoke_from,
-            stream=streaming
+            app_model=app_model, message_id=message_id, user=user, invoke_from=invoke_from, stream=streaming
         )
 
     @classmethod
@@ -157,12 +150,12 @@ class AppGenerateService:
             workflow = workflow_service.get_draft_workflow(app_model=app_model)
 
             if not workflow:
-                raise ValueError('Workflow not initialized')
+                raise ValueError("Workflow not initialized")
         else:
             # fetch published workflow by app_model
             workflow = workflow_service.get_published_workflow(app_model=app_model)
 
             if not workflow:
-                raise ValueError('Workflow not published')
+                raise ValueError("Workflow not published")
 
         return workflow

+ 0 - 1
api/services/app_model_config_service.py

@@ -5,7 +5,6 @@ from models.model import AppMode
 
 
 class AppModelConfigService:
-
     @classmethod
     def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict:
         if app_mode == AppMode.CHAT:

+ 77 - 88
api/services/app_service.py

@@ -33,27 +33,22 @@ class AppService:
         :param args: request args
         :return:
         """
-        filters = [
-            App.tenant_id == tenant_id,
-            App.is_universal == False
-        ]
+        filters = [App.tenant_id == tenant_id, App.is_universal == False]
 
-        if args['mode'] == 'workflow':
+        if args["mode"] == "workflow":
             filters.append(App.mode.in_([AppMode.WORKFLOW.value, AppMode.COMPLETION.value]))
-        elif args['mode'] == 'chat':
+        elif args["mode"] == "chat":
             filters.append(App.mode.in_([AppMode.CHAT.value, AppMode.ADVANCED_CHAT.value]))
-        elif args['mode'] == 'agent-chat':
+        elif args["mode"] == "agent-chat":
             filters.append(App.mode == AppMode.AGENT_CHAT.value)
-        elif args['mode'] == 'channel':
+        elif args["mode"] == "channel":
             filters.append(App.mode == AppMode.CHANNEL.value)
 
-        if args.get('name'):
-            name = args['name'][:30]
-            filters.append(App.name.ilike(f'%{name}%'))
-        if args.get('tag_ids'):
-            target_ids = TagService.get_target_ids_by_tag_ids('app',
-                                                              tenant_id,
-                                                              args['tag_ids'])
+        if args.get("name"):
+            name = args["name"][:30]
+            filters.append(App.name.ilike(f"%{name}%"))
+        if args.get("tag_ids"):
+            target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"])
             if target_ids:
                 filters.append(App.id.in_(target_ids))
             else:
@@ -61,9 +56,9 @@ class AppService:
 
         app_models = db.paginate(
             db.select(App).where(*filters).order_by(App.created_at.desc()),
-            page=args['page'],
-            per_page=args['limit'],
-            error_out=False
+            page=args["page"],
+            per_page=args["limit"],
+            error_out=False,
         )
 
         return app_models
@@ -75,21 +70,20 @@ class AppService:
         :param args: request args
         :param account: Account instance
         """
-        app_mode = AppMode.value_of(args['mode'])
+        app_mode = AppMode.value_of(args["mode"])
         app_template = default_app_templates[app_mode]
 
         # get model config
-        default_model_config = app_template.get('model_config')
+        default_model_config = app_template.get("model_config")
         default_model_config = default_model_config.copy() if default_model_config else None
-        if default_model_config and 'model' in default_model_config:
+        if default_model_config and "model" in default_model_config:
             # get model provider
             model_manager = ModelManager()
 
             # get default model instance
             try:
                 model_instance = model_manager.get_default_model_instance(
-                    tenant_id=account.current_tenant_id,
-                    model_type=ModelType.LLM
+                    tenant_id=account.current_tenant_id, model_type=ModelType.LLM
                 )
             except (ProviderTokenNotInitError, LLMBadRequestError):
                 model_instance = None
@@ -98,39 +92,41 @@ class AppService:
                 model_instance = None
 
             if model_instance:
-                if model_instance.model == default_model_config['model']['name'] and model_instance.provider == default_model_config['model']['provider']:
-                    default_model_dict = default_model_config['model']
+                if (
+                    model_instance.model == default_model_config["model"]["name"]
+                    and model_instance.provider == default_model_config["model"]["provider"]
+                ):
+                    default_model_dict = default_model_config["model"]
                 else:
                     llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
                     model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
 
                     default_model_dict = {
-                        'provider': model_instance.provider,
-                        'name': model_instance.model,
-                        'mode': model_schema.model_properties.get(ModelPropertyKey.MODE),
-                        'completion_params': {}
+                        "provider": model_instance.provider,
+                        "name": model_instance.model,
+                        "mode": model_schema.model_properties.get(ModelPropertyKey.MODE),
+                        "completion_params": {},
                     }
             else:
                 provider, model = model_manager.get_default_provider_model_name(
-                    tenant_id=account.current_tenant_id,
-                    model_type=ModelType.LLM
+                    tenant_id=account.current_tenant_id, model_type=ModelType.LLM
                 )
-                default_model_config['model']['provider'] = provider
-                default_model_config['model']['name'] = model
-                default_model_dict = default_model_config['model']
-
-            default_model_config['model'] = json.dumps(default_model_dict)
-
-        app = App(**app_template['app'])
-        app.name = args['name']
-        app.description = args.get('description', '')
-        app.mode = args['mode']
-        app.icon_type = args.get('icon_type', 'emoji')
-        app.icon = args['icon']
-        app.icon_background = args['icon_background']
+                default_model_config["model"]["provider"] = provider
+                default_model_config["model"]["name"] = model
+                default_model_dict = default_model_config["model"]
+
+            default_model_config["model"] = json.dumps(default_model_dict)
+
+        app = App(**app_template["app"])
+        app.name = args["name"]
+        app.description = args.get("description", "")
+        app.mode = args["mode"]
+        app.icon_type = args.get("icon_type", "emoji")
+        app.icon = args["icon"]
+        app.icon_background = args["icon_background"]
         app.tenant_id = tenant_id
-        app.api_rph = args.get('api_rph', 0)
-        app.api_rpm = args.get('api_rpm', 0)
+        app.api_rph = args.get("api_rph", 0)
+        app.api_rpm = args.get("api_rpm", 0)
 
         db.session.add(app)
         db.session.flush()
@@ -158,7 +154,7 @@ class AppService:
             model_config: AppModelConfig = app.app_model_config
             agent_mode = model_config.agent_mode_dict
             # decrypt agent tool parameters if it's secret-input
-            for tool in agent_mode.get('tools') or []:
+            for tool in agent_mode.get("tools") or []:
                 if not isinstance(tool, dict) or len(tool.keys()) <= 3:
                     continue
                 agent_tool_entity = AgentToolEntity(**tool)
@@ -174,7 +170,7 @@ class AppService:
                         tool_runtime=tool_runtime,
                         provider_name=agent_tool_entity.provider_id,
                         provider_type=agent_tool_entity.provider_type,
-                        identity_id=f'AGENT.{app.id}'
+                        identity_id=f"AGENT.{app.id}",
                     )
 
                     # get decrypted parameters
@@ -185,7 +181,7 @@ class AppService:
                         masked_parameter = {}
 
                     # override tool parameters
-                    tool['tool_parameters'] = masked_parameter
+                    tool["tool_parameters"] = masked_parameter
                 except Exception as e:
                     pass
 
@@ -215,12 +211,12 @@ class AppService:
         :param args: request args
         :return: App instance
         """
-        app.name = args.get('name')
-        app.description = args.get('description', '')
-        app.max_active_requests = args.get('max_active_requests')
-        app.icon_type = args.get('icon_type', 'emoji')
-        app.icon = args.get('icon')
-        app.icon_background = args.get('icon_background')
+        app.name = args.get("name")
+        app.description = args.get("description", "")
+        app.max_active_requests = args.get("max_active_requests")
+        app.icon_type = args.get("icon_type", "emoji")
+        app.icon = args.get("icon")
+        app.icon_background = args.get("icon_background")
         app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
         db.session.commit()
 
@@ -298,10 +294,7 @@ class AppService:
         db.session.commit()
 
         # Trigger asynchronous deletion of app and related data
-        remove_app_and_related_data_task.delay(
-            tenant_id=app.tenant_id,
-            app_id=app.id
-        )
+        remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id)
 
     def get_app_meta(self, app_model: App) -> dict:
         """
@@ -311,9 +304,7 @@ class AppService:
         """
         app_mode = AppMode.value_of(app_model.mode)
 
-        meta = {
-            'tool_icons': {}
-        }
+        meta = {"tool_icons": {}}
 
         if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
             workflow = app_model.workflow
@@ -321,17 +312,19 @@ class AppService:
                 return meta
 
             graph = workflow.graph_dict
-            nodes = graph.get('nodes', [])
+            nodes = graph.get("nodes", [])
             tools = []
             for node in nodes:
-                if node.get('data', {}).get('type') == 'tool':
-                    node_data = node.get('data', {})
-                    tools.append({
-                        'provider_type': node_data.get('provider_type'),
-                        'provider_id': node_data.get('provider_id'),
-                        'tool_name': node_data.get('tool_name'),
-                        'tool_parameters': {}
-                    })
+                if node.get("data", {}).get("type") == "tool":
+                    node_data = node.get("data", {})
+                    tools.append(
+                        {
+                            "provider_type": node_data.get("provider_type"),
+                            "provider_id": node_data.get("provider_id"),
+                            "tool_name": node_data.get("tool_name"),
+                            "tool_parameters": {},
+                        }
+                    )
         else:
             app_model_config: AppModelConfig = app_model.app_model_config
 
@@ -341,30 +334,26 @@ class AppService:
             agent_config = app_model_config.agent_mode_dict or {}
 
             # get all tools
-            tools = agent_config.get('tools', [])
+            tools = agent_config.get("tools", [])
 
-        url_prefix = (dify_config.CONSOLE_API_URL
-                      + "/console/api/workspaces/current/tool-provider/builtin/")
+        url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
 
         for tool in tools:
             keys = list(tool.keys())
             if len(keys) >= 4:
                 # current tool standard
-                provider_type = tool.get('provider_type')
-                provider_id = tool.get('provider_id')
-                tool_name = tool.get('tool_name')
-                if provider_type == 'builtin':
-                    meta['tool_icons'][tool_name] = url_prefix + provider_id + '/icon'
-                elif provider_type == 'api':
+                provider_type = tool.get("provider_type")
+                provider_id = tool.get("provider_id")
+                tool_name = tool.get("tool_name")
+                if provider_type == "builtin":
+                    meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon"
+                elif provider_type == "api":
                     try:
-                        provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
-                            ApiToolProvider.id == provider_id
-                        ).first()
-                        meta['tool_icons'][tool_name] = json.loads(provider.icon)
+                        provider: ApiToolProvider = (
+                            db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first()
+                        )
+                        meta["tool_icons"][tool_name] = json.loads(provider.icon)
                     except:
-                        meta['tool_icons'][tool_name] = {
-                            "background": "#252525",
-                            "content": "\ud83d\ude01"
-                        }
+                        meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"}
 
         return meta

+ 26 - 30
api/services/audio_service.py

@@ -17,7 +17,7 @@ from services.errors.audio import (
 
 FILE_SIZE = 30
 FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024
-ALLOWED_EXTENSIONS = ['mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'wav', 'webm', 'amr']
+ALLOWED_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm", "amr"]
 
 logger = logging.getLogger(__name__)
 
@@ -31,19 +31,19 @@ class AudioService:
                 raise ValueError("Speech to text is not enabled")
 
             features_dict = workflow.features_dict
-            if 'speech_to_text' not in features_dict or not features_dict['speech_to_text'].get('enabled'):
+            if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"):
                 raise ValueError("Speech to text is not enabled")
         else:
             app_model_config: AppModelConfig = app_model.app_model_config
 
-            if not app_model_config.speech_to_text_dict['enabled']:
+            if not app_model_config.speech_to_text_dict["enabled"]:
                 raise ValueError("Speech to text is not enabled")
 
         if file is None:
             raise NoAudioUploadedServiceError()
 
         extension = file.mimetype
-        if extension not in [f'audio/{ext}' for ext in ALLOWED_EXTENSIONS]:
+        if extension not in [f"audio/{ext}" for ext in ALLOWED_EXTENSIONS]:
             raise UnsupportedAudioTypeServiceError()
 
         file_content = file.read()
@@ -55,20 +55,25 @@ class AudioService:
 
         model_manager = ModelManager()
         model_instance = model_manager.get_default_model_instance(
-            tenant_id=app_model.tenant_id,
-            model_type=ModelType.SPEECH2TEXT
+            tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT
         )
         if model_instance is None:
             raise ProviderNotSupportSpeechToTextServiceError()
 
         buffer = io.BytesIO(file_content)
-        buffer.name = 'temp.mp3'
+        buffer.name = "temp.mp3"
 
         return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)}
 
     @classmethod
-    def transcript_tts(cls, app_model: App, text: Optional[str] = None,
-                       voice: Optional[str] = None, end_user: Optional[str] = None, message_id: Optional[str] = None):
+    def transcript_tts(
+        cls,
+        app_model: App,
+        text: Optional[str] = None,
+        voice: Optional[str] = None,
+        end_user: Optional[str] = None,
+        message_id: Optional[str] = None,
+    ):
         from collections.abc import Generator
 
         from flask import Response, stream_with_context
@@ -84,65 +89,56 @@ class AudioService:
                         raise ValueError("TTS is not enabled")
 
                     features_dict = workflow.features_dict
-                    if 'text_to_speech' not in features_dict or not features_dict['text_to_speech'].get('enabled'):
+                    if "text_to_speech" not in features_dict or not features_dict["text_to_speech"].get("enabled"):
                         raise ValueError("TTS is not enabled")
 
-                    voice = features_dict['text_to_speech'].get('voice') if voice is None else voice
+                    voice = features_dict["text_to_speech"].get("voice") if voice is None else voice
                 else:
                     text_to_speech_dict = app_model.app_model_config.text_to_speech_dict
 
-                    if not text_to_speech_dict.get('enabled'):
+                    if not text_to_speech_dict.get("enabled"):
                         raise ValueError("TTS is not enabled")
 
-                    voice = text_to_speech_dict.get('voice') if voice is None else voice
+                    voice = text_to_speech_dict.get("voice") if voice is None else voice
 
                 model_manager = ModelManager()
                 model_instance = model_manager.get_default_model_instance(
-                    tenant_id=app_model.tenant_id,
-                    model_type=ModelType.TTS
+                    tenant_id=app_model.tenant_id, model_type=ModelType.TTS
                 )
                 try:
                     if not voice:
                         voices = model_instance.get_tts_voices()
                         if voices:
-                            voice = voices[0].get('value')
+                            voice = voices[0].get("value")
                         else:
                             raise ValueError("Sorry, no voice available.")
 
                     return model_instance.invoke_tts(
-                        content_text=text_content.strip(),
-                        user=end_user,
-                        tenant_id=app_model.tenant_id,
-                        voice=voice
+                        content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice
                     )
                 except Exception as e:
                     raise e
 
         if message_id:
-            message = db.session.query(Message).filter(
-                Message.id == message_id
-            ).first()
-            if message.answer == '' and message.status == 'normal':
+            message = db.session.query(Message).filter(Message.id == message_id).first()
+            if message.answer == "" and message.status == "normal":
                 return None
 
             else:
                 response = invoke_tts(message.answer, app_model=app_model, voice=voice)
                 if isinstance(response, Generator):
-                    return Response(stream_with_context(response), content_type='audio/mpeg')
+                    return Response(stream_with_context(response), content_type="audio/mpeg")
                 return response
         else:
             response = invoke_tts(text, app_model, voice)
             if isinstance(response, Generator):
-                return Response(stream_with_context(response), content_type='audio/mpeg')
+                return Response(stream_with_context(response), content_type="audio/mpeg")
             return response
 
     @classmethod
     def transcript_tts_voices(cls, tenant_id: str, language: str):
         model_manager = ModelManager()
-        model_instance = model_manager.get_default_model_instance(
-            tenant_id=tenant_id,
-            model_type=ModelType.TTS
-        )
+        model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.TTS)
         if model_instance is None:
             raise ProviderNotSupportTextToSpeechServiceError()
 

+ 2 - 4
api/services/auth/api_key_auth_factory.py

@@ -1,14 +1,12 @@
-
 from services.auth.firecrawl import FirecrawlAuth
 
 
 class ApiKeyAuthFactory:
-
     def __init__(self, provider: str, credentials: dict):
-        if provider == 'firecrawl':
+        if provider == "firecrawl":
             self.auth = FirecrawlAuth(credentials)
         else:
-            raise ValueError('Invalid provider')
+            raise ValueError("Invalid provider")
 
     def validate_credentials(self):
         return self.auth.validate_credentials()

+ 36 - 32
api/services/auth/api_key_auth_service.py

@@ -7,39 +7,43 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory
 
 
 class ApiKeyAuthService:
-
     @staticmethod
     def get_provider_auth_list(tenant_id: str) -> list:
-        data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
-            DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
-            DataSourceApiKeyAuthBinding.disabled.is_(False)
-        ).all()
+        data_source_api_key_bindings = (
+            db.session.query(DataSourceApiKeyAuthBinding)
+            .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))
+            .all()
+        )
         return data_source_api_key_bindings
 
     @staticmethod
     def create_provider_auth(tenant_id: str, args: dict):
-        auth_result = ApiKeyAuthFactory(args['provider'], args['credentials']).validate_credentials()
+        auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
         if auth_result:
             # Encrypt the api key
-            api_key = encrypter.encrypt_token(tenant_id, args['credentials']['config']['api_key'])
-            args['credentials']['config']['api_key'] = api_key
+            api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"])
+            args["credentials"]["config"]["api_key"] = api_key
 
             data_source_api_key_binding = DataSourceApiKeyAuthBinding()
             data_source_api_key_binding.tenant_id = tenant_id
-            data_source_api_key_binding.category = args['category']
-            data_source_api_key_binding.provider = args['provider']
-            data_source_api_key_binding.credentials = json.dumps(args['credentials'], ensure_ascii=False)
+            data_source_api_key_binding.category = args["category"]
+            data_source_api_key_binding.provider = args["provider"]
+            data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
             db.session.add(data_source_api_key_binding)
             db.session.commit()
 
     @staticmethod
     def get_auth_credentials(tenant_id: str, category: str, provider: str):
-        data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
-            DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
-            DataSourceApiKeyAuthBinding.category == category,
-            DataSourceApiKeyAuthBinding.provider == provider,
-            DataSourceApiKeyAuthBinding.disabled.is_(False)
-        ).first()
+        data_source_api_key_bindings = (
+            db.session.query(DataSourceApiKeyAuthBinding)
+            .filter(
+                DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
+                DataSourceApiKeyAuthBinding.category == category,
+                DataSourceApiKeyAuthBinding.provider == provider,
+                DataSourceApiKeyAuthBinding.disabled.is_(False),
+            )
+            .first()
+        )
         if not data_source_api_key_bindings:
             return None
         credentials = json.loads(data_source_api_key_bindings.credentials)
@@ -47,24 +51,24 @@ class ApiKeyAuthService:
 
     @staticmethod
     def delete_provider_auth(tenant_id: str, binding_id: str):
-        data_source_api_key_binding = db.session.query(DataSourceApiKeyAuthBinding).filter(
-            DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
-            DataSourceApiKeyAuthBinding.id == binding_id
-        ).first()
+        data_source_api_key_binding = (
+            db.session.query(DataSourceApiKeyAuthBinding)
+            .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
+            .first()
+        )
         if data_source_api_key_binding:
             db.session.delete(data_source_api_key_binding)
             db.session.commit()
 
     @classmethod
     def validate_api_key_auth_args(cls, args):
-        if 'category' not in args or not args['category']:
-            raise ValueError('category is required')
-        if 'provider' not in args or not args['provider']:
-            raise ValueError('provider is required')
-        if 'credentials' not in args or not args['credentials']:
-            raise ValueError('credentials is required')
-        if not isinstance(args['credentials'], dict):
-            raise ValueError('credentials must be a dictionary')
-        if 'auth_type' not in args['credentials'] or not args['credentials']['auth_type']:
-            raise ValueError('auth_type is required')
-
+        if "category" not in args or not args["category"]:
+            raise ValueError("category is required")
+        if "provider" not in args or not args["provider"]:
+            raise ValueError("provider is required")
+        if "credentials" not in args or not args["credentials"]:
+            raise ValueError("credentials is required")
+        if not isinstance(args["credentials"], dict):
+            raise ValueError("credentials must be a dictionary")
+        if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]:
+            raise ValueError("auth_type is required")

+ 16 - 25
api/services/auth/firecrawl.py

@@ -8,49 +8,40 @@ from services.auth.api_key_auth_base import ApiKeyAuthBase
 class FirecrawlAuth(ApiKeyAuthBase):
     def __init__(self, credentials: dict):
         super().__init__(credentials)
-        auth_type = credentials.get('auth_type')
-        if auth_type != 'bearer':
-            raise ValueError('Invalid auth type, Firecrawl auth type must be Bearer')
-        self.api_key = credentials.get('config').get('api_key', None)
-        self.base_url = credentials.get('config').get('base_url', 'https://api.firecrawl.dev')
+        auth_type = credentials.get("auth_type")
+        if auth_type != "bearer":
+            raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer")
+        self.api_key = credentials.get("config").get("api_key", None)
+        self.base_url = credentials.get("config").get("base_url", "https://api.firecrawl.dev")
 
         if not self.api_key:
-            raise ValueError('No API key provided')
+            raise ValueError("No API key provided")
 
     def validate_credentials(self):
         headers = self._prepare_headers()
         options = {
-            'url': 'https://example.com',
-            'crawlerOptions': {
-                'excludes': [],
-                'includes': [],
-                'limit': 1
-            },
-            'pageOptions': {
-                'onlyMainContent': True
-            }
+            "url": "https://example.com",
+            "crawlerOptions": {"excludes": [], "includes": [], "limit": 1},
+            "pageOptions": {"onlyMainContent": True},
         }
-        response = self._post_request(f'{self.base_url}/v0/crawl', options, headers)
+        response = self._post_request(f"{self.base_url}/v0/crawl", options, headers)
         if response.status_code == 200:
             return True
         else:
             self._handle_error(response)
 
     def _prepare_headers(self):
-        return {
-            'Content-Type': 'application/json',
-            'Authorization': f'Bearer {self.api_key}'
-        }
+        return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
 
     def _post_request(self, url, data, headers):
         return requests.post(url, headers=headers, json=data)
 
     def _handle_error(self, response):
         if response.status_code in [402, 409, 500]:
-            error_message = response.json().get('error', 'Unknown error occurred')
-            raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
+            error_message = response.json().get("error", "Unknown error occurred")
+            raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
         else:
             if response.text:
-                error_message = json.loads(response.text).get('error', 'Unknown error occurred')
-                raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
-            raise Exception(f'Unexpected error occurred while trying to authorize. Status code: {response.status_code}')
+                error_message = json.loads(response.text).get("error", "Unknown error occurred")
+                raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
+            raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")

+ 23 - 40
api/services/billing_service.py

@@ -7,58 +7,40 @@ from models.account import TenantAccountJoin, TenantAccountRole
 
 
 class BillingService:
-    base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL')
-    secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY')
+    base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
+    secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
 
     @classmethod
     def get_info(cls, tenant_id: str):
-        params = {'tenant_id': tenant_id}
+        params = {"tenant_id": tenant_id}
 
-        billing_info = cls._send_request('GET', '/subscription/info', params=params)
+        billing_info = cls._send_request("GET", "/subscription/info", params=params)
 
         return billing_info
 
     @classmethod
-    def get_subscription(cls, plan: str,
-                         interval: str,
-                         prefilled_email: str = '',
-                         tenant_id: str = ''):
-        params = {
-            'plan': plan,
-            'interval': interval,
-            'prefilled_email': prefilled_email,
-            'tenant_id': tenant_id
-        }
-        return cls._send_request('GET', '/subscription/payment-link', params=params)
+    def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""):
+        params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id}
+        return cls._send_request("GET", "/subscription/payment-link", params=params)
 
     @classmethod
-    def get_model_provider_payment_link(cls,
-                                        provider_name: str,
-                                        tenant_id: str,
-                                        account_id: str,
-                                        prefilled_email: str):
+    def get_model_provider_payment_link(cls, provider_name: str, tenant_id: str, account_id: str, prefilled_email: str):
         params = {
-            'provider_name': provider_name,
-            'tenant_id': tenant_id,
-            'account_id': account_id,
-            'prefilled_email': prefilled_email
+            "provider_name": provider_name,
+            "tenant_id": tenant_id,
+            "account_id": account_id,
+            "prefilled_email": prefilled_email,
         }
-        return cls._send_request('GET', '/model-provider/payment-link', params=params)
+        return cls._send_request("GET", "/model-provider/payment-link", params=params)
 
     @classmethod
-    def get_invoices(cls, prefilled_email: str = '', tenant_id: str = ''):
-        params = {
-            'prefilled_email': prefilled_email,
-            'tenant_id': tenant_id
-        }
-        return cls._send_request('GET', '/invoices', params=params)
+    def get_invoices(cls, prefilled_email: str = "", tenant_id: str = ""):
+        params = {"prefilled_email": prefilled_email, "tenant_id": tenant_id}
+        return cls._send_request("GET", "/invoices", params=params)
 
     @classmethod
     def _send_request(cls, method, endpoint, json=None, params=None):
-        headers = {
-            "Content-Type": "application/json",
-            "Billing-Api-Secret-Key": cls.secret_key
-        }
+        headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
 
         url = f"{cls.base_url}{endpoint}"
         response = requests.request(method, url, json=json, params=params, headers=headers)
@@ -69,10 +51,11 @@ class BillingService:
     def is_tenant_owner_or_admin(current_user):
         tenant_id = current_user.current_tenant_id
 
-        join = db.session.query(TenantAccountJoin).filter(
-            TenantAccountJoin.tenant_id == tenant_id,
-            TenantAccountJoin.account_id == current_user.id
-        ).first()
+        join = (
+            db.session.query(TenantAccountJoin)
+            .filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id)
+            .first()
+        )
 
         if not TenantAccountRole.is_privileged_role(join.role):
-            raise ValueError('Only team owner or team admin can perform this action')
+            raise ValueError("Only team owner or team admin can perform this action")

+ 9 - 6
api/services/code_based_extension_service.py

@@ -2,12 +2,15 @@ from extensions.ext_code_based_extension import code_based_extension
 
 
 class CodeBasedExtensionService:
-
     @staticmethod
     def get_code_based_extension(module: str) -> list[dict]:
         module_extensions = code_based_extension.module_extensions(module)
-        return [{
-            'name': module_extension.name,
-            'label': module_extension.label,
-            'form_schema': module_extension.form_schema
-        } for module_extension in module_extensions if not module_extension.builtin]
+        return [
+            {
+                "name": module_extension.name,
+                "label": module_extension.label,
+                "form_schema": module_extension.form_schema,
+            }
+            for module_extension in module_extensions
+            if not module_extension.builtin
+        ]

+ 46 - 33
api/services/conversation_service.py

@@ -15,22 +15,27 @@ from services.errors.message import MessageNotExistsError
 
 class ConversationService:
     @classmethod
-    def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
-                              last_id: Optional[str], limit: int,
-                              invoke_from: InvokeFrom,
-                              include_ids: Optional[list] = None,
-                              exclude_ids: Optional[list] = None,
-                              sort_by: str = '-updated_at') -> InfiniteScrollPagination:
+    def pagination_by_last_id(
+        cls,
+        app_model: App,
+        user: Optional[Union[Account, EndUser]],
+        last_id: Optional[str],
+        limit: int,
+        invoke_from: InvokeFrom,
+        include_ids: Optional[list] = None,
+        exclude_ids: Optional[list] = None,
+        sort_by: str = "-updated_at",
+    ) -> InfiniteScrollPagination:
         if not user:
             return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
 
         base_query = db.session.query(Conversation).filter(
             Conversation.is_deleted == False,
             Conversation.app_id == app_model.id,
-            Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'),
+            Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
             Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
             Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
-            or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value)
+            or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value),
         )
 
         if include_ids is not None:
@@ -58,28 +63,26 @@ class ConversationService:
         has_more = False
         if len(conversations) == limit:
             current_page_last_conversation = conversations[-1]
-            rest_filter_condition = cls._build_filter_condition(sort_field, sort_direction,
-                                                                current_page_last_conversation, is_next_page=True)
+            rest_filter_condition = cls._build_filter_condition(
+                sort_field, sort_direction, current_page_last_conversation, is_next_page=True
+            )
             rest_count = base_query.filter(rest_filter_condition).count()
 
             if rest_count > 0:
                 has_more = True
 
-        return InfiniteScrollPagination(
-            data=conversations,
-            limit=limit,
-            has_more=has_more
-        )
+        return InfiniteScrollPagination(data=conversations, limit=limit, has_more=has_more)
 
     @classmethod
     def _get_sort_params(cls, sort_by: str) -> tuple[str, callable]:
-        if sort_by.startswith('-'):
+        if sort_by.startswith("-"):
             return sort_by[1:], desc
         return sort_by, asc
 
     @classmethod
-    def _build_filter_condition(cls, sort_field: str, sort_direction: callable, reference_conversation: Conversation,
-                                is_next_page: bool = False):
+    def _build_filter_condition(
+        cls, sort_field: str, sort_direction: callable, reference_conversation: Conversation, is_next_page: bool = False
+    ):
         field_value = getattr(reference_conversation, sort_field)
         if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page):
             return getattr(Conversation, sort_field) < field_value
@@ -87,8 +90,14 @@ class ConversationService:
             return getattr(Conversation, sort_field) > field_value
 
     @classmethod
-    def rename(cls, app_model: App, conversation_id: str,
-               user: Optional[Union[Account, EndUser]], name: str, auto_generate: bool):
+    def rename(
+        cls,
+        app_model: App,
+        conversation_id: str,
+        user: Optional[Union[Account, EndUser]],
+        name: str,
+        auto_generate: bool,
+    ):
         conversation = cls.get_conversation(app_model, conversation_id, user)
 
         if auto_generate:
@@ -103,11 +112,12 @@ class ConversationService:
     @classmethod
     def auto_generate_name(cls, app_model: App, conversation: Conversation):
         # get conversation first message
-        message = db.session.query(Message) \
-            .filter(
-            Message.app_id == app_model.id,
-            Message.conversation_id == conversation.id
-        ).order_by(Message.created_at.asc()).first()
+        message = (
+            db.session.query(Message)
+            .filter(Message.app_id == app_model.id, Message.conversation_id == conversation.id)
+            .order_by(Message.created_at.asc())
+            .first()
+        )
 
         if not message:
             raise MessageNotExistsError()
@@ -127,15 +137,18 @@ class ConversationService:
 
     @classmethod
     def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
-        conversation = db.session.query(Conversation) \
+        conversation = (
+            db.session.query(Conversation)
             .filter(
-            Conversation.id == conversation_id,
-            Conversation.app_id == app_model.id,
-            Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'),
-            Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
-            Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
-            Conversation.is_deleted == False
-        ).first()
+                Conversation.id == conversation_id,
+                Conversation.app_id == app_model.id,
+                Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
+                Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
+                Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
+                Conversation.is_deleted == False,
+            )
+            .first()
+        )
 
         if not conversation:
             raise ConversationNotExistsError()

File diff suppressed because it is too large
+ 263 - 278
api/services/dataset_service.py


+ 3 - 6
api/services/enterprise/base.py

@@ -4,15 +4,12 @@ import requests
 
 
 class EnterpriseRequest:
-    base_url = os.environ.get('ENTERPRISE_API_URL', 'ENTERPRISE_API_URL')
-    secret_key = os.environ.get('ENTERPRISE_API_SECRET_KEY', 'ENTERPRISE_API_SECRET_KEY')
+    base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
+    secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY")
 
     @classmethod
     def send_request(cls, method, endpoint, json=None, params=None):
-        headers = {
-            "Content-Type": "application/json",
-            "Enterprise-Api-Secret-Key": cls.secret_key
-        }
+        headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key}
 
         url = f"{cls.base_url}{endpoint}"
         response = requests.request(method, url, json=json, params=params, headers=headers)

+ 2 - 3
api/services/enterprise/enterprise_service.py

@@ -2,11 +2,10 @@ from services.enterprise.base import EnterpriseRequest
 
 
 class EnterpriseService:
-
     @classmethod
     def get_info(cls):
-        return EnterpriseRequest.send_request('GET', '/info')
+        return EnterpriseRequest.send_request("GET", "/info")
 
     @classmethod
     def get_app_web_sso_enabled(cls, app_code):
-        return EnterpriseRequest.send_request('GET', f'/app-sso-setting?appCode={app_code}')
+        return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}")

+ 18 - 20
api/services/entities/model_provider_entities.py

@@ -22,14 +22,16 @@ class CustomConfigurationStatus(Enum):
     """
     Enum class for custom configuration status.
     """
-    ACTIVE = 'active'
-    NO_CONFIGURE = 'no-configure'
+
+    ACTIVE = "active"
+    NO_CONFIGURE = "no-configure"
 
 
 class CustomConfigurationResponse(BaseModel):
     """
     Model class for provider custom configuration response.
     """
+
     status: CustomConfigurationStatus
 
 
@@ -37,6 +39,7 @@ class SystemConfigurationResponse(BaseModel):
     """
     Model class for provider system configuration response.
     """
+
     enabled: bool
     current_quota_type: Optional[ProviderQuotaType] = None
     quota_configurations: list[QuotaConfiguration] = []
@@ -46,6 +49,7 @@ class ProviderResponse(BaseModel):
     """
     Model class for provider response.
     """
+
     provider: str
     label: I18nObject
     description: Optional[I18nObject] = None
@@ -67,18 +71,15 @@ class ProviderResponse(BaseModel):
     def __init__(self, **data) -> None:
         super().__init__(**data)
 
-        url_prefix = (dify_config.CONSOLE_API_URL
-                      + f"/console/api/workspaces/current/model-providers/{self.provider}")
+        url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
         if self.icon_small is not None:
             self.icon_small = I18nObject(
-                en_US=f"{url_prefix}/icon_small/en_US",
-                zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
+                en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
             )
 
         if self.icon_large is not None:
             self.icon_large = I18nObject(
-                en_US=f"{url_prefix}/icon_large/en_US",
-                zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
+                en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
             )
 
 
@@ -86,6 +87,7 @@ class ProviderWithModelsResponse(BaseModel):
     """
     Model class for provider with models response.
     """
+
     provider: str
     label: I18nObject
     icon_small: Optional[I18nObject] = None
@@ -96,18 +98,15 @@ class ProviderWithModelsResponse(BaseModel):
     def __init__(self, **data) -> None:
         super().__init__(**data)
 
-        url_prefix = (dify_config.CONSOLE_API_URL
-                      + f"/console/api/workspaces/current/model-providers/{self.provider}")
+        url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
         if self.icon_small is not None:
             self.icon_small = I18nObject(
-                en_US=f"{url_prefix}/icon_small/en_US",
-                zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
+                en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
             )
 
         if self.icon_large is not None:
             self.icon_large = I18nObject(
-                en_US=f"{url_prefix}/icon_large/en_US",
-                zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
+                en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
             )
 
 
@@ -119,18 +118,15 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
     def __init__(self, **data) -> None:
         super().__init__(**data)
 
-        url_prefix = (dify_config.CONSOLE_API_URL
-                      + f"/console/api/workspaces/current/model-providers/{self.provider}")
+        url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
         if self.icon_small is not None:
             self.icon_small = I18nObject(
-                en_US=f"{url_prefix}/icon_small/en_US",
-                zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
+                en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
             )
 
         if self.icon_large is not None:
             self.icon_large = I18nObject(
-                en_US=f"{url_prefix}/icon_large/en_US",
-                zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
+                en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
             )
 
 
@@ -138,6 +134,7 @@ class DefaultModelResponse(BaseModel):
     """
     Default model entity.
     """
+
     model: str
     model_type: ModelType
     provider: SimpleProviderEntityResponse
@@ -150,6 +147,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity):
     """
     Model with provider entity.
     """
+
     provider: SimpleProviderEntityResponse
 
     def __init__(self, model: ModelWithProviderEntity) -> None:

+ 0 - 1
api/services/errors/account.py

@@ -55,4 +55,3 @@ class RoleAlreadyAssignedError(BaseServiceError):
 
 class RateLimitExceededError(BaseServiceError):
     pass
-

+ 1 - 1
api/services/errors/base.py

@@ -1,3 +1,3 @@
 class BaseServiceError(Exception):
     def __init__(self, description: str = None):
-        self.description = description
+        self.description = description

+ 34 - 34
api/services/feature_service.py

@@ -6,8 +6,8 @@ from services.enterprise.enterprise_service import EnterpriseService
 
 
 class SubscriptionModel(BaseModel):
-    plan: str = 'sandbox'
-    interval: str = ''
+    plan: str = "sandbox"
+    interval: str = ""
 
 
 class BillingModel(BaseModel):
@@ -27,7 +27,7 @@ class FeatureModel(BaseModel):
     vector_space: LimitationModel = LimitationModel(size=0, limit=5)
     annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10)
     documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50)
-    docs_processing: str = 'standard'
+    docs_processing: str = "standard"
     can_replace_logo: bool = False
     model_load_balancing_enabled: bool = False
     dataset_operator_enabled: bool = False
@@ -38,13 +38,13 @@ class FeatureModel(BaseModel):
 
 class SystemFeatureModel(BaseModel):
     sso_enforced_for_signin: bool = False
-    sso_enforced_for_signin_protocol: str = ''
+    sso_enforced_for_signin_protocol: str = ""
     sso_enforced_for_web: bool = False
-    sso_enforced_for_web_protocol: str = ''
+    sso_enforced_for_web_protocol: str = ""
     enable_web_sso_switch_component: bool = False
 
-class FeatureService:
 
+class FeatureService:
     @classmethod
     def get_features(cls, tenant_id: str) -> FeatureModel:
         features = FeatureModel()
@@ -76,44 +76,44 @@ class FeatureService:
     def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
         billing_info = BillingService.get_info(tenant_id)
 
-        features.billing.enabled = billing_info['enabled']
-        features.billing.subscription.plan = billing_info['subscription']['plan']
-        features.billing.subscription.interval = billing_info['subscription']['interval']
+        features.billing.enabled = billing_info["enabled"]
+        features.billing.subscription.plan = billing_info["subscription"]["plan"]
+        features.billing.subscription.interval = billing_info["subscription"]["interval"]
 
-        if 'members' in billing_info:
-            features.members.size = billing_info['members']['size']
-            features.members.limit = billing_info['members']['limit']
+        if "members" in billing_info:
+            features.members.size = billing_info["members"]["size"]
+            features.members.limit = billing_info["members"]["limit"]
 
-        if 'apps' in billing_info:
-            features.apps.size = billing_info['apps']['size']
-            features.apps.limit = billing_info['apps']['limit']
+        if "apps" in billing_info:
+            features.apps.size = billing_info["apps"]["size"]
+            features.apps.limit = billing_info["apps"]["limit"]
 
-        if 'vector_space' in billing_info:
-            features.vector_space.size = billing_info['vector_space']['size']
-            features.vector_space.limit = billing_info['vector_space']['limit']
+        if "vector_space" in billing_info:
+            features.vector_space.size = billing_info["vector_space"]["size"]
+            features.vector_space.limit = billing_info["vector_space"]["limit"]
 
-        if 'documents_upload_quota' in billing_info:
-            features.documents_upload_quota.size = billing_info['documents_upload_quota']['size']
-            features.documents_upload_quota.limit = billing_info['documents_upload_quota']['limit']
+        if "documents_upload_quota" in billing_info:
+            features.documents_upload_quota.size = billing_info["documents_upload_quota"]["size"]
+            features.documents_upload_quota.limit = billing_info["documents_upload_quota"]["limit"]
 
-        if 'annotation_quota_limit' in billing_info:
-            features.annotation_quota_limit.size = billing_info['annotation_quota_limit']['size']
-            features.annotation_quota_limit.limit = billing_info['annotation_quota_limit']['limit']
+        if "annotation_quota_limit" in billing_info:
+            features.annotation_quota_limit.size = billing_info["annotation_quota_limit"]["size"]
+            features.annotation_quota_limit.limit = billing_info["annotation_quota_limit"]["limit"]
 
-        if 'docs_processing' in billing_info:
-            features.docs_processing = billing_info['docs_processing']
+        if "docs_processing" in billing_info:
+            features.docs_processing = billing_info["docs_processing"]
 
-        if 'can_replace_logo' in billing_info:
-            features.can_replace_logo = billing_info['can_replace_logo']
+        if "can_replace_logo" in billing_info:
+            features.can_replace_logo = billing_info["can_replace_logo"]
 
-        if 'model_load_balancing_enabled' in billing_info:
-            features.model_load_balancing_enabled = billing_info['model_load_balancing_enabled']
+        if "model_load_balancing_enabled" in billing_info:
+            features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"]
 
     @classmethod
     def _fulfill_params_from_enterprise(cls, features):
         enterprise_info = EnterpriseService.get_info()
 
-        features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin']
-        features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol']
-        features.sso_enforced_for_web = enterprise_info['sso_enforced_for_web']
-        features.sso_enforced_for_web_protocol = enterprise_info['sso_enforced_for_web_protocol']
+        features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"]
+        features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"]
+        features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"]
+        features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"]

+ 40 - 28
api/services/file_service.py

@@ -17,27 +17,45 @@ from models.account import Account
 from models.model import EndUser, UploadFile
 from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
 
-IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
+IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
 IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
 
-ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'xls', 'docx', 'csv']
-UNSTRUCTURED_ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'xls',
-                                   'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml', 'epub']
+ALLOWED_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
+UNSTRUCTURED_ALLOWED_EXTENSIONS = [
+    "txt",
+    "markdown",
+    "md",
+    "pdf",
+    "html",
+    "htm",
+    "xlsx",
+    "xls",
+    "docx",
+    "csv",
+    "eml",
+    "msg",
+    "pptx",
+    "ppt",
+    "xml",
+    "epub",
+]
 
 PREVIEW_WORDS_LIMIT = 3000
 
 
 class FileService:
-
     @staticmethod
     def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile:
         filename = file.filename
-        extension = file.filename.split('.')[-1]
+        extension = file.filename.split(".")[-1]
         if len(filename) > 200:
-            filename = filename.split('.')[0][:200] + '.' + extension
+            filename = filename.split(".")[0][:200] + "." + extension
         etl_type = dify_config.ETL_TYPE
-        allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if etl_type == 'Unstructured' \
+        allowed_extensions = (
+            UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS
+            if etl_type == "Unstructured"
             else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS
+        )
         if extension.lower() not in allowed_extensions:
             raise UnsupportedFileTypeError()
         elif only_image and extension.lower() not in IMAGE_EXTENSIONS:
@@ -55,7 +73,7 @@ class FileService:
             file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024
 
         if file_size > file_size_limit:
-            message = f'File size exceeded. {file_size} > {file_size_limit}'
+            message = f"File size exceeded. {file_size} > {file_size_limit}"
             raise FileTooLargeError(message)
 
         # user uuid as file name
@@ -67,7 +85,7 @@ class FileService:
             # end_user
             current_tenant_id = user.tenant_id
 
-        file_key = 'upload_files/' + current_tenant_id + '/' + file_uuid + '.' + extension
+        file_key = "upload_files/" + current_tenant_id + "/" + file_uuid + "." + extension
 
         # save file to storage
         storage.save(file_key, file_content)
@@ -81,11 +99,11 @@ class FileService:
             size=file_size,
             extension=extension,
             mime_type=file.mimetype,
-            created_by_role=('account' if isinstance(user, Account) else 'end_user'),
+            created_by_role=("account" if isinstance(user, Account) else "end_user"),
             created_by=user.id,
             created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
             used=False,
-            hash=hashlib.sha3_256(file_content).hexdigest()
+            hash=hashlib.sha3_256(file_content).hexdigest(),
         )
 
         db.session.add(upload_file)
@@ -99,10 +117,10 @@ class FileService:
             text_name = text_name[:200]
         # user uuid as file name
         file_uuid = str(uuid.uuid4())
-        file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.txt'
+        file_key = "upload_files/" + current_user.current_tenant_id + "/" + file_uuid + ".txt"
 
         # save file to storage
-        storage.save(file_key, text.encode('utf-8'))
+        storage.save(file_key, text.encode("utf-8"))
 
         # save file to db
         upload_file = UploadFile(
@@ -111,13 +129,13 @@ class FileService:
             key=file_key,
             name=text_name,
             size=len(text),
-            extension='txt',
-            mime_type='text/plain',
+            extension="txt",
+            mime_type="text/plain",
             created_by=current_user.id,
             created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
             used=True,
             used_by=current_user.id,
-            used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
+            used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
         )
 
         db.session.add(upload_file)
@@ -127,9 +145,7 @@ class FileService:
 
     @staticmethod
     def get_file_preview(file_id: str) -> str:
-        upload_file = db.session.query(UploadFile) \
-            .filter(UploadFile.id == file_id) \
-            .first()
+        upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
 
         if not upload_file:
             raise NotFound("File not found")
@@ -137,12 +153,12 @@ class FileService:
         # extract text from file
         extension = upload_file.extension
         etl_type = dify_config.ETL_TYPE
-        allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
+        allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS
         if extension.lower() not in allowed_extensions:
             raise UnsupportedFileTypeError()
 
         text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True)
-        text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
+        text = text[0:PREVIEW_WORDS_LIMIT] if text else ""
 
         return text
 
@@ -152,9 +168,7 @@ class FileService:
         if not result:
             raise NotFound("File not found or signature is invalid")
 
-        upload_file = db.session.query(UploadFile) \
-            .filter(UploadFile.id == file_id) \
-            .first()
+        upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
 
         if not upload_file:
             raise NotFound("File not found or signature is invalid")
@@ -170,9 +184,7 @@ class FileService:
 
     @staticmethod
     def get_public_image_preview(file_id: str) -> tuple[Generator, str]:
-        upload_file = db.session.query(UploadFile) \
-            .filter(UploadFile.id == file_id) \
-            .first()
+        upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
 
         if not upload_file:
             raise NotFound("File not found or signature is invalid")

+ 40 - 39
api/services/hit_testing_service.py

@@ -9,14 +9,11 @@ from models.account import Account
 from models.dataset import Dataset, DatasetQuery, DocumentSegment
 
 default_retrieval_model = {
-    'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
-    'reranking_enable': False,
-    'reranking_model': {
-        'reranking_provider_name': '',
-        'reranking_model_name': ''
-    },
-    'top_k': 2,
-    'score_threshold_enabled': False
+    "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
+    "reranking_enable": False,
+    "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
+    "top_k": 2,
+    "score_threshold_enabled": False,
 }
 
 
@@ -27,9 +24,9 @@ class HitTestingService:
             return {
                 "query": {
                     "content": query,
-                    "tsne_position": {'x': 0, 'y': 0},
+                    "tsne_position": {"x": 0, "y": 0},
                 },
-                "records": []
+                "records": [],
             }
 
         start = time.perf_counter()
@@ -38,28 +35,28 @@ class HitTestingService:
         if not retrieval_model:
             retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
 
-        all_documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'),
-                                                  dataset_id=dataset.id,
-                                                  query=cls.escape_query_for_search(query),
-                                                  top_k=retrieval_model.get('top_k', 2),
-                                                  score_threshold=retrieval_model.get('score_threshold', .0)
-                                                  if retrieval_model['score_threshold_enabled'] else None,
-                                                  reranking_model=retrieval_model.get('reranking_model', None)
-                                                  if retrieval_model['reranking_enable'] else None,
-                                                  reranking_mode=retrieval_model.get('reranking_mode')
-                                                  if retrieval_model.get('reranking_mode') else 'reranking_model',
-                                                  weights=retrieval_model.get('weights', None),
-                                                  )
+        all_documents = RetrievalService.retrieve(
+            retrival_method=retrieval_model.get("search_method", "semantic_search"),
+            dataset_id=dataset.id,
+            query=cls.escape_query_for_search(query),
+            top_k=retrieval_model.get("top_k", 2),
+            score_threshold=retrieval_model.get("score_threshold", 0.0)
+            if retrieval_model["score_threshold_enabled"]
+            else None,
+            reranking_model=retrieval_model.get("reranking_model", None)
+            if retrieval_model["reranking_enable"]
+            else None,
+            reranking_mode=retrieval_model.get("reranking_mode")
+            if retrieval_model.get("reranking_mode")
+            else "reranking_model",
+            weights=retrieval_model.get("weights", None),
+        )
 
         end = time.perf_counter()
         logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
 
         dataset_query = DatasetQuery(
-            dataset_id=dataset.id,
-            content=query,
-            source='hit_testing',
-            created_by_role='account',
-            created_by=account.id
+            dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id
         )
 
         db.session.add(dataset_query)
@@ -72,14 +69,18 @@ class HitTestingService:
         i = 0
         records = []
         for document in documents:
-            index_node_id = document.metadata['doc_id']
-
-            segment = db.session.query(DocumentSegment).filter(
-                DocumentSegment.dataset_id == dataset.id,
-                DocumentSegment.enabled == True,
-                DocumentSegment.status == 'completed',
-                DocumentSegment.index_node_id == index_node_id
-            ).first()
+            index_node_id = document.metadata["doc_id"]
+
+            segment = (
+                db.session.query(DocumentSegment)
+                .filter(
+                    DocumentSegment.dataset_id == dataset.id,
+                    DocumentSegment.enabled == True,
+                    DocumentSegment.status == "completed",
+                    DocumentSegment.index_node_id == index_node_id,
+                )
+                .first()
+            )
 
             if not segment:
                 i += 1
@@ -87,7 +88,7 @@ class HitTestingService:
 
             record = {
                 "segment": segment,
-                "score": document.metadata.get('score', None),
+                "score": document.metadata.get("score", None),
             }
 
             records.append(record)
@@ -98,15 +99,15 @@ class HitTestingService:
             "query": {
                 "content": query,
             },
-            "records": records
+            "records": records,
         }
 
     @classmethod
     def hit_testing_args_check(cls, args):
-        query = args['query']
+        query = args["query"]
 
         if not query or len(query) > 250:
-            raise ValueError('Query is required and cannot exceed 250 characters')
+            raise ValueError("Query is required and cannot exceed 250 characters")
 
     @staticmethod
     def escape_query_for_search(query: str) -> str:

+ 99 - 93
api/services/message_service.py

@@ -27,8 +27,14 @@ from services.workflow_service import WorkflowService
 
 class MessageService:
     @classmethod
-    def pagination_by_first_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
-                               conversation_id: str, first_id: Optional[str], limit: int) -> InfiniteScrollPagination:
+    def pagination_by_first_id(
+        cls,
+        app_model: App,
+        user: Optional[Union[Account, EndUser]],
+        conversation_id: str,
+        first_id: Optional[str],
+        limit: int,
+    ) -> InfiniteScrollPagination:
         if not user:
             return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
 
@@ -36,52 +42,69 @@ class MessageService:
             return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
 
         conversation = ConversationService.get_conversation(
-            app_model=app_model,
-            user=user,
-            conversation_id=conversation_id
+            app_model=app_model, user=user, conversation_id=conversation_id
         )
 
         if first_id:
-            first_message = db.session.query(Message) \
-                .filter(Message.conversation_id == conversation.id, Message.id == first_id).first()
+            first_message = (
+                db.session.query(Message)
+                .filter(Message.conversation_id == conversation.id, Message.id == first_id)
+                .first()
+            )
 
             if not first_message:
                 raise FirstMessageNotExistsError()
 
-            history_messages = db.session.query(Message).filter(
-                Message.conversation_id == conversation.id,
-                Message.created_at < first_message.created_at,
-                Message.id != first_message.id
-            ) \
-                .order_by(Message.created_at.desc()).limit(limit).all()
+            history_messages = (
+                db.session.query(Message)
+                .filter(
+                    Message.conversation_id == conversation.id,
+                    Message.created_at < first_message.created_at,
+                    Message.id != first_message.id,
+                )
+                .order_by(Message.created_at.desc())
+                .limit(limit)
+                .all()
+            )
         else:
-            history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \
-                .order_by(Message.created_at.desc()).limit(limit).all()
+            history_messages = (
+                db.session.query(Message)
+                .filter(Message.conversation_id == conversation.id)
+                .order_by(Message.created_at.desc())
+                .limit(limit)
+                .all()
+            )
 
         has_more = False
         if len(history_messages) == limit:
             current_page_first_message = history_messages[-1]
-            rest_count = db.session.query(Message).filter(
-                Message.conversation_id == conversation.id,
-                Message.created_at < current_page_first_message.created_at,
-                Message.id != current_page_first_message.id
-            ).count()
+            rest_count = (
+                db.session.query(Message)
+                .filter(
+                    Message.conversation_id == conversation.id,
+                    Message.created_at < current_page_first_message.created_at,
+                    Message.id != current_page_first_message.id,
+                )
+                .count()
+            )
 
             if rest_count > 0:
                 has_more = True
 
         history_messages = list(reversed(history_messages))
 
-        return InfiniteScrollPagination(
-            data=history_messages,
-            limit=limit,
-            has_more=has_more
-        )
+        return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more)
 
     @classmethod
-    def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
-                              last_id: Optional[str], limit: int, conversation_id: Optional[str] = None,
-                              include_ids: Optional[list] = None) -> InfiniteScrollPagination:
+    def pagination_by_last_id(
+        cls,
+        app_model: App,
+        user: Optional[Union[Account, EndUser]],
+        last_id: Optional[str],
+        limit: int,
+        conversation_id: Optional[str] = None,
+        include_ids: Optional[list] = None,
+    ) -> InfiniteScrollPagination:
         if not user:
             return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
 
@@ -89,9 +112,7 @@ class MessageService:
 
         if conversation_id is not None:
             conversation = ConversationService.get_conversation(
-                app_model=app_model,
-                user=user,
-                conversation_id=conversation_id
+                app_model=app_model, user=user, conversation_id=conversation_id
             )
 
             base_query = base_query.filter(Message.conversation_id == conversation.id)
@@ -105,10 +126,12 @@ class MessageService:
             if not last_message:
                 raise LastMessageNotExistsError()
 
-            history_messages = base_query.filter(
-                Message.created_at < last_message.created_at,
-                Message.id != last_message.id
-            ).order_by(Message.created_at.desc()).limit(limit).all()
+            history_messages = (
+                base_query.filter(Message.created_at < last_message.created_at, Message.id != last_message.id)
+                .order_by(Message.created_at.desc())
+                .limit(limit)
+                .all()
+            )
         else:
             history_messages = base_query.order_by(Message.created_at.desc()).limit(limit).all()
 
@@ -116,30 +139,22 @@ class MessageService:
         if len(history_messages) == limit:
             current_page_first_message = history_messages[-1]
             rest_count = base_query.filter(
-                Message.created_at < current_page_first_message.created_at,
-                Message.id != current_page_first_message.id
+                Message.created_at < current_page_first_message.created_at, Message.id != current_page_first_message.id
             ).count()
 
             if rest_count > 0:
                 has_more = True
 
-        return InfiniteScrollPagination(
-            data=history_messages,
-            limit=limit,
-            has_more=has_more
-        )
+        return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more)
 
     @classmethod
-    def create_feedback(cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]],
-                        rating: Optional[str]) -> MessageFeedback:
+    def create_feedback(
+        cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], rating: Optional[str]
+    ) -> MessageFeedback:
         if not user:
-            raise ValueError('user cannot be None')
+            raise ValueError("user cannot be None")
 
-        message = cls.get_message(
-            app_model=app_model,
-            user=user,
-            message_id=message_id
-        )
+        message = cls.get_message(app_model=app_model, user=user, message_id=message_id)
 
         feedback = message.user_feedback if isinstance(user, EndUser) else message.admin_feedback
 
@@ -148,14 +163,14 @@ class MessageService:
         elif rating and feedback:
             feedback.rating = rating
         elif not rating and not feedback:
-            raise ValueError('rating cannot be None when feedback not exists')
+            raise ValueError("rating cannot be None when feedback not exists")
         else:
             feedback = MessageFeedback(
                 app_id=app_model.id,
                 conversation_id=message.conversation_id,
                 message_id=message.id,
                 rating=rating,
-                from_source=('user' if isinstance(user, EndUser) else 'admin'),
+                from_source=("user" if isinstance(user, EndUser) else "admin"),
                 from_end_user_id=(user.id if isinstance(user, EndUser) else None),
                 from_account_id=(user.id if isinstance(user, Account) else None),
             )
@@ -167,13 +182,17 @@ class MessageService:
 
     @classmethod
     def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
-        message = db.session.query(Message).filter(
-            Message.id == message_id,
-            Message.app_id == app_model.id,
-            Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
-            Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
-            Message.from_account_id == (user.id if isinstance(user, Account) else None),
-        ).first()
+        message = (
+            db.session.query(Message)
+            .filter(
+                Message.id == message_id,
+                Message.app_id == app_model.id,
+                Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
+                Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
+                Message.from_account_id == (user.id if isinstance(user, Account) else None),
+            )
+            .first()
+        )
 
         if not message:
             raise MessageNotExistsError()
@@ -181,27 +200,22 @@ class MessageService:
         return message
 
     @classmethod
-    def get_suggested_questions_after_answer(cls, app_model: App, user: Optional[Union[Account, EndUser]],
-                                             message_id: str, invoke_from: InvokeFrom) -> list[Message]:
+    def get_suggested_questions_after_answer(
+        cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str, invoke_from: InvokeFrom
+    ) -> list[Message]:
         if not user:
-            raise ValueError('user cannot be None')
+            raise ValueError("user cannot be None")
 
-        message = cls.get_message(
-            app_model=app_model,
-            user=user,
-            message_id=message_id
-        )
+        message = cls.get_message(app_model=app_model, user=user, message_id=message_id)
 
         conversation = ConversationService.get_conversation(
-            app_model=app_model,
-            conversation_id=message.conversation_id,
-            user=user
+            app_model=app_model, conversation_id=message.conversation_id, user=user
         )
 
         if not conversation:
             raise ConversationNotExistsError()
 
-        if conversation.status != 'normal':
+        if conversation.status != "normal":
             raise ConversationCompletedError()
 
         model_manager = ModelManager()
@@ -216,24 +230,23 @@ class MessageService:
             if workflow is None:
                 return []
 
-            app_config = AdvancedChatAppConfigManager.get_app_config(
-                app_model=app_model,
-                workflow=workflow
-            )
+            app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
 
             if not app_config.additional_features.suggested_questions_after_answer:
                 raise SuggestedQuestionsAfterAnswerDisabledError()
 
             model_instance = model_manager.get_default_model_instance(
-                tenant_id=app_model.tenant_id,
-                model_type=ModelType.LLM
+                tenant_id=app_model.tenant_id, model_type=ModelType.LLM
             )
         else:
             if not conversation.override_model_configs:
-                app_model_config = db.session.query(AppModelConfig).filter(
-                    AppModelConfig.id == conversation.app_model_config_id,
-                    AppModelConfig.app_id == app_model.id
-                ).first()
+                app_model_config = (
+                    db.session.query(AppModelConfig)
+                    .filter(
+                        AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
+                    )
+                    .first()
+                )
             else:
                 conversation_override_model_configs = json.loads(conversation.override_model_configs)
                 app_model_config = AppModelConfig(
@@ -249,16 +262,13 @@ class MessageService:
 
             model_instance = model_manager.get_model_instance(
                 tenant_id=app_model.tenant_id,
-                provider=app_model_config.model_dict['provider'],
+                provider=app_model_config.model_dict["provider"],
                 model_type=ModelType.LLM,
-                model=app_model_config.model_dict['name']
+                model=app_model_config.model_dict["name"],
             )
 
         # get memory of conversation (read-only)
-        memory = TokenBufferMemory(
-            conversation=conversation,
-            model_instance=model_instance
-        )
+        memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
 
         histories = memory.get_history_prompt_text(
             max_token_limit=3000,
@@ -267,18 +277,14 @@ class MessageService:
 
         with measure_time() as timer:
             questions = LLMGenerator.generate_suggested_questions_after_answer(
-                tenant_id=app_model.tenant_id,
-                histories=histories
+                tenant_id=app_model.tenant_id, histories=histories
             )
 
         # get tracing instance
         trace_manager = TraceQueueManager(app_id=app_model.id)
         trace_manager.add_trace_task(
             TraceTask(
-                TraceTaskName.SUGGESTED_QUESTION_TRACE,
-                message_id=message_id,
-                suggested_question=questions,
-                timer=timer
+                TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id=message_id, suggested_question=questions, timer=timer
             )
         )
 

+ 119 - 110
api/services/model_load_balancing_service.py

@@ -23,7 +23,6 @@ logger = logging.getLogger(__name__)
 
 
 class ModelLoadBalancingService:
-
     def __init__(self) -> None:
         self.provider_manager = ProviderManager()
 
@@ -46,10 +45,7 @@ class ModelLoadBalancingService:
             raise ValueError(f"Provider {provider} does not exist.")
 
         # Enable model load balancing
-        provider_configuration.enable_model_load_balancing(
-            model=model,
-            model_type=ModelType.value_of(model_type)
-        )
+        provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
 
     def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
         """
@@ -70,13 +66,11 @@ class ModelLoadBalancingService:
             raise ValueError(f"Provider {provider} does not exist.")
 
         # disable model load balancing
-        provider_configuration.disable_model_load_balancing(
-            model=model,
-            model_type=ModelType.value_of(model_type)
-        )
+        provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
 
-    def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, model_type: str) \
-            -> tuple[bool, list[dict]]:
+    def get_load_balancing_configs(
+        self, tenant_id: str, provider: str, model: str, model_type: str
+    ) -> tuple[bool, list[dict]]:
         """
         Get load balancing configurations.
         :param tenant_id: workspace id
@@ -107,20 +101,24 @@ class ModelLoadBalancingService:
             is_load_balancing_enabled = True
 
         # Get load balancing configurations
-        load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
+        load_balancing_configs = (
+            db.session.query(LoadBalancingModelConfig)
             .filter(
-            LoadBalancingModelConfig.tenant_id == tenant_id,
-            LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
-            LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
-            LoadBalancingModelConfig.model_name == model
-        ).order_by(LoadBalancingModelConfig.created_at).all()
+                LoadBalancingModelConfig.tenant_id == tenant_id,
+                LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
+                LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
+                LoadBalancingModelConfig.model_name == model,
+            )
+            .order_by(LoadBalancingModelConfig.created_at)
+            .all()
+        )
 
         if provider_configuration.custom_configuration.provider:
             # check if the inherit configuration exists,
             # inherit is represented for the provider or model custom credentials
             inherit_config_exists = False
             for load_balancing_config in load_balancing_configs:
-                if load_balancing_config.name == '__inherit__':
+                if load_balancing_config.name == "__inherit__":
                     inherit_config_exists = True
                     break
 
@@ -133,7 +131,7 @@ class ModelLoadBalancingService:
             else:
                 # move the inherit configuration to the first
                 for i, load_balancing_config in enumerate(load_balancing_configs[:]):
-                    if load_balancing_config.name == '__inherit__':
+                    if load_balancing_config.name == "__inherit__":
                         inherit_config = load_balancing_configs.pop(i)
                         load_balancing_configs.insert(0, inherit_config)
 
@@ -151,7 +149,7 @@ class ModelLoadBalancingService:
                 provider=provider,
                 model=model,
                 model_type=model_type,
-                config_id=load_balancing_config.id
+                config_id=load_balancing_config.id,
             )
 
             try:
@@ -172,32 +170,32 @@ class ModelLoadBalancingService:
                 if variable in credentials:
                     try:
                         credentials[variable] = encrypter.decrypt_token_with_decoding(
-                            credentials.get(variable),
-                            decoding_rsa_key,
-                            decoding_cipher_rsa
+                            credentials.get(variable), decoding_rsa_key, decoding_cipher_rsa
                         )
                     except ValueError:
                         pass
 
             # Obfuscate credentials
             credentials = provider_configuration.obfuscated_credentials(
-                credentials=credentials,
-                credential_form_schemas=credential_schemas.credential_form_schemas
+                credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas
             )
 
-            datas.append({
-                'id': load_balancing_config.id,
-                'name': load_balancing_config.name,
-                'credentials': credentials,
-                'enabled': load_balancing_config.enabled,
-                'in_cooldown': in_cooldown,
-                'ttl': ttl
-            })
+            datas.append(
+                {
+                    "id": load_balancing_config.id,
+                    "name": load_balancing_config.name,
+                    "credentials": credentials,
+                    "enabled": load_balancing_config.enabled,
+                    "in_cooldown": in_cooldown,
+                    "ttl": ttl,
+                }
+            )
 
         return is_load_balancing_enabled, datas
 
-    def get_load_balancing_config(self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str) \
-            -> Optional[dict]:
+    def get_load_balancing_config(
+        self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str
+    ) -> Optional[dict]:
         """
         Get load balancing configuration.
         :param tenant_id: workspace id
@@ -219,14 +217,17 @@ class ModelLoadBalancingService:
         model_type = ModelType.value_of(model_type)
 
         # Get load balancing configurations
-        load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \
+        load_balancing_model_config = (
+            db.session.query(LoadBalancingModelConfig)
             .filter(
-            LoadBalancingModelConfig.tenant_id == tenant_id,
-            LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
-            LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
-            LoadBalancingModelConfig.model_name == model,
-            LoadBalancingModelConfig.id == config_id
-        ).first()
+                LoadBalancingModelConfig.tenant_id == tenant_id,
+                LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
+                LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
+                LoadBalancingModelConfig.model_name == model,
+                LoadBalancingModelConfig.id == config_id,
+            )
+            .first()
+        )
 
         if not load_balancing_model_config:
             return None
@@ -244,19 +245,19 @@ class ModelLoadBalancingService:
 
         # Obfuscate credentials
         credentials = provider_configuration.obfuscated_credentials(
-            credentials=credentials,
-            credential_form_schemas=credential_schemas.credential_form_schemas
+            credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas
         )
 
         return {
-            'id': load_balancing_model_config.id,
-            'name': load_balancing_model_config.name,
-            'credentials': credentials,
-            'enabled': load_balancing_model_config.enabled
+            "id": load_balancing_model_config.id,
+            "name": load_balancing_model_config.name,
+            "credentials": credentials,
+            "enabled": load_balancing_model_config.enabled,
         }
 
-    def _init_inherit_config(self, tenant_id: str, provider: str, model: str, model_type: ModelType) \
-            -> LoadBalancingModelConfig:
+    def _init_inherit_config(
+        self, tenant_id: str, provider: str, model: str, model_type: ModelType
+    ) -> LoadBalancingModelConfig:
         """
         Initialize the inherit configuration.
         :param tenant_id: workspace id
@@ -271,18 +272,16 @@ class ModelLoadBalancingService:
             provider_name=provider,
             model_type=model_type.to_origin_model_type(),
             model_name=model,
-            name='__inherit__'
+            name="__inherit__",
         )
         db.session.add(inherit_config)
         db.session.commit()
 
         return inherit_config
 
-    def update_load_balancing_configs(self, tenant_id: str,
-                                      provider: str,
-                                      model: str,
-                                      model_type: str,
-                                      configs: list[dict]) -> None:
+    def update_load_balancing_configs(
+        self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict]
+    ) -> None:
         """
         Update load balancing configurations.
         :param tenant_id: workspace id
@@ -304,15 +303,18 @@ class ModelLoadBalancingService:
         model_type = ModelType.value_of(model_type)
 
         if not isinstance(configs, list):
-            raise ValueError('Invalid load balancing configs')
+            raise ValueError("Invalid load balancing configs")
 
-        current_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
+        current_load_balancing_configs = (
+            db.session.query(LoadBalancingModelConfig)
             .filter(
-            LoadBalancingModelConfig.tenant_id == tenant_id,
-            LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
-            LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
-            LoadBalancingModelConfig.model_name == model
-        ).all()
+                LoadBalancingModelConfig.tenant_id == tenant_id,
+                LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
+                LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
+                LoadBalancingModelConfig.model_name == model,
+            )
+            .all()
+        )
 
         # id as key, config as value
         current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs}
@@ -320,25 +322,25 @@ class ModelLoadBalancingService:
 
         for config in configs:
             if not isinstance(config, dict):
-                raise ValueError('Invalid load balancing config')
+                raise ValueError("Invalid load balancing config")
 
-            config_id = config.get('id')
-            name = config.get('name')
-            credentials = config.get('credentials')
-            enabled = config.get('enabled')
+            config_id = config.get("id")
+            name = config.get("name")
+            credentials = config.get("credentials")
+            enabled = config.get("enabled")
 
             if not name:
-                raise ValueError('Invalid load balancing config name')
+                raise ValueError("Invalid load balancing config name")
 
             if enabled is None:
-                raise ValueError('Invalid load balancing config enabled')
+                raise ValueError("Invalid load balancing config enabled")
 
             # is config exists
             if config_id:
                 config_id = str(config_id)
 
                 if config_id not in current_load_balancing_configs_dict:
-                    raise ValueError('Invalid load balancing config id: {}'.format(config_id))
+                    raise ValueError("Invalid load balancing config id: {}".format(config_id))
 
                 updated_config_ids.add(config_id)
 
@@ -347,11 +349,11 @@ class ModelLoadBalancingService:
                 # check duplicate name
                 for current_load_balancing_config in current_load_balancing_configs:
                     if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name:
-                        raise ValueError('Load balancing config name {} already exists'.format(name))
+                        raise ValueError("Load balancing config name {} already exists".format(name))
 
                 if credentials:
                     if not isinstance(credentials, dict):
-                        raise ValueError('Invalid load balancing config credentials')
+                        raise ValueError("Invalid load balancing config credentials")
 
                     # validate custom provider config
                     credentials = self._custom_credentials_validate(
@@ -361,7 +363,7 @@ class ModelLoadBalancingService:
                         model=model,
                         credentials=credentials,
                         load_balancing_model_config=load_balancing_config,
-                        validate=False
+                        validate=False,
                     )
 
                     # update load balancing config
@@ -375,19 +377,19 @@ class ModelLoadBalancingService:
                 self._clear_credentials_cache(tenant_id, config_id)
             else:
                 # create load balancing config
-                if name == '__inherit__':
-                    raise ValueError('Invalid load balancing config name')
+                if name == "__inherit__":
+                    raise ValueError("Invalid load balancing config name")
 
                 # check duplicate name
                 for current_load_balancing_config in current_load_balancing_configs:
                     if current_load_balancing_config.name == name:
-                        raise ValueError('Load balancing config name {} already exists'.format(name))
+                        raise ValueError("Load balancing config name {} already exists".format(name))
 
                 if not credentials:
-                    raise ValueError('Invalid load balancing config credentials')
+                    raise ValueError("Invalid load balancing config credentials")
 
                 if not isinstance(credentials, dict):
-                    raise ValueError('Invalid load balancing config credentials')
+                    raise ValueError("Invalid load balancing config credentials")
 
                 # validate custom provider config
                 credentials = self._custom_credentials_validate(
@@ -396,7 +398,7 @@ class ModelLoadBalancingService:
                     model_type=model_type,
                     model=model,
                     credentials=credentials,
-                    validate=False
+                    validate=False,
                 )
 
                 # create load balancing config
@@ -406,7 +408,7 @@ class ModelLoadBalancingService:
                     model_type=model_type.to_origin_model_type(),
                     model_name=model,
                     name=name,
-                    encrypted_config=json.dumps(credentials)
+                    encrypted_config=json.dumps(credentials),
                 )
 
                 db.session.add(load_balancing_model_config)
@@ -420,12 +422,15 @@ class ModelLoadBalancingService:
 
             self._clear_credentials_cache(tenant_id, config_id)
 
-    def validate_load_balancing_credentials(self, tenant_id: str,
-                                            provider: str,
-                                            model: str,
-                                            model_type: str,
-                                            credentials: dict,
-                                            config_id: Optional[str] = None) -> None:
+    def validate_load_balancing_credentials(
+        self,
+        tenant_id: str,
+        provider: str,
+        model: str,
+        model_type: str,
+        credentials: dict,
+        config_id: Optional[str] = None,
+    ) -> None:
         """
         Validate load balancing credentials.
         :param tenant_id: workspace id
@@ -450,14 +455,17 @@ class ModelLoadBalancingService:
         load_balancing_model_config = None
         if config_id:
             # Get load balancing config
-            load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \
+            load_balancing_model_config = (
+                db.session.query(LoadBalancingModelConfig)
                 .filter(
-                LoadBalancingModelConfig.tenant_id == tenant_id,
-                LoadBalancingModelConfig.provider_name == provider,
-                LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
-                LoadBalancingModelConfig.model_name == model,
-                LoadBalancingModelConfig.id == config_id
-            ).first()
+                    LoadBalancingModelConfig.tenant_id == tenant_id,
+                    LoadBalancingModelConfig.provider_name == provider,
+                    LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
+                    LoadBalancingModelConfig.model_name == model,
+                    LoadBalancingModelConfig.id == config_id,
+                )
+                .first()
+            )
 
             if not load_balancing_model_config:
                 raise ValueError(f"Load balancing config {config_id} does not exist.")
@@ -469,16 +477,19 @@ class ModelLoadBalancingService:
             model_type=model_type,
             model=model,
             credentials=credentials,
-            load_balancing_model_config=load_balancing_model_config
+            load_balancing_model_config=load_balancing_model_config,
         )
 
-    def _custom_credentials_validate(self, tenant_id: str,
-                                     provider_configuration: ProviderConfiguration,
-                                     model_type: ModelType,
-                                     model: str,
-                                     credentials: dict,
-                                     load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
-                                     validate: bool = True) -> dict:
+    def _custom_credentials_validate(
+        self,
+        tenant_id: str,
+        provider_configuration: ProviderConfiguration,
+        model_type: ModelType,
+        model: str,
+        credentials: dict,
+        load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
+        validate: bool = True,
+    ) -> dict:
         """
         Validate custom credentials.
         :param tenant_id: workspace id
@@ -521,12 +532,11 @@ class ModelLoadBalancingService:
                     provider=provider_configuration.provider.provider,
                     model_type=model_type,
                     model=model,
-                    credentials=credentials
+                    credentials=credentials,
                 )
             else:
                 credentials = model_provider_factory.provider_credentials_validate(
-                    provider=provider_configuration.provider.provider,
-                    credentials=credentials
+                    provider=provider_configuration.provider.provider, credentials=credentials
                 )
 
         for key, value in credentials.items():
@@ -535,8 +545,9 @@ class ModelLoadBalancingService:
 
         return credentials
 
-    def _get_credential_schema(self, provider_configuration: ProviderConfiguration) \
-            -> ModelCredentialSchema | ProviderCredentialSchema:
+    def _get_credential_schema(
+        self, provider_configuration: ProviderConfiguration
+    ) -> ModelCredentialSchema | ProviderCredentialSchema:
         """
         Get form schemas.
         :param provider_configuration: provider configuration
@@ -558,9 +569,7 @@ class ModelLoadBalancingService:
         :return:
         """
         provider_model_credentials_cache = ProviderCredentialsCache(
-            tenant_id=tenant_id,
-            identity_id=config_id,
-            cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL
+            tenant_id=tenant_id, identity_id=config_id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL
         )
 
         provider_model_credentials_cache.delete()

+ 79 - 121
api/services/model_provider_service.py

@@ -73,8 +73,8 @@ class ModelProviderService:
                 system_configuration=SystemConfigurationResponse(
                     enabled=provider_configuration.system_configuration.enabled,
                     current_quota_type=provider_configuration.system_configuration.current_quota_type,
-                    quota_configurations=provider_configuration.system_configuration.quota_configurations
-                )
+                    quota_configurations=provider_configuration.system_configuration.quota_configurations,
+                ),
             )
 
             provider_responses.append(provider_response)
@@ -95,9 +95,9 @@ class ModelProviderService:
         provider_configurations = self.provider_manager.get_configurations(tenant_id)
 
         # Get provider available models
-        return [ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(
-            provider=provider
-        )]
+        return [
+            ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider)
+        ]
 
     def get_provider_credentials(self, tenant_id: str, provider: str) -> dict:
         """
@@ -195,13 +195,12 @@ class ModelProviderService:
 
         # Get model custom credentials from ProviderModel if exists
         return provider_configuration.get_custom_model_credentials(
-            model_type=ModelType.value_of(model_type),
-            model=model,
-            obfuscated=True
+            model_type=ModelType.value_of(model_type), model=model, obfuscated=True
         )
 
-    def model_credentials_validate(self, tenant_id: str, provider: str, model_type: str, model: str,
-                                   credentials: dict) -> None:
+    def model_credentials_validate(
+        self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
+    ) -> None:
         """
         validate model credentials.
 
@@ -222,13 +221,12 @@ class ModelProviderService:
 
         # Validate model credentials
         provider_configuration.custom_model_credentials_validate(
-            model_type=ModelType.value_of(model_type),
-            model=model,
-            credentials=credentials
+            model_type=ModelType.value_of(model_type), model=model, credentials=credentials
         )
 
-    def save_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str,
-                               credentials: dict) -> None:
+    def save_model_credentials(
+        self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
+    ) -> None:
         """
         save model credentials.
 
@@ -249,9 +247,7 @@ class ModelProviderService:
 
         # Add or update custom model credentials
         provider_configuration.add_or_update_custom_model_credentials(
-            model_type=ModelType.value_of(model_type),
-            model=model,
-            credentials=credentials
+            model_type=ModelType.value_of(model_type), model=model, credentials=credentials
         )
 
     def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None:
@@ -273,10 +269,7 @@ class ModelProviderService:
             raise ValueError(f"Provider {provider} does not exist.")
 
         # Remove custom model credentials
-        provider_configuration.delete_custom_model_credentials(
-            model_type=ModelType.value_of(model_type),
-            model=model
-        )
+        provider_configuration.delete_custom_model_credentials(model_type=ModelType.value_of(model_type), model=model)
 
     def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]:
         """
@@ -290,9 +283,7 @@ class ModelProviderService:
         provider_configurations = self.provider_manager.get_configurations(tenant_id)
 
         # Get provider available models
-        models = provider_configurations.get_models(
-            model_type=ModelType.value_of(model_type)
-        )
+        models = provider_configurations.get_models(model_type=ModelType.value_of(model_type))
 
         # Group models by provider
         provider_models = {}
@@ -323,16 +314,19 @@ class ModelProviderService:
                     icon_small=first_model.provider.icon_small,
                     icon_large=first_model.provider.icon_large,
                     status=CustomConfigurationStatus.ACTIVE,
-                    models=[ProviderModelWithStatusEntity(
-                        model=model.model,
-                        label=model.label,
-                        model_type=model.model_type,
-                        features=model.features,
-                        fetch_from=model.fetch_from,
-                        model_properties=model.model_properties,
-                        status=model.status,
-                        load_balancing_enabled=model.load_balancing_enabled
-                    ) for model in models]
+                    models=[
+                        ProviderModelWithStatusEntity(
+                            model=model.model,
+                            label=model.label,
+                            model_type=model.model_type,
+                            features=model.features,
+                            fetch_from=model.fetch_from,
+                            model_properties=model.model_properties,
+                            status=model.status,
+                            load_balancing_enabled=model.load_balancing_enabled,
+                        )
+                        for model in models
+                    ],
                 )
             )
 
@@ -361,19 +355,13 @@ class ModelProviderService:
         model_type_instance = cast(LargeLanguageModel, model_type_instance)
 
         # fetch credentials
-        credentials = provider_configuration.get_current_credentials(
-            model_type=ModelType.LLM,
-            model=model
-        )
+        credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model)
 
         if not credentials:
             return []
 
         # Call get_parameter_rules method of model instance to get model parameter rules
-        return model_type_instance.get_parameter_rules(
-            model=model,
-            credentials=credentials
-        )
+        return model_type_instance.get_parameter_rules(model=model, credentials=credentials)
 
     def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]:
         """
@@ -384,22 +372,23 @@ class ModelProviderService:
         :return:
         """
         model_type_enum = ModelType.value_of(model_type)
-        result = self.provider_manager.get_default_model(
-            tenant_id=tenant_id,
-            model_type=model_type_enum
-        )
+        result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum)
         try:
-            return DefaultModelResponse(
-                model=result.model,
-                model_type=result.model_type,
-                provider=SimpleProviderEntityResponse(
-                    provider=result.provider.provider,
-                    label=result.provider.label,
-                    icon_small=result.provider.icon_small,
-                    icon_large=result.provider.icon_large,
-                    supported_model_types=result.provider.supported_model_types
+            return (
+                DefaultModelResponse(
+                    model=result.model,
+                    model_type=result.model_type,
+                    provider=SimpleProviderEntityResponse(
+                        provider=result.provider.provider,
+                        label=result.provider.label,
+                        icon_small=result.provider.icon_small,
+                        icon_large=result.provider.icon_large,
+                        supported_model_types=result.provider.supported_model_types,
+                    ),
                 )
-            ) if result else None
+                if result
+                else None
+            )
         except Exception as e:
             logger.info(f"get_default_model_of_model_type error: {e}")
             return None
@@ -416,13 +405,12 @@ class ModelProviderService:
         """
         model_type_enum = ModelType.value_of(model_type)
         self.provider_manager.update_default_model_record(
-            tenant_id=tenant_id,
-            model_type=model_type_enum,
-            provider=provider,
-            model=model
+            tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model
         )
 
-    def get_model_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[Optional[bytes], Optional[str]]:
+    def get_model_provider_icon(
+        self, provider: str, icon_type: str, lang: str
+    ) -> tuple[Optional[bytes], Optional[str]]:
         """
         get model provider icon.
 
@@ -434,11 +422,11 @@ class ModelProviderService:
         provider_instance = model_provider_factory.get_provider_instance(provider)
         provider_schema = provider_instance.get_provider_schema()
 
-        if icon_type.lower() == 'icon_small':
+        if icon_type.lower() == "icon_small":
             if not provider_schema.icon_small:
                 raise ValueError(f"Provider {provider} does not have small icon.")
 
-            if lang.lower() == 'zh_hans':
+            if lang.lower() == "zh_hans":
                 file_name = provider_schema.icon_small.zh_Hans
             else:
                 file_name = provider_schema.icon_small.en_US
@@ -446,13 +434,15 @@ class ModelProviderService:
             if not provider_schema.icon_large:
                 raise ValueError(f"Provider {provider} does not have large icon.")
 
-            if lang.lower() == 'zh_hans':
+            if lang.lower() == "zh_hans":
                 file_name = provider_schema.icon_large.zh_Hans
             else:
                 file_name = provider_schema.icon_large.en_US
 
         root_path = current_app.root_path
-        provider_instance_path = os.path.dirname(os.path.join(root_path, provider_instance.__class__.__module__.replace('.', '/')))
+        provider_instance_path = os.path.dirname(
+            os.path.join(root_path, provider_instance.__class__.__module__.replace(".", "/"))
+        )
         file_path = os.path.join(provider_instance_path, "_assets")
         file_path = os.path.join(file_path, file_name)
 
@@ -460,10 +450,10 @@ class ModelProviderService:
             return None, None
 
         mimetype, _ = mimetypes.guess_type(file_path)
-        mimetype = mimetype or 'application/octet-stream'
+        mimetype = mimetype or "application/octet-stream"
 
         # read binary from file
-        with open(file_path, 'rb') as f:
+        with open(file_path, "rb") as f:
             byte_data = f.read()
             return byte_data, mimetype
 
@@ -509,10 +499,7 @@ class ModelProviderService:
             raise ValueError(f"Provider {provider} does not exist.")
 
         # Enable model
-        provider_configuration.enable_model(
-            model=model,
-            model_type=ModelType.value_of(model_type)
-        )
+        provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type))
 
     def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
         """
@@ -533,78 +520,49 @@ class ModelProviderService:
             raise ValueError(f"Provider {provider} does not exist.")
 
         # Enable model
-        provider_configuration.disable_model(
-            model=model,
-            model_type=ModelType.value_of(model_type)
-        )
+        provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type))
 
     def free_quota_submit(self, tenant_id: str, provider: str):
         api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
         api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
-        api_url = api_base_url + '/api/v1/providers/apply'
+        api_url = api_base_url + "/api/v1/providers/apply"
 
-        headers = {
-            'Content-Type': 'application/json',
-            'Authorization': f"Bearer {api_key}"
-        }
-        response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider})
+        headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
+        response = requests.post(api_url, headers=headers, json={"workspace_id": tenant_id, "provider_name": provider})
         if not response.ok:
             logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
             raise ValueError(f"Error: {response.status_code} ")
 
-        if response.json()["code"] != 'success':
-            raise ValueError(
-                f"error: {response.json()['message']}"
-            )
+        if response.json()["code"] != "success":
+            raise ValueError(f"error: {response.json()['message']}")
 
         rst = response.json()
 
-        if rst['type'] == 'redirect':
-            return {
-                'type': rst['type'],
-                'redirect_url': rst['redirect_url']
-            }
+        if rst["type"] == "redirect":
+            return {"type": rst["type"], "redirect_url": rst["redirect_url"]}
         else:
-            return {
-                'type': rst['type'],
-                'result': 'success'
-            }
+            return {"type": rst["type"], "result": "success"}
 
     def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]):
         api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
         api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
-        api_url = api_base_url + '/api/v1/providers/qualification-verify'
+        api_url = api_base_url + "/api/v1/providers/qualification-verify"
 
-        headers = {
-            'Content-Type': 'application/json',
-            'Authorization': f"Bearer {api_key}"
-        }
-        json_data = {'workspace_id': tenant_id, 'provider_name': provider}
+        headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
+        json_data = {"workspace_id": tenant_id, "provider_name": provider}
         if token:
-            json_data['token'] = token
-        response = requests.post(api_url, headers=headers,
-                                 json=json_data)
+            json_data["token"] = token
+        response = requests.post(api_url, headers=headers, json=json_data)
         if not response.ok:
             logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
             raise ValueError(f"Error: {response.status_code} ")
 
         rst = response.json()
-        if rst["code"] != 'success':
-            raise ValueError(
-                f"error: {rst['message']}"
-            )
+        if rst["code"] != "success":
+            raise ValueError(f"error: {rst['message']}")
 
-        data = rst['data']
-        if data['qualified'] is True:
-            return {
-                'result': 'success',
-                'provider_name': provider,
-                'flag': True
-            }
+        data = rst["data"]
+        if data["qualified"] is True:
+            return {"result": "success", "provider_name": provider, "flag": True}
         else:
-            return {
-                'result': 'success',
-                'provider_name': provider,
-                'flag': False,
-                'reason': data['reason']
-            }
+            return {"result": "success", "provider_name": provider, "flag": False, "reason": data["reason"]}

+ 5 - 4
api/services/moderation_service.py

@@ -4,17 +4,18 @@ from models.model import App, AppModelConfig
 
 
 class ModerationService:
-
     def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult:
         app_model_config: AppModelConfig = None
 
-        app_model_config = db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
+        app_model_config = (
+            db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
+        )
 
         if not app_model_config:
             raise ValueError("app model config not found")
 
-        name = app_model_config.sensitive_word_avoidance_dict['type']
-        config = app_model_config.sensitive_word_avoidance_dict['config']
+        name = app_model_config.sensitive_word_avoidance_dict["type"]
+        config = app_model_config.sensitive_word_avoidance_dict["config"]
 
         moderation = ModerationFactory(name, app_id, app_model.tenant_id, config)
         return moderation.moderation_for_outputs(text)

+ 10 - 13
api/services/operation_service.py

@@ -4,15 +4,12 @@ import requests
 
 
 class OperationService:
-    base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL')
-    secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY')
+    base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
+    secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
 
     @classmethod
     def _send_request(cls, method, endpoint, json=None, params=None):
-        headers = {
-            "Content-Type": "application/json",
-            "Billing-Api-Secret-Key": cls.secret_key
-        }
+        headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
 
         url = f"{cls.base_url}{endpoint}"
         response = requests.request(method, url, json=json, params=params, headers=headers)
@@ -22,11 +19,11 @@ class OperationService:
     @classmethod
     def record_utm(cls, tenant_id: str, utm_info: dict):
         params = {
-            'tenant_id': tenant_id,
-            'utm_source': utm_info.get('utm_source', ''),
-            'utm_medium': utm_info.get('utm_medium', ''),
-            'utm_campaign': utm_info.get('utm_campaign', ''),
-            'utm_content': utm_info.get('utm_content', ''),
-            'utm_term': utm_info.get('utm_term', '')
+            "tenant_id": tenant_id,
+            "utm_source": utm_info.get("utm_source", ""),
+            "utm_medium": utm_info.get("utm_medium", ""),
+            "utm_campaign": utm_info.get("utm_campaign", ""),
+            "utm_content": utm_info.get("utm_content", ""),
+            "utm_term": utm_info.get("utm_term", ""),
         }
-        return cls._send_request('POST', '/tenant_utms', params=params)
+        return cls._send_request("POST", "/tenant_utms", params=params)

+ 33 - 19
api/services/ops_service.py

@@ -12,19 +12,25 @@ class OpsService:
         :param tracing_provider: tracing provider
         :return:
         """
-        trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter(
-            TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider
-        ).first()
+        trace_config_data: TraceAppConfig = (
+            db.session.query(TraceAppConfig)
+            .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
+            .first()
+        )
 
         if not trace_config_data:
             return None
 
         # decrypt_token and obfuscated_token
         tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
-        decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(tenant_id, tracing_provider, trace_config_data.tracing_config)
-        if tracing_provider == 'langfuse' and ('project_key' not in decrypt_tracing_config or not decrypt_tracing_config.get('project_key')):
+        decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(
+            tenant_id, tracing_provider, trace_config_data.tracing_config
+        )
+        if tracing_provider == "langfuse" and (
+            "project_key" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_key")
+        ):
             project_key = OpsTraceManager.get_trace_config_project_key(decrypt_tracing_config, tracing_provider)
-            decrypt_tracing_config['project_key'] = project_key
+            decrypt_tracing_config["project_key"] = project_key
 
         decrypt_tracing_config = OpsTraceManager.obfuscated_decrypt_token(tracing_provider, decrypt_tracing_config)
 
@@ -44,8 +50,10 @@ class OpsService:
         if tracing_provider not in provider_config_map.keys() and tracing_provider:
             return {"error": f"Invalid tracing provider: {tracing_provider}"}
 
-        config_class, other_keys = provider_config_map[tracing_provider]['config_class'], \
-            provider_config_map[tracing_provider]['other_keys']
+        config_class, other_keys = (
+            provider_config_map[tracing_provider]["config_class"],
+            provider_config_map[tracing_provider]["other_keys"],
+        )
         default_config_instance = config_class(**tracing_config)
         for key in other_keys:
             if key in tracing_config and tracing_config[key] == "":
@@ -59,9 +67,11 @@ class OpsService:
         project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider)
 
         # check if trace config already exists
-        trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter(
-            TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider
-        ).first()
+        trace_config_data: TraceAppConfig = (
+            db.session.query(TraceAppConfig)
+            .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
+            .first()
+        )
 
         if trace_config_data:
             return None
@@ -69,8 +79,8 @@ class OpsService:
         # get tenant id
         tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
         tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config)
-        if tracing_provider == 'langfuse':
-            tracing_config['project_key'] = project_key
+        if tracing_provider == "langfuse":
+            tracing_config["project_key"] = project_key
         trace_config_data = TraceAppConfig(
             app_id=app_id,
             tracing_provider=tracing_provider,
@@ -94,9 +104,11 @@ class OpsService:
             raise ValueError(f"Invalid tracing provider: {tracing_provider}")
 
         # check if trace config already exists
-        current_trace_config = db.session.query(TraceAppConfig).filter(
-            TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider
-        ).first()
+        current_trace_config = (
+            db.session.query(TraceAppConfig)
+            .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
+            .first()
+        )
 
         if not current_trace_config:
             return None
@@ -126,9 +138,11 @@ class OpsService:
         :param tracing_provider: tracing provider
         :return:
         """
-        trace_config = db.session.query(TraceAppConfig).filter(
-            TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider
-        ).first()
+        trace_config = (
+            db.session.query(TraceAppConfig)
+            .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
+            .first()
+        )
 
         if not trace_config:
             return None

+ 62 - 63
api/services/recommended_app_service.py

@@ -16,7 +16,6 @@ logger = logging.getLogger(__name__)
 
 
 class RecommendedAppService:
-
     builtin_data: Optional[dict] = None
 
     @classmethod
@@ -27,21 +26,21 @@ class RecommendedAppService:
         :return:
         """
         mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
-        if mode == 'remote':
+        if mode == "remote":
             try:
                 result = cls._fetch_recommended_apps_from_dify_official(language)
             except Exception as e:
-                logger.warning(f'fetch recommended apps from dify official failed: {e}, switch to built-in.')
+                logger.warning(f"fetch recommended apps from dify official failed: {e}, switch to built-in.")
                 result = cls._fetch_recommended_apps_from_builtin(language)
-        elif mode == 'db':
+        elif mode == "db":
             result = cls._fetch_recommended_apps_from_db(language)
-        elif mode == 'builtin':
+        elif mode == "builtin":
             result = cls._fetch_recommended_apps_from_builtin(language)
         else:
-            raise ValueError(f'invalid fetch recommended apps mode: {mode}')
+            raise ValueError(f"invalid fetch recommended apps mode: {mode}")
 
-        if not result.get('recommended_apps') and language != 'en-US':
-            result = cls._fetch_recommended_apps_from_builtin('en-US')
+        if not result.get("recommended_apps") and language != "en-US":
+            result = cls._fetch_recommended_apps_from_builtin("en-US")
 
         return result
 
@@ -52,16 +51,18 @@ class RecommendedAppService:
         :param language: language
         :return:
         """
-        recommended_apps = db.session.query(RecommendedApp).filter(
-            RecommendedApp.is_listed == True,
-            RecommendedApp.language == language
-        ).all()
+        recommended_apps = (
+            db.session.query(RecommendedApp)
+            .filter(RecommendedApp.is_listed == True, RecommendedApp.language == language)
+            .all()
+        )
 
         if len(recommended_apps) == 0:
-            recommended_apps = db.session.query(RecommendedApp).filter(
-                RecommendedApp.is_listed == True,
-                RecommendedApp.language == languages[0]
-            ).all()
+            recommended_apps = (
+                db.session.query(RecommendedApp)
+                .filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
+                .all()
+            )
 
         categories = set()
         recommended_apps_result = []
@@ -75,28 +76,28 @@ class RecommendedAppService:
                 continue
 
             recommended_app_result = {
-                'id': recommended_app.id,
-                'app': {
-                    'id': app.id,
-                    'name': app.name,
-                    'mode': app.mode,
-                    'icon': app.icon,
-                    'icon_background': app.icon_background
+                "id": recommended_app.id,
+                "app": {
+                    "id": app.id,
+                    "name": app.name,
+                    "mode": app.mode,
+                    "icon": app.icon,
+                    "icon_background": app.icon_background,
                 },
-                'app_id': recommended_app.app_id,
-                'description': site.description,
-                'copyright': site.copyright,
-                'privacy_policy': site.privacy_policy,
-                'custom_disclaimer': site.custom_disclaimer,
-                'category': recommended_app.category,
-                'position': recommended_app.position,
-                'is_listed': recommended_app.is_listed
+                "app_id": recommended_app.app_id,
+                "description": site.description,
+                "copyright": site.copyright,
+                "privacy_policy": site.privacy_policy,
+                "custom_disclaimer": site.custom_disclaimer,
+                "category": recommended_app.category,
+                "position": recommended_app.position,
+                "is_listed": recommended_app.is_listed,
             }
             recommended_apps_result.append(recommended_app_result)
 
             categories.add(recommended_app.category)  # add category to categories
 
-        return {'recommended_apps': recommended_apps_result, 'categories': sorted(categories)}
+        return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)}
 
     @classmethod
     def _fetch_recommended_apps_from_dify_official(cls, language: str) -> dict:
@@ -106,16 +107,16 @@ class RecommendedAppService:
         :return:
         """
         domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN
-        url = f'{domain}/apps?language={language}'
+        url = f"{domain}/apps?language={language}"
         response = requests.get(url, timeout=(3, 10))
         if response.status_code != 200:
-            raise ValueError(f'fetch recommended apps failed, status code: {response.status_code}')
+            raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}")
 
         result = response.json()
 
         if "categories" in result:
             result["categories"] = sorted(result["categories"])
-        
+
         return result
 
     @classmethod
@@ -126,7 +127,7 @@ class RecommendedAppService:
         :return:
         """
         builtin_data = cls._get_builtin_data()
-        return builtin_data.get('recommended_apps', {}).get(language)
+        return builtin_data.get("recommended_apps", {}).get(language)
 
     @classmethod
     def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]:
@@ -136,18 +137,18 @@ class RecommendedAppService:
         :return:
         """
         mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
-        if mode == 'remote':
+        if mode == "remote":
             try:
                 result = cls._fetch_recommended_app_detail_from_dify_official(app_id)
             except Exception as e:
-                logger.warning(f'fetch recommended app detail from dify official failed: {e}, switch to built-in.')
+                logger.warning(f"fetch recommended app detail from dify official failed: {e}, switch to built-in.")
                 result = cls._fetch_recommended_app_detail_from_builtin(app_id)
-        elif mode == 'db':
+        elif mode == "db":
             result = cls._fetch_recommended_app_detail_from_db(app_id)
-        elif mode == 'builtin':
+        elif mode == "builtin":
             result = cls._fetch_recommended_app_detail_from_builtin(app_id)
         else:
-            raise ValueError(f'invalid fetch recommended app detail mode: {mode}')
+            raise ValueError(f"invalid fetch recommended app detail mode: {mode}")
 
         return result
 
@@ -159,7 +160,7 @@ class RecommendedAppService:
         :return:
         """
         domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN
-        url = f'{domain}/apps/{app_id}'
+        url = f"{domain}/apps/{app_id}"
         response = requests.get(url, timeout=(3, 10))
         if response.status_code != 200:
             return None
@@ -174,10 +175,11 @@ class RecommendedAppService:
         :return:
         """
         # is in public recommended list
-        recommended_app = db.session.query(RecommendedApp).filter(
-            RecommendedApp.is_listed == True,
-            RecommendedApp.app_id == app_id
-        ).first()
+        recommended_app = (
+            db.session.query(RecommendedApp)
+            .filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id)
+            .first()
+        )
 
         if not recommended_app:
             return None
@@ -188,12 +190,12 @@ class RecommendedAppService:
             return None
 
         return {
-            'id': app_model.id,
-            'name': app_model.name,
-            'icon': app_model.icon,
-            'icon_background': app_model.icon_background,
-            'mode': app_model.mode,
-            'export_data': AppDslService.export_dsl(app_model=app_model)
+            "id": app_model.id,
+            "name": app_model.name,
+            "icon": app_model.icon,
+            "icon_background": app_model.icon_background,
+            "mode": app_model.mode,
+            "export_data": AppDslService.export_dsl(app_model=app_model),
         }
 
     @classmethod
@@ -204,7 +206,7 @@ class RecommendedAppService:
         :return:
         """
         builtin_data = cls._get_builtin_data()
-        return builtin_data.get('app_details', {}).get(app_id)
+        return builtin_data.get("app_details", {}).get(app_id)
 
     @classmethod
     def _get_builtin_data(cls) -> dict:
@@ -216,7 +218,7 @@ class RecommendedAppService:
             return cls.builtin_data
 
         root_path = current_app.root_path
-        with open(path.join(root_path, 'constants', 'recommended_apps.json'), encoding='utf-8') as f:
+        with open(path.join(root_path, "constants", "recommended_apps.json"), encoding="utf-8") as f:
             json_data = f.read()
             data = json.loads(json_data)
             cls.builtin_data = data
@@ -229,27 +231,24 @@ class RecommendedAppService:
         Fetch all recommended apps and export datas
         :return:
         """
-        templates = {
-            "recommended_apps": {},
-            "app_details": {}
-        }
+        templates = {"recommended_apps": {}, "app_details": {}}
         for language in languages:
             try:
                 result = cls._fetch_recommended_apps_from_dify_official(language)
             except Exception as e:
-                logger.warning(f'fetch recommended apps from dify official failed: {e}, skip.')
+                logger.warning(f"fetch recommended apps from dify official failed: {e}, skip.")
                 continue
 
-            templates['recommended_apps'][language] = result
+            templates["recommended_apps"][language] = result
 
-            for recommended_app in result.get('recommended_apps'):
-                app_id = recommended_app.get('app_id')
+            for recommended_app in result.get("recommended_apps"):
+                app_id = recommended_app.get("app_id")
 
                 # get app detail
                 app_detail = cls._fetch_recommended_app_detail_from_dify_official(app_id)
                 if not app_detail:
                     continue
 
-                templates['app_details'][app_id] = app_detail
+                templates["app_details"][app_id] = app_detail
 
         return templates

+ 37 - 31
api/services/saved_message_service.py

@@ -10,46 +10,48 @@ from services.message_service import MessageService
 
 class SavedMessageService:
     @classmethod
-    def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
-                              last_id: Optional[str], limit: int) -> InfiniteScrollPagination:
-        saved_messages = db.session.query(SavedMessage).filter(
-            SavedMessage.app_id == app_model.id,
-            SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
-            SavedMessage.created_by == user.id
-        ).order_by(SavedMessage.created_at.desc()).all()
+    def pagination_by_last_id(
+        cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int
+    ) -> InfiniteScrollPagination:
+        saved_messages = (
+            db.session.query(SavedMessage)
+            .filter(
+                SavedMessage.app_id == app_model.id,
+                SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
+                SavedMessage.created_by == user.id,
+            )
+            .order_by(SavedMessage.created_at.desc())
+            .all()
+        )
         message_ids = [sm.message_id for sm in saved_messages]
 
         return MessageService.pagination_by_last_id(
-            app_model=app_model,
-            user=user,
-            last_id=last_id,
-            limit=limit,
-            include_ids=message_ids
+            app_model=app_model, user=user, last_id=last_id, limit=limit, include_ids=message_ids
         )
 
     @classmethod
     def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
-        saved_message = db.session.query(SavedMessage).filter(
-            SavedMessage.app_id == app_model.id,
-            SavedMessage.message_id == message_id,
-            SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
-            SavedMessage.created_by == user.id
-        ).first()
+        saved_message = (
+            db.session.query(SavedMessage)
+            .filter(
+                SavedMessage.app_id == app_model.id,
+                SavedMessage.message_id == message_id,
+                SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
+                SavedMessage.created_by == user.id,
+            )
+            .first()
+        )
 
         if saved_message:
             return
 
-        message = MessageService.get_message(
-            app_model=app_model,
-            user=user,
-            message_id=message_id
-        )
+        message = MessageService.get_message(app_model=app_model, user=user, message_id=message_id)
 
         saved_message = SavedMessage(
             app_id=app_model.id,
             message_id=message.id,
-            created_by_role='account' if isinstance(user, Account) else 'end_user',
-            created_by=user.id
+            created_by_role="account" if isinstance(user, Account) else "end_user",
+            created_by=user.id,
         )
 
         db.session.add(saved_message)
@@ -57,12 +59,16 @@ class SavedMessageService:
 
     @classmethod
     def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
-        saved_message = db.session.query(SavedMessage).filter(
-            SavedMessage.app_id == app_model.id,
-            SavedMessage.message_id == message_id,
-            SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
-            SavedMessage.created_by == user.id
-        ).first()
+        saved_message = (
+            db.session.query(SavedMessage)
+            .filter(
+                SavedMessage.app_id == app_model.id,
+                SavedMessage.message_id == message_id,
+                SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
+                SavedMessage.created_by == user.id,
+            )
+            .first()
+        )
 
         if not saved_message:
             return

+ 58 - 62
api/services/tag_service.py

@@ -12,38 +12,32 @@ from models.model import App, Tag, TagBinding
 class TagService:
     @staticmethod
     def get_tags(tag_type: str, current_tenant_id: str, keyword: str = None) -> list:
-        query = db.session.query(
-            Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label('binding_count')
-        ).outerjoin(
-            TagBinding, Tag.id == TagBinding.tag_id
-        ).filter(
-            Tag.type == tag_type,
-            Tag.tenant_id == current_tenant_id
+        query = (
+            db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
+            .outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
+            .filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
         )
         if keyword:
-            query = query.filter(db.and_(Tag.name.ilike(f'%{keyword}%')))
-        query = query.group_by(
-            Tag.id
-        )
+            query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
+        query = query.group_by(Tag.id)
         results = query.order_by(Tag.created_at.desc()).all()
         return results
 
     @staticmethod
     def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
-        tags = db.session.query(Tag).filter(
-            Tag.id.in_(tag_ids),
-            Tag.tenant_id == current_tenant_id,
-            Tag.type == tag_type
-        ).all()
+        tags = (
+            db.session.query(Tag)
+            .filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
+            .all()
+        )
         if not tags:
             return []
         tag_ids = [tag.id for tag in tags]
-        tag_bindings = db.session.query(
-            TagBinding.target_id
-        ).filter(
-            TagBinding.tag_id.in_(tag_ids),
-            TagBinding.tenant_id == current_tenant_id
-        ).all()
+        tag_bindings = (
+            db.session.query(TagBinding.target_id)
+            .filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
+            .all()
+        )
         if not tag_bindings:
             return []
         results = [tag_binding.target_id for tag_binding in tag_bindings]
@@ -51,27 +45,28 @@ class TagService:
 
     @staticmethod
     def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
-        tags = db.session.query(Tag).join(
-            TagBinding,
-            Tag.id == TagBinding.tag_id
-        ).filter(
-            TagBinding.target_id == target_id,
-            TagBinding.tenant_id == current_tenant_id,
-            Tag.tenant_id == current_tenant_id,
-            Tag.type == tag_type
-        ).all()
+        tags = (
+            db.session.query(Tag)
+            .join(TagBinding, Tag.id == TagBinding.tag_id)
+            .filter(
+                TagBinding.target_id == target_id,
+                TagBinding.tenant_id == current_tenant_id,
+                Tag.tenant_id == current_tenant_id,
+                Tag.type == tag_type,
+            )
+            .all()
+        )
 
         return tags if tags else []
 
-
     @staticmethod
     def save_tags(args: dict) -> Tag:
         tag = Tag(
             id=str(uuid.uuid4()),
-            name=args['name'],
-            type=args['type'],
+            name=args["name"],
+            type=args["type"],
             created_by=current_user.id,
-            tenant_id=current_user.current_tenant_id
+            tenant_id=current_user.current_tenant_id,
         )
         db.session.add(tag)
         db.session.commit()
@@ -82,7 +77,7 @@ class TagService:
         tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
         if not tag:
             raise NotFound("Tag not found")
-        tag.name = args['name']
+        tag.name = args["name"]
         db.session.commit()
         return tag
 
@@ -107,20 +102,21 @@ class TagService:
     @staticmethod
     def save_tag_binding(args):
         # check if target exists
-        TagService.check_target_exists(args['type'], args['target_id'])
+        TagService.check_target_exists(args["type"], args["target_id"])
         # save tag binding
-        for tag_id in args['tag_ids']:
-            tag_binding = db.session.query(TagBinding).filter(
-                TagBinding.tag_id == tag_id,
-                TagBinding.target_id == args['target_id']
-            ).first()
+        for tag_id in args["tag_ids"]:
+            tag_binding = (
+                db.session.query(TagBinding)
+                .filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"])
+                .first()
+            )
             if tag_binding:
                 continue
             new_tag_binding = TagBinding(
                 tag_id=tag_id,
-                target_id=args['target_id'],
+                target_id=args["target_id"],
                 tenant_id=current_user.current_tenant_id,
-                created_by=current_user.id
+                created_by=current_user.id,
             )
             db.session.add(new_tag_binding)
         db.session.commit()
@@ -128,34 +124,34 @@ class TagService:
     @staticmethod
     def delete_tag_binding(args):
         # check if target exists
-        TagService.check_target_exists(args['type'], args['target_id'])
+        TagService.check_target_exists(args["type"], args["target_id"])
         # delete tag binding
-        tag_bindings = db.session.query(TagBinding).filter(
-            TagBinding.target_id == args['target_id'],
-            TagBinding.tag_id == (args['tag_id'])
-        ).first()
+        tag_bindings = (
+            db.session.query(TagBinding)
+            .filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"]))
+            .first()
+        )
         if tag_bindings:
             db.session.delete(tag_bindings)
             db.session.commit()
 
-
-
     @staticmethod
     def check_target_exists(type: str, target_id: str):
-        if type == 'knowledge':
-            dataset = db.session.query(Dataset).filter(
-                Dataset.tenant_id == current_user.current_tenant_id,
-                Dataset.id == target_id
-            ).first()
+        if type == "knowledge":
+            dataset = (
+                db.session.query(Dataset)
+                .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)
+                .first()
+            )
             if not dataset:
                 raise NotFound("Dataset not found")
-        elif type == 'app':
-            app = db.session.query(App).filter(
-                App.tenant_id == current_user.current_tenant_id,
-                App.id == target_id
-            ).first()
+        elif type == "app":
+            app = (
+                db.session.query(App)
+                .filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id)
+                .first()
+            )
             if not app:
                 raise NotFound("App not found")
         else:
             raise NotFound("Invalid binding type")
-

+ 178 - 171
api/services/tools/api_tools_manage_service.py

@@ -29,111 +29,107 @@ class ApiToolManageService:
     @staticmethod
     def parser_api_schema(schema: str) -> list[ApiToolBundle]:
         """
-            parse api schema to tool bundle
+        parse api schema to tool bundle
         """
         try:
             warnings = {}
             try:
                 tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings)
             except Exception as e:
-                raise ValueError(f'invalid schema: {str(e)}')
-            
+                raise ValueError(f"invalid schema: {str(e)}")
+
             credentials_schema = [
                 ToolProviderCredentials(
-                    name='auth_type',
+                    name="auth_type",
                     type=ToolProviderCredentials.CredentialsType.SELECT,
                     required=True,
-                    default='none',
+                    default="none",
                     options=[
-                        ToolCredentialsOption(value='none', label=I18nObject(
-                            en_US='None',
-                            zh_Hans='无'
-                        )),
-                        ToolCredentialsOption(value='api_key', label=I18nObject(
-                            en_US='Api Key',
-                            zh_Hans='Api Key'
-                        )),
+                        ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
+                        ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
                     ],
-                    placeholder=I18nObject(
-                        en_US='Select auth type',
-                        zh_Hans='选择认证方式'
-                    )
+                    placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
                 ),
                 ToolProviderCredentials(
-                    name='api_key_header',
+                    name="api_key_header",
                     type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
                     required=False,
-                    placeholder=I18nObject(
-                        en_US='Enter api key header',
-                        zh_Hans='输入 api key header,如:X-API-KEY'
-                    ),
-                    default='api_key',
-                    help=I18nObject(
-                        en_US='HTTP header name for api key',
-                        zh_Hans='HTTP 头部字段名,用于传递 api key'
-                    )
+                    placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"),
+                    default="api_key",
+                    help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"),
                 ),
                 ToolProviderCredentials(
-                    name='api_key_value',
+                    name="api_key_value",
                     type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
                     required=False,
-                    placeholder=I18nObject(
-                        en_US='Enter api key',
-                        zh_Hans='输入 api key'
-                    ),
-                    default=''
+                    placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"),
+                    default="",
                 ),
             ]
 
-            return jsonable_encoder({
-                'schema_type': schema_type,
-                'parameters_schema': tool_bundles,
-                'credentials_schema': credentials_schema,
-                'warning': warnings
-            })
+            return jsonable_encoder(
+                {
+                    "schema_type": schema_type,
+                    "parameters_schema": tool_bundles,
+                    "credentials_schema": credentials_schema,
+                    "warning": warnings,
+                }
+            )
         except Exception as e:
-            raise ValueError(f'invalid schema: {str(e)}')
+            raise ValueError(f"invalid schema: {str(e)}")
 
     @staticmethod
     def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]:
         """
-            convert schema to tool bundles
+        convert schema to tool bundles
 
-            :return: the list of tool bundles, description
+        :return: the list of tool bundles, description
         """
         try:
             tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
             return tool_bundles
         except Exception as e:
-            raise ValueError(f'invalid schema: {str(e)}')
+            raise ValueError(f"invalid schema: {str(e)}")
 
     @staticmethod
     def create_api_tool_provider(
-        user_id: str, tenant_id: str, provider_name: str, icon: dict, credentials: dict,
-        schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str]
+        user_id: str,
+        tenant_id: str,
+        provider_name: str,
+        icon: dict,
+        credentials: dict,
+        schema_type: str,
+        schema: str,
+        privacy_policy: str,
+        custom_disclaimer: str,
+        labels: list[str],
     ):
         """
-            create api tool provider
+        create api tool provider
         """
         if schema_type not in [member.value for member in ApiProviderSchemaType]:
-            raise ValueError(f'invalid schema type {schema}')
-        
+            raise ValueError(f"invalid schema type {schema}")
+
         # check if the provider exists
-        provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
-            ApiToolProvider.tenant_id == tenant_id,
-            ApiToolProvider.name == provider_name,
-        ).first()
+        provider: ApiToolProvider = (
+            db.session.query(ApiToolProvider)
+            .filter(
+                ApiToolProvider.tenant_id == tenant_id,
+                ApiToolProvider.name == provider_name,
+            )
+            .first()
+        )
 
         if provider is not None:
-            raise ValueError(f'provider {provider_name} already exists')
+            raise ValueError(f"provider {provider_name} already exists")
 
         # parse openapi to tool bundle
         extra_info = {}
         # extra info like description will be set here
         tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
-        
+
         if len(tool_bundles) > 100:
-            raise ValueError('the number of apis should be less than 100')
+            raise ValueError("the number of apis should be less than 100")
 
         # create db provider
         db_provider = ApiToolProvider(
@@ -142,19 +138,19 @@ class ApiToolManageService:
             name=provider_name,
             icon=json.dumps(icon),
             schema=schema,
-            description=extra_info.get('description', ''),
+            description=extra_info.get("description", ""),
             schema_type_str=schema_type,
             tools_str=json.dumps(jsonable_encoder(tool_bundles)),
             credentials_str={},
             privacy_policy=privacy_policy,
-            custom_disclaimer=custom_disclaimer
+            custom_disclaimer=custom_disclaimer,
         )
 
-        if 'auth_type' not in credentials:
-            raise ValueError('auth_type is required')
+        if "auth_type" not in credentials:
+            raise ValueError("auth_type is required")
 
         # get auth type, none or api key
-        auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
+        auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
 
         # create provider entity
         provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
@@ -172,14 +168,12 @@ class ApiToolManageService:
         # update labels
         ToolLabelManager.update_tool_labels(provider_controller, labels)
 
-        return { 'result': 'success' }
-    
+        return {"result": "success"}
+
     @staticmethod
-    def get_api_tool_provider_remote_schema(
-        user_id: str, tenant_id: str, url: str
-    ):
+    def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str):
         """
-            get api tool provider remote schema
+        get api tool provider remote schema
         """
         headers = {
             "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0",
@@ -189,84 +183,98 @@ class ApiToolManageService:
         try:
             response = get(url, headers=headers, timeout=10)
             if response.status_code != 200:
-                raise ValueError(f'Got status code {response.status_code}')
+                raise ValueError(f"Got status code {response.status_code}")
             schema = response.text
 
             # try to parse schema, avoid SSRF attack
             ApiToolManageService.parser_api_schema(schema)
         except Exception as e:
             logger.error(f"parse api schema error: {str(e)}")
-            raise ValueError('invalid schema, please check the url you provided')
-        
-        return {
-            'schema': schema
-        }
+            raise ValueError("invalid schema, please check the url you provided")
+
+        return {"schema": schema}
 
     @staticmethod
-    def list_api_tool_provider_tools(
-        user_id: str, tenant_id: str, provider: str
-    ) -> list[UserTool]:
+    def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:
         """
-            list api tool provider tools
+        list api tool provider tools
         """
-        provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
-            ApiToolProvider.tenant_id == tenant_id,
-            ApiToolProvider.name == provider,
-        ).first()
+        provider: ApiToolProvider = (
+            db.session.query(ApiToolProvider)
+            .filter(
+                ApiToolProvider.tenant_id == tenant_id,
+                ApiToolProvider.name == provider,
+            )
+            .first()
+        )
 
         if provider is None:
-            raise ValueError(f'you have not added provider {provider}')
-        
+            raise ValueError(f"you have not added provider {provider}")
+
         controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
         labels = ToolLabelManager.get_tool_labels(controller)
-        
+
         return [
             ToolTransformService.tool_to_user_tool(
                 tool_bundle,
                 labels=labels,
-            ) for tool_bundle in provider.tools
+            )
+            for tool_bundle in provider.tools
         ]
 
     @staticmethod
     def update_api_tool_provider(
-        user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict, 
-        schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str]
+        user_id: str,
+        tenant_id: str,
+        provider_name: str,
+        original_provider: str,
+        icon: dict,
+        credentials: dict,
+        schema_type: str,
+        schema: str,
+        privacy_policy: str,
+        custom_disclaimer: str,
+        labels: list[str],
     ):
         """
-            update api tool provider
+        update api tool provider
         """
         if schema_type not in [member.value for member in ApiProviderSchemaType]:
-            raise ValueError(f'invalid schema type {schema}')
-        
+            raise ValueError(f"invalid schema type {schema}")
+
         # check if the provider exists
-        provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
-            ApiToolProvider.tenant_id == tenant_id,
-            ApiToolProvider.name == original_provider,
-        ).first()
+        provider: ApiToolProvider = (
+            db.session.query(ApiToolProvider)
+            .filter(
+                ApiToolProvider.tenant_id == tenant_id,
+                ApiToolProvider.name == original_provider,
+            )
+            .first()
+        )
 
         if provider is None:
-            raise ValueError(f'api provider {provider_name} does not exists')
+            raise ValueError(f"api provider {provider_name} does not exists")
 
         # parse openapi to tool bundle
         extra_info = {}
         # extra info like description will be set here
         tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
-        
+
         # update db provider
         provider.name = provider_name
         provider.icon = json.dumps(icon)
         provider.schema = schema
-        provider.description = extra_info.get('description', '')
+        provider.description = extra_info.get("description", "")
         provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value
         provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
         provider.privacy_policy = privacy_policy
         provider.custom_disclaimer = custom_disclaimer
 
-        if 'auth_type' not in credentials:
-            raise ValueError('auth_type is required')
+        if "auth_type" not in credentials:
+            raise ValueError("auth_type is required")
 
         # get auth type, none or api key
-        auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
+        auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
 
         # create provider entity
         provider_controller = ApiToolProviderController.from_db(provider, auth_type)
@@ -295,84 +303,91 @@ class ApiToolManageService:
         # update labels
         ToolLabelManager.update_tool_labels(provider_controller, labels)
 
-        return { 'result': 'success' }
-    
+        return {"result": "success"}
+
     @staticmethod
-    def delete_api_tool_provider(
-        user_id: str, tenant_id: str, provider_name: str
-    ):
+    def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str):
         """
-            delete tool provider
+        delete tool provider
         """
-        provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
-            ApiToolProvider.tenant_id == tenant_id,
-            ApiToolProvider.name == provider_name,
-        ).first()
+        provider: ApiToolProvider = (
+            db.session.query(ApiToolProvider)
+            .filter(
+                ApiToolProvider.tenant_id == tenant_id,
+                ApiToolProvider.name == provider_name,
+            )
+            .first()
+        )
 
         if provider is None:
-            raise ValueError(f'you have not added provider {provider_name}')
-        
+            raise ValueError(f"you have not added provider {provider_name}")
+
         db.session.delete(provider)
         db.session.commit()
 
-        return { 'result': 'success' }
-    
+        return {"result": "success"}
+
     @staticmethod
-    def get_api_tool_provider(
-        user_id: str, tenant_id: str, provider: str
-    ):
+    def get_api_tool_provider(user_id: str, tenant_id: str, provider: str):
         """
-            get api tool provider
+        get api tool provider
         """
         return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id)
-    
+
     @staticmethod
     def test_api_tool_preview(
-        tenant_id: str, 
+        tenant_id: str,
         provider_name: str,
-        tool_name: str, 
-        credentials: dict, 
-        parameters: dict, 
-        schema_type: str, 
-        schema: str
+        tool_name: str,
+        credentials: dict,
+        parameters: dict,
+        schema_type: str,
+        schema: str,
     ):
         """
-            test api tool before adding api tool provider
+        test api tool before adding api tool provider
         """
         if schema_type not in [member.value for member in ApiProviderSchemaType]:
-            raise ValueError(f'invalid schema type {schema_type}')
-        
+            raise ValueError(f"invalid schema type {schema_type}")
+
         try:
             tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema)
         except Exception as e:
-            raise ValueError('invalid schema')
-        
+            raise ValueError("invalid schema")
+
         # get tool bundle
         tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None)
         if tool_bundle is None:
-            raise ValueError(f'invalid tool name {tool_name}')
-        
-        db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
-            ApiToolProvider.tenant_id == tenant_id,
-            ApiToolProvider.name == provider_name,
-        ).first()
+            raise ValueError(f"invalid tool name {tool_name}")
+
+        db_provider: ApiToolProvider = (
+            db.session.query(ApiToolProvider)
+            .filter(
+                ApiToolProvider.tenant_id == tenant_id,
+                ApiToolProvider.name == provider_name,
+            )
+            .first()
+        )
 
         if not db_provider:
             # create a fake db provider
             db_provider = ApiToolProvider(
-                tenant_id='', user_id='', name='', icon='',
+                tenant_id="",
+                user_id="",
+                name="",
+                icon="",
                 schema=schema,
-                description='',
+                description="",
                 schema_type_str=ApiProviderSchemaType.OPENAPI.value,
                 tools_str=json.dumps(jsonable_encoder(tool_bundles)),
                 credentials_str=json.dumps(credentials),
             )
 
-        if 'auth_type' not in credentials:
-            raise ValueError('auth_type is required')
+        if "auth_type" not in credentials:
+            raise ValueError("auth_type is required")
 
         # get auth type, none or api key
-        auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
+        auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
 
         # create provider entity
         provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
@@ -381,10 +396,7 @@ class ApiToolManageService:
 
         # decrypt credentials
         if db_provider.id:
-            tool_configuration = ToolConfigurationManager(
-                tenant_id=tenant_id, 
-                provider_controller=provider_controller
-            )
+            tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
             decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
             # check if the credential has changed, save the original credential
             masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
@@ -396,27 +408,27 @@ class ApiToolManageService:
             provider_controller.validate_credentials_format(credentials)
             # get tool
             tool = provider_controller.get_tool(tool_name)
-            tool = tool.fork_tool_runtime(runtime={
-                'credentials': credentials,
-                'tenant_id': tenant_id,
-            })
+            tool = tool.fork_tool_runtime(
+                runtime={
+                    "credentials": credentials,
+                    "tenant_id": tenant_id,
+                }
+            )
             result = tool.validate_credentials(credentials, parameters)
         except Exception as e:
-            return { 'error': str(e) }
-        
-        return { 'result': result or 'empty response' }
-    
+            return {"error": str(e)}
+
+        return {"result": result or "empty response"}
+
     @staticmethod
-    def list_api_tools(
-        user_id: str, tenant_id: str
-    ) -> list[UserToolProvider]:
+    def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
         """
-            list api tools
+        list api tools
         """
         # get all api providers
-        db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter(
-            ApiToolProvider.tenant_id == tenant_id
-        ).all() or []
+        db_providers: list[ApiToolProvider] = (
+            db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or []
+        )
 
         result: list[UserToolProvider] = []
 
@@ -425,26 +437,21 @@ class ApiToolManageService:
             provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
             labels = ToolLabelManager.get_tool_labels(provider_controller)
             user_provider = ToolTransformService.api_provider_to_user_provider(
-                provider_controller,
-                db_provider=provider,
-                decrypt_credentials=True
+                provider_controller, db_provider=provider, decrypt_credentials=True
             )
             user_provider.labels = labels
 
             # add icon
             ToolTransformService.repack_provider(user_provider)
 
-            tools = provider_controller.get_tools(
-                user_id=user_id, tenant_id=tenant_id
-            )
+            tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
 
             for tool in tools:
-                user_provider.tools.append(ToolTransformService.tool_to_user_tool(
-                    tenant_id=tenant_id,
-                    tool=tool, 
-                    credentials=user_provider.original_credentials, 
-                    labels=labels
-                ))
+                user_provider.tools.append(
+                    ToolTransformService.tool_to_user_tool(
+                        tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
+                    )
+                )
 
             result.append(user_provider)
 

+ 80 - 72
api/services/tools/builtin_tools_manage_service.py

@@ -20,21 +20,25 @@ logger = logging.getLogger(__name__)
 
 class BuiltinToolManageService:
     @staticmethod
-    def list_builtin_tool_provider_tools(
-        user_id: str, tenant_id: str, provider: str
-    ) -> list[UserTool]:
+    def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:
         """
-            list builtin tool provider tools
+        list builtin tool provider tools
         """
         provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
         tools = provider_controller.get_tools()
 
-        tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
+        tool_provider_configurations = ToolConfigurationManager(
+            tenant_id=tenant_id, provider_controller=provider_controller
+        )
         # check if user has added the provider
-        builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
-            BuiltinToolProvider.tenant_id == tenant_id,
-            BuiltinToolProvider.provider == provider,
-        ).first()
+        builtin_provider: BuiltinToolProvider = (
+            db.session.query(BuiltinToolProvider)
+            .filter(
+                BuiltinToolProvider.tenant_id == tenant_id,
+                BuiltinToolProvider.provider == provider,
+            )
+            .first()
+        )
 
         credentials = {}
         if builtin_provider is not None:
@@ -44,47 +48,47 @@ class BuiltinToolManageService:
 
         result = []
         for tool in tools:
-            result.append(ToolTransformService.tool_to_user_tool(
-                tool=tool,
-                credentials=credentials,
-                tenant_id=tenant_id,
-                labels=ToolLabelManager.get_tool_labels(provider_controller)
-            ))
+            result.append(
+                ToolTransformService.tool_to_user_tool(
+                    tool=tool,
+                    credentials=credentials,
+                    tenant_id=tenant_id,
+                    labels=ToolLabelManager.get_tool_labels(provider_controller),
+                )
+            )
 
         return result
 
     @staticmethod
-    def list_builtin_provider_credentials_schema(
-        provider_name
-    ):
+    def list_builtin_provider_credentials_schema(provider_name):
         """
-            list builtin provider credentials schema
+        list builtin provider credentials schema
 
-            :return: the list of tool providers
+        :return: the list of tool providers
         """
         provider = ToolManager.get_builtin_provider(provider_name)
-        return jsonable_encoder([
-            v for _, v in (provider.credentials_schema or {}).items()
-        ])
+        return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()])
 
     @staticmethod
-    def update_builtin_tool_provider(
-        user_id: str, tenant_id: str, provider_name: str, credentials: dict
-    ):
+    def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict):
         """
-            update builtin tool provider
+        update builtin tool provider
         """
         # get if the provider exists
-        provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
-            BuiltinToolProvider.tenant_id == tenant_id,
-            BuiltinToolProvider.provider == provider_name,
-        ).first()
+        provider: BuiltinToolProvider = (
+            db.session.query(BuiltinToolProvider)
+            .filter(
+                BuiltinToolProvider.tenant_id == tenant_id,
+                BuiltinToolProvider.provider == provider_name,
+            )
+            .first()
+        )
 
         try:
             # get provider
             provider_controller = ToolManager.get_builtin_provider(provider_name)
             if not provider_controller.need_credentials:
-                raise ValueError(f'provider {provider_name} does not need credentials')
+                raise ValueError(f"provider {provider_name} does not need credentials")
             tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
             # get original credentials if exists
             if provider is not None:
@@ -121,19 +125,21 @@ class BuiltinToolManageService:
             # delete cache
             tool_configuration.delete_tool_credentials_cache()
 
-        return {'result': 'success'}
+        return {"result": "success"}
 
     @staticmethod
-    def get_builtin_tool_provider_credentials(
-        user_id: str, tenant_id: str, provider: str
-    ):
+    def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str):
         """
-            get builtin tool provider credentials
+        get builtin tool provider credentials
         """
-        provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
-            BuiltinToolProvider.tenant_id == tenant_id,
-            BuiltinToolProvider.provider == provider,
-        ).first()
+        provider: BuiltinToolProvider = (
+            db.session.query(BuiltinToolProvider)
+            .filter(
+                BuiltinToolProvider.tenant_id == tenant_id,
+                BuiltinToolProvider.provider == provider,
+            )
+            .first()
+        )
 
         if provider is None:
             return {}
@@ -145,19 +151,21 @@ class BuiltinToolManageService:
         return credentials
 
     @staticmethod
-    def delete_builtin_tool_provider(
-        user_id: str, tenant_id: str, provider_name: str
-    ):
+    def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str):
         """
-            delete tool provider
+        delete tool provider
         """
-        provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
-            BuiltinToolProvider.tenant_id == tenant_id,
-            BuiltinToolProvider.provider == provider_name,
-        ).first()
+        provider: BuiltinToolProvider = (
+            db.session.query(BuiltinToolProvider)
+            .filter(
+                BuiltinToolProvider.tenant_id == tenant_id,
+                BuiltinToolProvider.provider == provider_name,
+            )
+            .first()
+        )
 
         if provider is None:
-            raise ValueError(f'you have not added provider {provider_name}')
+            raise ValueError(f"you have not added provider {provider_name}")
 
         db.session.delete(provider)
         db.session.commit()
@@ -167,38 +175,36 @@ class BuiltinToolManageService:
         tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
         tool_configuration.delete_tool_credentials_cache()
 
-        return {'result': 'success'}
+        return {"result": "success"}
 
     @staticmethod
-    def get_builtin_tool_provider_icon(
-        provider: str
-    ):
+    def get_builtin_tool_provider_icon(provider: str):
         """
-            get tool provider icon and it's mimetype
+        get tool provider icon and it's mimetype
         """
         icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider)
-        with open(icon_path, 'rb') as f:
+        with open(icon_path, "rb") as f:
             icon_bytes = f.read()
 
         return icon_bytes, mime_type
 
     @staticmethod
-    def list_builtin_tools(
-        user_id: str, tenant_id: str
-    ) -> list[UserToolProvider]:
+    def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
         """
-            list builtin tools
+        list builtin tools
         """
         # get all builtin providers
         provider_controllers = ToolManager.list_builtin_providers()
 
         # get all user added providers
-        db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter(
-            BuiltinToolProvider.tenant_id == tenant_id
-        ).all() or []
+        db_providers: list[BuiltinToolProvider] = (
+            db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
+        )
 
         # find provider
-        find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
+        find_provider = lambda provider: next(
+            filter(lambda db_provider: db_provider.provider == provider, db_providers), None
+        )
 
         result: list[UserToolProvider] = []
 
@@ -209,7 +215,7 @@ class BuiltinToolManageService:
                     include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
                     exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
                     data=provider_controller,
-                    name_func=lambda x: x.identity.name
+                    name_func=lambda x: x.identity.name,
                 ):
                     continue
 
@@ -217,7 +223,7 @@ class BuiltinToolManageService:
                 user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
                     provider_controller=provider_controller,
                     db_provider=find_provider(provider_controller.identity.name),
-                    decrypt_credentials=True
+                    decrypt_credentials=True,
                 )
 
                 # add icon
@@ -225,12 +231,14 @@ class BuiltinToolManageService:
 
                 tools = provider_controller.get_tools()
                 for tool in tools:
-                    user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool(
-                        tenant_id=tenant_id,
-                        tool=tool,
-                        credentials=user_builtin_provider.original_credentials,
-                        labels=ToolLabelManager.get_tool_labels(provider_controller)
-                    ))
+                    user_builtin_provider.tools.append(
+                        ToolTransformService.tool_to_user_tool(
+                            tenant_id=tenant_id,
+                            tool=tool,
+                            credentials=user_builtin_provider.original_credentials,
+                            labels=ToolLabelManager.get_tool_labels(provider_controller),
+                        )
+                    )
 
                 result.append(user_builtin_provider)
             except Exception as e:

+ 1 - 1
api/services/tools/tool_labels_service.py

@@ -5,4 +5,4 @@ from core.tools.entities.values import default_tool_labels
 class ToolLabelsService:
     @classmethod
     def list_tool_labels(cls) -> list[ToolLabel]:
-        return default_tool_labels
+        return default_tool_labels

+ 3 - 6
api/services/tools/tools_manage_service.py

@@ -11,13 +11,11 @@ class ToolCommonService:
     @staticmethod
     def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral = None):
         """
-            list tool providers
+        list tool providers
 
-            :return: the list of tool providers
+        :return: the list of tool providers
         """
-        providers = ToolManager.user_list_providers(
-            user_id, tenant_id, typ
-        )
+        providers = ToolManager.user_list_providers(user_id, tenant_id, typ)
 
         # add icon
         for provider in providers:
@@ -26,4 +24,3 @@ class ToolCommonService:
         result = [provider.to_dict() for provider in providers]
 
         return result
-    

+ 47 - 63
api/services/tools/tools_transform_service.py

@@ -22,46 +22,39 @@ from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvi
 
 logger = logging.getLogger(__name__)
 
+
 class ToolTransformService:
     @staticmethod
     def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str) -> Union[str, dict]:
         """
-            get tool provider icon url
+        get tool provider icon url
         """
-        url_prefix = (dify_config.CONSOLE_API_URL
-                      + "/console/api/workspaces/current/tool-provider/")
-        
+        url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/"
+
         if provider_type == ToolProviderType.BUILT_IN.value:
-            return url_prefix + 'builtin/' + provider_name + '/icon'
+            return url_prefix + "builtin/" + provider_name + "/icon"
         elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]:
             try:
                 return json.loads(icon)
             except:
-                return {
-                    "background": "#252525",
-                    "content": "\ud83d\ude01"
-                }
-        
-        return ''
-        
+                return {"background": "#252525", "content": "\ud83d\ude01"}
+
+        return ""
+
     @staticmethod
     def repack_provider(provider: Union[dict, UserToolProvider]):
         """
-            repack provider
+        repack provider
 
-            :param provider: the provider dict
+        :param provider: the provider dict
         """
-        if isinstance(provider, dict) and 'icon' in provider:
-            provider['icon'] = ToolTransformService.get_tool_provider_icon_url(
-                provider_type=provider['type'],
-                provider_name=provider['name'],
-                icon=provider['icon']
+        if isinstance(provider, dict) and "icon" in provider:
+            provider["icon"] = ToolTransformService.get_tool_provider_icon_url(
+                provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
             )
         elif isinstance(provider, UserToolProvider):
             provider.icon = ToolTransformService.get_tool_provider_icon_url(
-                provider_type=provider.type.value,
-                provider_name=provider.name,
-                icon=provider.icon
+                provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
             )
 
     @staticmethod
@@ -92,14 +85,13 @@ class ToolTransformService:
             masked_credentials={},
             is_team_authorization=False,
             tools=[],
-            labels=provider_controller.tool_labels
+            labels=provider_controller.tool_labels,
         )
 
         # get credentials schema
         schema = provider_controller.get_credentials_schema()
         for name, value in schema.items():
-            result.masked_credentials[name] = \
-                ToolProviderCredentials.CredentialsType.default(value.type)
+            result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type)
 
         # check if the provider need credentials
         if not provider_controller.need_credentials:
@@ -113,8 +105,7 @@ class ToolTransformService:
 
                 # init tool configuration
                 tool_configuration = ToolConfigurationManager(
-                    tenant_id=db_provider.tenant_id, 
-                    provider_controller=provider_controller
+                    tenant_id=db_provider.tenant_id, provider_controller=provider_controller
                 )
                 # decrypt the credentials and mask the credentials
                 decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
@@ -124,7 +115,7 @@ class ToolTransformService:
                 result.original_credentials = decrypted_credentials
 
         return result
-    
+
     @staticmethod
     def api_provider_to_controller(
         db_provider: ApiToolProvider,
@@ -135,25 +126,23 @@ class ToolTransformService:
         # package tool provider controller
         controller = ApiToolProviderController.from_db(
             db_provider=db_provider,
-            auth_type=ApiProviderAuthType.API_KEY if db_provider.credentials['auth_type'] == 'api_key' else 
-            ApiProviderAuthType.NONE
+            auth_type=ApiProviderAuthType.API_KEY
+            if db_provider.credentials["auth_type"] == "api_key"
+            else ApiProviderAuthType.NONE,
         )
 
         return controller
-    
+
     @staticmethod
-    def workflow_provider_to_controller(
-        db_provider: WorkflowToolProvider
-    ) -> WorkflowToolProviderController:
+    def workflow_provider_to_controller(db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
         """
         convert provider controller to provider
         """
         return WorkflowToolProviderController.from_db(db_provider)
-    
+
     @staticmethod
     def workflow_provider_to_user_provider(
-        provider_controller: WorkflowToolProviderController,
-        labels: list[str] = None
+        provider_controller: WorkflowToolProviderController, labels: list[str] = None
     ):
         """
         convert provider controller to user provider
@@ -175,7 +164,7 @@ class ToolTransformService:
             masked_credentials={},
             is_team_authorization=True,
             tools=[],
-            labels=labels or []
+            labels=labels or [],
         )
 
     @staticmethod
@@ -183,16 +172,16 @@ class ToolTransformService:
         provider_controller: ApiToolProviderController,
         db_provider: ApiToolProvider,
         decrypt_credentials: bool = True,
-        labels: list[str] = None
+        labels: list[str] = None,
     ) -> UserToolProvider:
         """
         convert provider controller to user provider
         """
-        username = 'Anonymous'
+        username = "Anonymous"
         try:
             username = db_provider.user.name
         except Exception as e:
-            logger.error(f'failed to get user name for api provider {db_provider.id}: {str(e)}')
+            logger.error(f"failed to get user name for api provider {db_provider.id}: {str(e)}")
         # add provider into providers
         credentials = db_provider.credentials
         result = UserToolProvider(
@@ -212,14 +201,13 @@ class ToolTransformService:
             masked_credentials={},
             is_team_authorization=True,
             tools=[],
-            labels=labels or []
+            labels=labels or [],
         )
 
         if decrypt_credentials:
             # init tool configuration
             tool_configuration = ToolConfigurationManager(
-                tenant_id=db_provider.tenant_id, 
-                provider_controller=provider_controller
+                tenant_id=db_provider.tenant_id, provider_controller=provider_controller
             )
 
             # decrypt the credentials and mask the credentials
@@ -229,23 +217,25 @@ class ToolTransformService:
             result.masked_credentials = masked_credentials
 
         return result
-    
+
     @staticmethod
     def tool_to_user_tool(
-        tool: Union[ApiToolBundle, WorkflowTool, Tool], 
-        credentials: dict = None, 
+        tool: Union[ApiToolBundle, WorkflowTool, Tool],
+        credentials: dict = None,
         tenant_id: str = None,
-        labels: list[str] = None
+        labels: list[str] = None,
     ) -> UserTool:
         """
         convert tool to user tool
         """
         if isinstance(tool, Tool):
             # fork tool runtime
-            tool = tool.fork_tool_runtime(runtime={
-                'credentials': credentials,
-                'tenant_id': tenant_id,
-            })
+            tool = tool.fork_tool_runtime(
+                runtime={
+                    "credentials": credentials,
+                    "tenant_id": tenant_id,
+                }
+            )
 
             # get tool parameters
             parameters = tool.parameters or []
@@ -270,20 +260,14 @@ class ToolTransformService:
                 label=tool.identity.label,
                 description=tool.description.human,
                 parameters=current_parameters,
-                labels=labels
+                labels=labels,
             )
         if isinstance(tool, ApiToolBundle):
             return UserTool(
                 author=tool.author,
                 name=tool.operation_id,
-                label=I18nObject(
-                    en_US=tool.operation_id,
-                    zh_Hans=tool.operation_id
-                ),
-                description=I18nObject(
-                    en_US=tool.summary or '',
-                    zh_Hans=tool.summary or ''
-                ),
+                label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
+                description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
                 parameters=tool.parameters,
-                labels=labels
-            )
+                labels=labels,
+            )

+ 127 - 120
api/services/tools/workflow_tools_manage_service.py

@@ -19,10 +19,21 @@ class WorkflowToolManageService:
     """
     Service class for managing workflow tools.
     """
+
     @classmethod
-    def create_workflow_tool(cls, user_id: str, tenant_id: str, workflow_app_id: str, name: str, 
-                                label: str, icon: dict, description: str,
-                                parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict:
+    def create_workflow_tool(
+        cls,
+        user_id: str,
+        tenant_id: str,
+        workflow_app_id: str,
+        name: str,
+        label: str,
+        icon: dict,
+        description: str,
+        parameters: list[dict],
+        privacy_policy: str = "",
+        labels: list[str] = None,
+    ) -> dict:
         """
         Create a workflow tool.
         :param user_id: the user id
@@ -38,27 +49,28 @@ class WorkflowToolManageService:
         WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
 
         # check if the name is unique
-        existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter(
-            WorkflowToolProvider.tenant_id == tenant_id,
-            # name or app_id
-            or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id)
-        ).first()
+        existing_workflow_tool_provider = (
+            db.session.query(WorkflowToolProvider)
+            .filter(
+                WorkflowToolProvider.tenant_id == tenant_id,
+                # name or app_id
+                or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id),
+            )
+            .first()
+        )
 
         if existing_workflow_tool_provider is not None:
-            raise ValueError(f'Tool with name {name} or app_id {workflow_app_id} already exists')
-        
-        app: App = db.session.query(App).filter(
-            App.id == workflow_app_id,
-            App.tenant_id == tenant_id
-        ).first()
+            raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
+
+        app: App = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
 
         if app is None:
-            raise ValueError(f'App {workflow_app_id} not found')
-        
+            raise ValueError(f"App {workflow_app_id} not found")
+
         workflow: Workflow = app.workflow
         if workflow is None:
-            raise ValueError(f'Workflow not found for app {workflow_app_id}')
-        
+            raise ValueError(f"Workflow not found for app {workflow_app_id}")
+
         workflow_tool_provider = WorkflowToolProvider(
             tenant_id=tenant_id,
             user_id=user_id,
@@ -76,19 +88,26 @@ class WorkflowToolManageService:
             WorkflowToolProviderController.from_db(workflow_tool_provider)
         except Exception as e:
             raise ValueError(str(e))
-        
+
         db.session.add(workflow_tool_provider)
         db.session.commit()
 
-        return {
-            'result': 'success'
-        }
-
+        return {"result": "success"}
 
     @classmethod
-    def update_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str, 
-                             name: str, label: str, icon: dict, description: str, 
-                             parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict:
+    def update_workflow_tool(
+        cls,
+        user_id: str,
+        tenant_id: str,
+        workflow_tool_id: str,
+        name: str,
+        label: str,
+        icon: dict,
+        description: str,
+        parameters: list[dict],
+        privacy_policy: str = "",
+        labels: list[str] = None,
+    ) -> dict:
         """
         Update a workflow tool.
         :param user_id: the user id
@@ -106,35 +125,39 @@ class WorkflowToolManageService:
         WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
 
         # check if the name is unique
-        existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter(
-            WorkflowToolProvider.tenant_id == tenant_id,
-            WorkflowToolProvider.name == name,
-            WorkflowToolProvider.id != workflow_tool_id
-        ).first()
+        existing_workflow_tool_provider = (
+            db.session.query(WorkflowToolProvider)
+            .filter(
+                WorkflowToolProvider.tenant_id == tenant_id,
+                WorkflowToolProvider.name == name,
+                WorkflowToolProvider.id != workflow_tool_id,
+            )
+            .first()
+        )
 
         if existing_workflow_tool_provider is not None:
-            raise ValueError(f'Tool with name {name} already exists')
-        
-        workflow_tool_provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
-            WorkflowToolProvider.tenant_id == tenant_id,
-            WorkflowToolProvider.id == workflow_tool_id
-        ).first()
+            raise ValueError(f"Tool with name {name} already exists")
+
+        workflow_tool_provider: WorkflowToolProvider = (
+            db.session.query(WorkflowToolProvider)
+            .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
+            .first()
+        )
 
         if workflow_tool_provider is None:
-            raise ValueError(f'Tool {workflow_tool_id} not found')
-        
-        app: App = db.session.query(App).filter(
-            App.id == workflow_tool_provider.app_id,
-            App.tenant_id == tenant_id
-        ).first()
+            raise ValueError(f"Tool {workflow_tool_id} not found")
+
+        app: App = (
+            db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
+        )
 
         if app is None:
-            raise ValueError(f'App {workflow_tool_provider.app_id} not found')
-        
+            raise ValueError(f"App {workflow_tool_provider.app_id} not found")
+
         workflow: Workflow = app.workflow
         if workflow is None:
-            raise ValueError(f'Workflow not found for app {workflow_tool_provider.app_id}')
-        
+            raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
+
         workflow_tool_provider.name = name
         workflow_tool_provider.label = label
         workflow_tool_provider.icon = json.dumps(icon)
@@ -154,13 +177,10 @@ class WorkflowToolManageService:
 
         if labels is not None:
             ToolLabelManager.update_tool_labels(
-                ToolTransformService.workflow_provider_to_controller(workflow_tool_provider),
-                labels
+                ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
             )
 
-        return {
-            'result': 'success'
-        }
+        return {"result": "success"}
 
     @classmethod
     def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]:
@@ -170,9 +190,7 @@ class WorkflowToolManageService:
         :param tenant_id: the tenant id
         :return: the list of tools
         """
-        db_tools = db.session.query(WorkflowToolProvider).filter(
-            WorkflowToolProvider.tenant_id == tenant_id
-        ).all()
+        db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
 
         tools = []
         for provider in db_tools:
@@ -188,14 +206,12 @@ class WorkflowToolManageService:
 
         for tool in tools:
             user_tool_provider = ToolTransformService.workflow_provider_to_user_provider(
-                provider_controller=tool,
-                labels=labels.get(tool.provider_id, [])
+                provider_controller=tool, labels=labels.get(tool.provider_id, [])
             )
             ToolTransformService.repack_provider(user_tool_provider)
             user_tool_provider.tools = [
                 ToolTransformService.tool_to_user_tool(
-                    tool.get_tools(user_id, tenant_id)[0],
-                    labels=labels.get(tool.provider_id, [])
+                    tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, [])
                 )
             ]
             result.append(user_tool_provider)
@@ -211,15 +227,12 @@ class WorkflowToolManageService:
         :param workflow_app_id: the workflow app id
         """
         db.session.query(WorkflowToolProvider).filter(
-            WorkflowToolProvider.tenant_id == tenant_id,
-            WorkflowToolProvider.id == workflow_tool_id
+            WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
         ).delete()
 
         db.session.commit()
 
-        return {
-            'result': 'success'
-        }
+        return {"result": "success"}
 
     @classmethod
     def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
@@ -230,40 +243,37 @@ class WorkflowToolManageService:
         :param workflow_app_id: the workflow app id
         :return: the tool
         """
-        db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
-            WorkflowToolProvider.tenant_id == tenant_id,
-            WorkflowToolProvider.id == workflow_tool_id
-        ).first()
+        db_tool: WorkflowToolProvider = (
+            db.session.query(WorkflowToolProvider)
+            .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
+            .first()
+        )
 
         if db_tool is None:
-            raise ValueError(f'Tool {workflow_tool_id} not found')
-        
-        workflow_app: App = db.session.query(App).filter(
-            App.id == db_tool.app_id,
-            App.tenant_id == tenant_id
-        ).first()
+            raise ValueError(f"Tool {workflow_tool_id} not found")
+
+        workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
 
         if workflow_app is None:
-            raise ValueError(f'App {db_tool.app_id} not found')
+            raise ValueError(f"App {db_tool.app_id} not found")
 
         tool = ToolTransformService.workflow_provider_to_controller(db_tool)
 
         return {
-            'name': db_tool.name,
-            'label': db_tool.label,
-            'workflow_tool_id': db_tool.id,
-            'workflow_app_id': db_tool.app_id,
-            'icon': json.loads(db_tool.icon),
-            'description': db_tool.description,
-            'parameters': jsonable_encoder(db_tool.parameter_configurations),
-            'tool': ToolTransformService.tool_to_user_tool(
-                tool.get_tools(user_id, tenant_id)[0],
-                labels=ToolLabelManager.get_tool_labels(tool)
+            "name": db_tool.name,
+            "label": db_tool.label,
+            "workflow_tool_id": db_tool.id,
+            "workflow_app_id": db_tool.app_id,
+            "icon": json.loads(db_tool.icon),
+            "description": db_tool.description,
+            "parameters": jsonable_encoder(db_tool.parameter_configurations),
+            "tool": ToolTransformService.tool_to_user_tool(
+                tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
             ),
-            'synced': workflow_app.workflow.version == db_tool.version,
-            'privacy_policy': db_tool.privacy_policy,
+            "synced": workflow_app.workflow.version == db_tool.version,
+            "privacy_policy": db_tool.privacy_policy,
         }
-    
+
     @classmethod
     def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
         """
@@ -273,40 +283,37 @@ class WorkflowToolManageService:
         :param workflow_app_id: the workflow app id
         :return: the tool
         """
-        db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
-            WorkflowToolProvider.tenant_id == tenant_id,
-            WorkflowToolProvider.app_id == workflow_app_id
-        ).first()
+        db_tool: WorkflowToolProvider = (
+            db.session.query(WorkflowToolProvider)
+            .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
+            .first()
+        )
 
         if db_tool is None:
-            raise ValueError(f'Tool {workflow_app_id} not found')
-        
-        workflow_app: App = db.session.query(App).filter(
-            App.id == db_tool.app_id,
-            App.tenant_id == tenant_id
-        ).first()
+            raise ValueError(f"Tool {workflow_app_id} not found")
+
+        workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
 
         if workflow_app is None:
-            raise ValueError(f'App {db_tool.app_id} not found')
+            raise ValueError(f"App {db_tool.app_id} not found")
 
         tool = ToolTransformService.workflow_provider_to_controller(db_tool)
 
         return {
-            'name': db_tool.name,
-            'label': db_tool.label,
-            'workflow_tool_id': db_tool.id,
-            'workflow_app_id': db_tool.app_id,
-            'icon': json.loads(db_tool.icon),
-            'description': db_tool.description,
-            'parameters': jsonable_encoder(db_tool.parameter_configurations),
-            'tool': ToolTransformService.tool_to_user_tool(
-                tool.get_tools(user_id, tenant_id)[0],
-                labels=ToolLabelManager.get_tool_labels(tool)
+            "name": db_tool.name,
+            "label": db_tool.label,
+            "workflow_tool_id": db_tool.id,
+            "workflow_app_id": db_tool.app_id,
+            "icon": json.loads(db_tool.icon),
+            "description": db_tool.description,
+            "parameters": jsonable_encoder(db_tool.parameter_configurations),
+            "tool": ToolTransformService.tool_to_user_tool(
+                tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
             ),
-            'synced': workflow_app.workflow.version == db_tool.version,
-            'privacy_policy': db_tool.privacy_policy
+            "synced": workflow_app.workflow.version == db_tool.version,
+            "privacy_policy": db_tool.privacy_policy,
         }
-    
+
     @classmethod
     def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]:
         """
@@ -316,19 +323,19 @@ class WorkflowToolManageService:
         :param workflow_app_id: the workflow app id
         :return: the list of tools
         """
-        db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
-            WorkflowToolProvider.tenant_id == tenant_id,
-            WorkflowToolProvider.id == workflow_tool_id
-        ).first()
+        db_tool: WorkflowToolProvider = (
+            db.session.query(WorkflowToolProvider)
+            .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
+            .first()
+        )
 
         if db_tool is None:
-            raise ValueError(f'Tool {workflow_tool_id} not found')
+            raise ValueError(f"Tool {workflow_tool_id} not found")
 
         tool = ToolTransformService.workflow_provider_to_controller(db_tool)
 
         return [
             ToolTransformService.tool_to_user_tool(
-                tool.get_tools(user_id, tenant_id)[0],
-                labels=ToolLabelManager.get_tool_labels(tool)
+                tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
             )
-        ]
+        ]

+ 9 - 13
api/services/vector_service.py

@@ -7,10 +7,10 @@ from models.dataset import Dataset, DocumentSegment
 
 
 class VectorService:
-
     @classmethod
-    def create_segments_vector(cls, keywords_list: Optional[list[list[str]]],
-                               segments: list[DocumentSegment], dataset: Dataset):
+    def create_segments_vector(
+        cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset
+    ):
         documents = []
         for segment in segments:
             document = Document(
@@ -20,14 +20,12 @@ class VectorService:
                     "doc_hash": segment.index_node_hash,
                     "document_id": segment.document_id,
                     "dataset_id": segment.dataset_id,
-                }
+                },
             )
             documents.append(document)
-        if dataset.indexing_technique == 'high_quality':
+        if dataset.indexing_technique == "high_quality":
             # save vector index
-            vector = Vector(
-                dataset=dataset
-            )
+            vector = Vector(dataset=dataset)
             vector.add_texts(documents, duplicate_check=True)
 
         # save keyword index
@@ -50,13 +48,11 @@ class VectorService:
                 "doc_hash": segment.index_node_hash,
                 "document_id": segment.document_id,
                 "dataset_id": segment.dataset_id,
-            }
+            },
         )
-        if dataset.indexing_technique == 'high_quality':
+        if dataset.indexing_technique == "high_quality":
             # update vector index
-            vector = Vector(
-                dataset=dataset
-            )
+            vector = Vector(dataset=dataset)
             vector.delete_by_ids([segment.index_node_id])
             vector.add_texts([document], duplicate_check=True)
 

+ 44 - 27
api/services/web_conversation_service.py

@@ -11,18 +11,29 @@ from services.conversation_service import ConversationService
 
 class WebConversationService:
     @classmethod
-    def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
-                              last_id: Optional[str], limit: int, invoke_from: InvokeFrom,
-                              pinned: Optional[bool] = None,
-                              sort_by='-updated_at') -> InfiniteScrollPagination:
+    def pagination_by_last_id(
+        cls,
+        app_model: App,
+        user: Optional[Union[Account, EndUser]],
+        last_id: Optional[str],
+        limit: int,
+        invoke_from: InvokeFrom,
+        pinned: Optional[bool] = None,
+        sort_by="-updated_at",
+    ) -> InfiniteScrollPagination:
         include_ids = None
         exclude_ids = None
         if pinned is not None:
-            pinned_conversations = db.session.query(PinnedConversation).filter(
-                PinnedConversation.app_id == app_model.id,
-                PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
-                PinnedConversation.created_by == user.id
-            ).order_by(PinnedConversation.created_at.desc()).all()
+            pinned_conversations = (
+                db.session.query(PinnedConversation)
+                .filter(
+                    PinnedConversation.app_id == app_model.id,
+                    PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
+                    PinnedConversation.created_by == user.id,
+                )
+                .order_by(PinnedConversation.created_at.desc())
+                .all()
+            )
             pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations]
             if pinned:
                 include_ids = pinned_conversation_ids
@@ -37,32 +48,34 @@ class WebConversationService:
             invoke_from=invoke_from,
             include_ids=include_ids,
             exclude_ids=exclude_ids,
-            sort_by=sort_by
+            sort_by=sort_by,
         )
 
     @classmethod
     def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
-        pinned_conversation = db.session.query(PinnedConversation).filter(
-            PinnedConversation.app_id == app_model.id,
-            PinnedConversation.conversation_id == conversation_id,
-            PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
-            PinnedConversation.created_by == user.id
-        ).first()
+        pinned_conversation = (
+            db.session.query(PinnedConversation)
+            .filter(
+                PinnedConversation.app_id == app_model.id,
+                PinnedConversation.conversation_id == conversation_id,
+                PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
+                PinnedConversation.created_by == user.id,
+            )
+            .first()
+        )
 
         if pinned_conversation:
             return
 
         conversation = ConversationService.get_conversation(
-            app_model=app_model,
-            conversation_id=conversation_id,
-            user=user
+            app_model=app_model, conversation_id=conversation_id, user=user
         )
 
         pinned_conversation = PinnedConversation(
             app_id=app_model.id,
             conversation_id=conversation.id,
-            created_by_role='account' if isinstance(user, Account) else 'end_user',
-            created_by=user.id
+            created_by_role="account" if isinstance(user, Account) else "end_user",
+            created_by=user.id,
         )
 
         db.session.add(pinned_conversation)
@@ -70,12 +83,16 @@ class WebConversationService:
 
     @classmethod
     def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
-        pinned_conversation = db.session.query(PinnedConversation).filter(
-            PinnedConversation.app_id == app_model.id,
-            PinnedConversation.conversation_id == conversation_id,
-            PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
-            PinnedConversation.created_by == user.id
-        ).first()
+        pinned_conversation = (
+            db.session.query(PinnedConversation)
+            .filter(
+                PinnedConversation.app_id == app_model.id,
+                PinnedConversation.conversation_id == conversation_id,
+                PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
+                PinnedConversation.created_by == user.id,
+            )
+            .first()
+        )
 
         if not pinned_conversation:
             return

+ 59 - 94
api/services/website_service.py

@@ -11,161 +11,126 @@ from services.auth.api_key_auth_service import ApiKeyAuthService
 
 
 class WebsiteService:
-
     @classmethod
     def document_create_args_validate(cls, args: dict):
-        if 'url' not in args or not args['url']:
-            raise ValueError('url is required')
-        if 'options' not in args or not args['options']:
-            raise ValueError('options is required')
-        if 'limit' not in args['options'] or not args['options']['limit']:
-            raise ValueError('limit is required')
+        if "url" not in args or not args["url"]:
+            raise ValueError("url is required")
+        if "options" not in args or not args["options"]:
+            raise ValueError("options is required")
+        if "limit" not in args["options"] or not args["options"]["limit"]:
+            raise ValueError("limit is required")
 
     @classmethod
     def crawl_url(cls, args: dict) -> dict:
-        provider = args.get('provider')
-        url = args.get('url')
-        options = args.get('options')
-        credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
-                                                             'website',
-                                                             provider)
-        if provider == 'firecrawl':
+        provider = args.get("provider")
+        url = args.get("url")
+        options = args.get("options")
+        credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider)
+        if provider == "firecrawl":
             # decrypt api_key
             api_key = encrypter.decrypt_token(
-                tenant_id=current_user.current_tenant_id,
-                token=credentials.get('config').get('api_key')
+                tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
             )
-            firecrawl_app = FirecrawlApp(api_key=api_key,
-                                         base_url=credentials.get('config').get('base_url', None))
-            crawl_sub_pages = options.get('crawl_sub_pages', False)
-            only_main_content = options.get('only_main_content', False)
+            firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
+            crawl_sub_pages = options.get("crawl_sub_pages", False)
+            only_main_content = options.get("only_main_content", False)
             if not crawl_sub_pages:
                 params = {
-                    'crawlerOptions': {
+                    "crawlerOptions": {
                         "includes": [],
                         "excludes": [],
                         "generateImgAltText": True,
                         "limit": 1,
-                        'returnOnlyUrls': False,
-                        'pageOptions': {
-                            'onlyMainContent': only_main_content,
-                            "includeHtml": False
-                        }
+                        "returnOnlyUrls": False,
+                        "pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False},
                     }
                 }
             else:
-                includes = options.get('includes').split(',') if options.get('includes') else []
-                excludes = options.get('excludes').split(',') if options.get('excludes') else []
+                includes = options.get("includes").split(",") if options.get("includes") else []
+                excludes = options.get("excludes").split(",") if options.get("excludes") else []
                 params = {
-                    'crawlerOptions': {
+                    "crawlerOptions": {
                         "includes": includes if includes else [],
                         "excludes": excludes if excludes else [],
                         "generateImgAltText": True,
-                        "limit": options.get('limit', 1),
-                        'returnOnlyUrls': False,
-                        'pageOptions': {
-                            'onlyMainContent': only_main_content,
-                            "includeHtml": False
-                        }
+                        "limit": options.get("limit", 1),
+                        "returnOnlyUrls": False,
+                        "pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False},
                     }
                 }
-                if options.get('max_depth'):
-                    params['crawlerOptions']['maxDepth'] = options.get('max_depth')
+                if options.get("max_depth"):
+                    params["crawlerOptions"]["maxDepth"] = options.get("max_depth")
             job_id = firecrawl_app.crawl_url(url, params)
-            website_crawl_time_cache_key = f'website_crawl_{job_id}'
+            website_crawl_time_cache_key = f"website_crawl_{job_id}"
             time = str(datetime.datetime.now().timestamp())
             redis_client.setex(website_crawl_time_cache_key, 3600, time)
-            return {
-                'status': 'active',
-                'job_id': job_id
-            }
+            return {"status": "active", "job_id": job_id}
         else:
-            raise ValueError('Invalid provider')
+            raise ValueError("Invalid provider")
 
     @classmethod
     def get_crawl_status(cls, job_id: str, provider: str) -> dict:
-        credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
-                                                             'website',
-                                                             provider)
-        if provider == 'firecrawl':
+        credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider)
+        if provider == "firecrawl":
             # decrypt api_key
             api_key = encrypter.decrypt_token(
-                tenant_id=current_user.current_tenant_id,
-                token=credentials.get('config').get('api_key')
+                tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
             )
-            firecrawl_app = FirecrawlApp(api_key=api_key,
-                                         base_url=credentials.get('config').get('base_url', None))
+            firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
             result = firecrawl_app.check_crawl_status(job_id)
             crawl_status_data = {
-                'status': result.get('status', 'active'),
-                'job_id': job_id,
-                'total': result.get('total', 0),
-                'current': result.get('current', 0),
-                'data': result.get('data', [])
+                "status": result.get("status", "active"),
+                "job_id": job_id,
+                "total": result.get("total", 0),
+                "current": result.get("current", 0),
+                "data": result.get("data", []),
             }
-            if crawl_status_data['status'] == 'completed':
-                website_crawl_time_cache_key = f'website_crawl_{job_id}'
+            if crawl_status_data["status"] == "completed":
+                website_crawl_time_cache_key = f"website_crawl_{job_id}"
                 start_time = redis_client.get(website_crawl_time_cache_key)
                 if start_time:
                     end_time = datetime.datetime.now().timestamp()
                     time_consuming = abs(end_time - float(start_time))
-                    crawl_status_data['time_consuming'] = f"{time_consuming:.2f}"
+                    crawl_status_data["time_consuming"] = f"{time_consuming:.2f}"
                     redis_client.delete(website_crawl_time_cache_key)
         else:
-            raise ValueError('Invalid provider')
+            raise ValueError("Invalid provider")
         return crawl_status_data
 
     @classmethod
     def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None:
-        credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
-                                                             'website',
-                                                             provider)
-        if provider == 'firecrawl':
-            file_key = 'website_files/' + job_id + '.txt'
+        credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
+        if provider == "firecrawl":
+            file_key = "website_files/" + job_id + ".txt"
             if storage.exists(file_key):
                 data = storage.load_once(file_key)
                 if data:
-                    data = json.loads(data.decode('utf-8'))
+                    data = json.loads(data.decode("utf-8"))
             else:
                 # decrypt api_key
-                api_key = encrypter.decrypt_token(
-                    tenant_id=tenant_id,
-                    token=credentials.get('config').get('api_key')
-                )
-                firecrawl_app = FirecrawlApp(api_key=api_key,
-                                             base_url=credentials.get('config').get('base_url', None))
+                api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
+                firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
                 result = firecrawl_app.check_crawl_status(job_id)
-                if result.get('status') != 'completed':
-                    raise ValueError('Crawl job is not completed')
-                data = result.get('data')
+                if result.get("status") != "completed":
+                    raise ValueError("Crawl job is not completed")
+                data = result.get("data")
             if data:
                 for item in data:
-                    if item.get('source_url') == url:
+                    if item.get("source_url") == url:
                         return item
             return None
         else:
-            raise ValueError('Invalid provider')
+            raise ValueError("Invalid provider")
 
     @classmethod
     def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None:
-        credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
-                                                             'website',
-                                                             provider)
-        if provider == 'firecrawl':
+        credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
+        if provider == "firecrawl":
             # decrypt api_key
-            api_key = encrypter.decrypt_token(
-                tenant_id=tenant_id,
-                token=credentials.get('config').get('api_key')
-            )
-            firecrawl_app = FirecrawlApp(api_key=api_key,
-                                         base_url=credentials.get('config').get('base_url', None))
-            params = {
-                'pageOptions': {
-                    'onlyMainContent': only_main_content,
-                    "includeHtml": False
-                }
-            }
+            api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
+            firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
+            params = {"pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False}}
             result = firecrawl_app.scrape_url(url, params)
             return result
         else:
-            raise ValueError('Invalid provider')
+            raise ValueError("Invalid provider")

+ 9 - 23
api/services/workflow_app_service.py

@@ -10,7 +10,6 @@ from models.workflow import WorkflowAppLog, WorkflowRun, WorkflowRunStatus
 
 
 class WorkflowAppService:
-
     def get_paginate_workflow_app_logs(self, app_model: App, args: dict) -> Pagination:
         """
         Get paginate workflow app logs
@@ -18,20 +17,14 @@ class WorkflowAppService:
         :param args: request args
         :return:
         """
-        query = (
-            db.select(WorkflowAppLog)
-            .where(
-                WorkflowAppLog.tenant_id == app_model.tenant_id,
-                WorkflowAppLog.app_id == app_model.id
-            )
+        query = db.select(WorkflowAppLog).where(
+            WorkflowAppLog.tenant_id == app_model.tenant_id, WorkflowAppLog.app_id == app_model.id
         )
 
-        status = WorkflowRunStatus.value_of(args.get('status')) if args.get('status') else None
-        keyword = args['keyword']
+        status = WorkflowRunStatus.value_of(args.get("status")) if args.get("status") else None
+        keyword = args["keyword"]
         if keyword or status:
-            query = query.join(
-                WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id
-            )
+            query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id)
 
         if keyword:
             keyword_like_val = f"%{args['keyword'][:30]}%"
@@ -39,7 +32,7 @@ class WorkflowAppService:
                 WorkflowRun.inputs.ilike(keyword_like_val),
                 WorkflowRun.outputs.ilike(keyword_like_val),
                 # filter keyword by end user session id if created by end user role
-                and_(WorkflowRun.created_by_role == 'end_user', EndUser.session_id.ilike(keyword_like_val))
+                and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)),
             ]
 
             # filter keyword by workflow run id
@@ -49,23 +42,16 @@ class WorkflowAppService:
 
             query = query.outerjoin(
                 EndUser,
-                and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value)
+                and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value),
             ).filter(or_(*keyword_conditions))
 
         if status:
             # join with workflow_run and filter by status
-            query = query.filter(
-                WorkflowRun.status == status.value
-            )
+            query = query.filter(WorkflowRun.status == status.value)
 
         query = query.order_by(WorkflowAppLog.created_at.desc())
 
-        pagination = db.paginate(
-            query,
-            page=args['page'],
-            per_page=args['limit'],
-            error_out=False
-        )
+        pagination = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
 
         return pagination
 

+ 39 - 31
api/services/workflow_run_service.py

@@ -18,6 +18,7 @@ class WorkflowRunService:
         :param app_model: app model
         :param args: request args
         """
+
         class WorkflowWithMessage:
             message_id: str
             conversation_id: str
@@ -33,9 +34,7 @@ class WorkflowRunService:
         with_message_workflow_runs = []
         for workflow_run in pagination.data:
             message = workflow_run.message
-            with_message_workflow_run = WorkflowWithMessage(
-                workflow_run=workflow_run
-            )
+            with_message_workflow_run = WorkflowWithMessage(workflow_run=workflow_run)
             if message:
                 with_message_workflow_run.message_id = message.id
                 with_message_workflow_run.conversation_id = message.conversation_id
@@ -53,26 +52,30 @@ class WorkflowRunService:
         :param app_model: app model
         :param args: request args
         """
-        limit = int(args.get('limit', 20))
+        limit = int(args.get("limit", 20))
 
         base_query = db.session.query(WorkflowRun).filter(
             WorkflowRun.tenant_id == app_model.tenant_id,
             WorkflowRun.app_id == app_model.id,
-            WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value
+            WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value,
         )
 
-        if args.get('last_id'):
+        if args.get("last_id"):
             last_workflow_run = base_query.filter(
-                WorkflowRun.id == args.get('last_id'),
+                WorkflowRun.id == args.get("last_id"),
             ).first()
 
             if not last_workflow_run:
-                raise ValueError('Last workflow run not exists')
-
-            workflow_runs = base_query.filter(
-                WorkflowRun.created_at < last_workflow_run.created_at,
-                WorkflowRun.id != last_workflow_run.id
-            ).order_by(WorkflowRun.created_at.desc()).limit(limit).all()
+                raise ValueError("Last workflow run not exists")
+
+            workflow_runs = (
+                base_query.filter(
+                    WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
+                )
+                .order_by(WorkflowRun.created_at.desc())
+                .limit(limit)
+                .all()
+            )
         else:
             workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
 
@@ -81,17 +84,13 @@ class WorkflowRunService:
             current_page_first_workflow_run = workflow_runs[-1]
             rest_count = base_query.filter(
                 WorkflowRun.created_at < current_page_first_workflow_run.created_at,
-                WorkflowRun.id != current_page_first_workflow_run.id
+                WorkflowRun.id != current_page_first_workflow_run.id,
             ).count()
 
             if rest_count > 0:
                 has_more = True
 
-        return InfiniteScrollPagination(
-            data=workflow_runs,
-            limit=limit,
-            has_more=has_more
-        )
+        return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
 
     def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun:
         """
@@ -100,11 +99,15 @@ class WorkflowRunService:
         :param app_model: app model
         :param run_id: workflow run id
         """
-        workflow_run = db.session.query(WorkflowRun).filter(
-            WorkflowRun.tenant_id == app_model.tenant_id,
-            WorkflowRun.app_id == app_model.id,
-            WorkflowRun.id == run_id,
-        ).first()
+        workflow_run = (
+            db.session.query(WorkflowRun)
+            .filter(
+                WorkflowRun.tenant_id == app_model.tenant_id,
+                WorkflowRun.app_id == app_model.id,
+                WorkflowRun.id == run_id,
+            )
+            .first()
+        )
 
         return workflow_run
 
@@ -117,12 +120,17 @@ class WorkflowRunService:
         if not workflow_run:
             return []
 
-        node_executions = db.session.query(WorkflowNodeExecution).filter(
-            WorkflowNodeExecution.tenant_id == app_model.tenant_id,
-            WorkflowNodeExecution.app_id == app_model.id,
-            WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
-            WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
-            WorkflowNodeExecution.workflow_run_id == run_id,
-        ).order_by(WorkflowNodeExecution.index.desc()).all()
+        node_executions = (
+            db.session.query(WorkflowNodeExecution)
+            .filter(
+                WorkflowNodeExecution.tenant_id == app_model.tenant_id,
+                WorkflowNodeExecution.app_id == app_model.id,
+                WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
+                WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
+                WorkflowNodeExecution.workflow_run_id == run_id,
+            )
+            .order_by(WorkflowNodeExecution.index.desc())
+            .all()
+        )
 
         return node_executions

+ 37 - 40
api/services/workflow_service.py

@@ -37,11 +37,13 @@ class WorkflowService:
         Get draft workflow
         """
         # fetch draft workflow by app_model
-        workflow = db.session.query(Workflow).filter(
-            Workflow.tenant_id == app_model.tenant_id,
-            Workflow.app_id == app_model.id,
-            Workflow.version == 'draft'
-        ).first()
+        workflow = (
+            db.session.query(Workflow)
+            .filter(
+                Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == "draft"
+            )
+            .first()
+        )
 
         # return draft workflow
         return workflow
@@ -55,11 +57,15 @@ class WorkflowService:
             return None
 
         # fetch published workflow by workflow_id
-        workflow = db.session.query(Workflow).filter(
-            Workflow.tenant_id == app_model.tenant_id,
-            Workflow.app_id == app_model.id,
-            Workflow.id == app_model.workflow_id
-        ).first()
+        workflow = (
+            db.session.query(Workflow)
+            .filter(
+                Workflow.tenant_id == app_model.tenant_id,
+                Workflow.app_id == app_model.id,
+                Workflow.id == app_model.workflow_id,
+            )
+            .first()
+        )
 
         return workflow
 
@@ -85,10 +91,7 @@ class WorkflowService:
             raise WorkflowHashNotEqualError()
 
         # validate features structure
-        self.validate_features_structure(
-            app_model=app_model,
-            features=features
-        )
+        self.validate_features_structure(app_model=app_model, features=features)
 
         # create draft workflow if not found
         if not workflow:
@@ -96,7 +99,7 @@ class WorkflowService:
                 tenant_id=app_model.tenant_id,
                 app_id=app_model.id,
                 type=WorkflowType.from_app_mode(app_model.mode).value,
-                version='draft',
+                version="draft",
                 graph=json.dumps(graph),
                 features=json.dumps(features),
                 created_by=account.id,
@@ -122,9 +125,7 @@ class WorkflowService:
         # return draft workflow
         return workflow
 
-    def publish_workflow(self, app_model: App,
-                         account: Account,
-                         draft_workflow: Optional[Workflow] = None) -> Workflow:
+    def publish_workflow(self, app_model: App, account: Account, draft_workflow: Optional[Workflow] = None) -> Workflow:
         """
         Publish workflow from draft
 
@@ -137,7 +138,7 @@ class WorkflowService:
             draft_workflow = self.get_draft_workflow(app_model=app_model)
 
         if not draft_workflow:
-            raise ValueError('No valid workflow found.')
+            raise ValueError("No valid workflow found.")
 
         # create new workflow
         workflow = Workflow(
@@ -187,17 +188,16 @@ class WorkflowService:
         workflow_engine_manager = WorkflowEngineManager()
         return workflow_engine_manager.get_default_config(node_type, filters)
 
-    def run_draft_workflow_node(self, app_model: App,
-                                node_id: str,
-                                user_inputs: dict,
-                                account: Account) -> WorkflowNodeExecution:
+    def run_draft_workflow_node(
+        self, app_model: App, node_id: str, user_inputs: dict, account: Account
+    ) -> WorkflowNodeExecution:
         """
         Run draft workflow node
         """
         # fetch draft workflow by app_model
         draft_workflow = self.get_draft_workflow(app_model=app_model)
         if not draft_workflow:
-            raise ValueError('Workflow not initialized')
+            raise ValueError("Workflow not initialized")
 
         # run draft workflow node
         workflow_engine_manager = WorkflowEngineManager()
@@ -226,7 +226,7 @@ class WorkflowService:
                 created_by_role=CreatedByRole.ACCOUNT.value,
                 created_by=account.id,
                 created_at=datetime.now(timezone.utc).replace(tzinfo=None),
-                finished_at=datetime.now(timezone.utc).replace(tzinfo=None)
+                finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
             )
             db.session.add(workflow_node_execution)
             db.session.commit()
@@ -247,14 +247,15 @@ class WorkflowService:
                 inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None,
                 process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None,
                 outputs=json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None,
-                execution_metadata=(json.dumps(jsonable_encoder(node_run_result.metadata))
-                                    if node_run_result.metadata else None),
+                execution_metadata=(
+                    json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
+                ),
                 status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
                 elapsed_time=time.perf_counter() - start_at,
                 created_by_role=CreatedByRole.ACCOUNT.value,
                 created_by=account.id,
                 created_at=datetime.now(timezone.utc).replace(tzinfo=None),
-                finished_at=datetime.now(timezone.utc).replace(tzinfo=None)
+                finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
             )
         else:
             # create workflow node execution
@@ -273,7 +274,7 @@ class WorkflowService:
                 created_by_role=CreatedByRole.ACCOUNT.value,
                 created_by=account.id,
                 created_at=datetime.now(timezone.utc).replace(tzinfo=None),
-                finished_at=datetime.now(timezone.utc).replace(tzinfo=None)
+                finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
             )
 
         db.session.add(workflow_node_execution)
@@ -295,16 +296,16 @@ class WorkflowService:
         workflow_converter = WorkflowConverter()
 
         if app_model.mode not in [AppMode.CHAT.value, AppMode.COMPLETION.value]:
-            raise ValueError(f'Current App mode: {app_model.mode} is not supported convert to workflow.')
+            raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.")
 
         # convert to workflow
         new_app = workflow_converter.convert_to_workflow(
             app_model=app_model,
             account=account,
-            name=args.get('name'),
-            icon_type=args.get('icon_type'),
-            icon=args.get('icon'),
-            icon_background=args.get('icon_background'),
+            name=args.get("name"),
+            icon_type=args.get("icon_type"),
+            icon=args.get("icon"),
+            icon_background=args.get("icon_background"),
         )
 
         return new_app
@@ -312,15 +313,11 @@ class WorkflowService:
     def validate_features_structure(self, app_model: App, features: dict) -> dict:
         if app_model.mode == AppMode.ADVANCED_CHAT.value:
             return AdvancedChatAppConfigManager.config_validate(
-                tenant_id=app_model.tenant_id,
-                config=features,
-                only_structure_validate=True
+                tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
             )
         elif app_model.mode == AppMode.WORKFLOW.value:
             return WorkflowAppConfigManager.config_validate(
-                tenant_id=app_model.tenant_id,
-                config=features,
-                only_structure_validate=True
+                tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
             )
         else:
             raise ValueError(f"Invalid app mode: {app_model.mode}")

+ 30 - 25
api/services/workspace_service.py

@@ -1,4 +1,3 @@
-
 from flask_login import current_user
 
 from configs import dify_config
@@ -14,34 +13,40 @@ class WorkspaceService:
         if not tenant:
             return None
         tenant_info = {
-            'id': tenant.id,
-            'name': tenant.name,
-            'plan': tenant.plan,
-            'status': tenant.status,
-            'created_at': tenant.created_at,
-            'in_trail': True,
-            'trial_end_reason': None,
-            'role': 'normal',
+            "id": tenant.id,
+            "name": tenant.name,
+            "plan": tenant.plan,
+            "status": tenant.status,
+            "created_at": tenant.created_at,
+            "in_trail": True,
+            "trial_end_reason": None,
+            "role": "normal",
         }
 
         # Get role of user
-        tenant_account_join = db.session.query(TenantAccountJoin).filter(
-            TenantAccountJoin.tenant_id == tenant.id,
-            TenantAccountJoin.account_id == current_user.id
-        ).first()
-        tenant_info['role'] = tenant_account_join.role
-
-        can_replace_logo = FeatureService.get_features(tenant_info['id']).can_replace_logo
-
-        if can_replace_logo and TenantService.has_roles(tenant, 
-        [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]):
+        tenant_account_join = (
+            db.session.query(TenantAccountJoin)
+            .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id)
+            .first()
+        )
+        tenant_info["role"] = tenant_account_join.role
+
+        can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo
+
+        if can_replace_logo and TenantService.has_roles(
+            tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]
+        ):
             base_url = dify_config.FILES_URL
-            replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None
-            remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False)
-
-            tenant_info['custom_config'] = {
-                'remove_webapp_brand': remove_webapp_brand,
-                'replace_webapp_logo': replace_webapp_logo,
+            replace_webapp_logo = (
+                f"{base_url}/files/workspaces/{tenant.id}/webapp-logo"
+                if tenant.custom_config_dict.get("replace_webapp_logo")
+                else None
+            )
+            remove_webapp_brand = tenant.custom_config_dict.get("remove_webapp_brand", False)
+
+            tenant_info["custom_config"] = {
+                "remove_webapp_brand": remove_webapp_brand,
+                "replace_webapp_logo": replace_webapp_logo,
             }
 
         return tenant_info

Some files were not shown because too many files changed in this diff