Browse Source

chore(api/libs): Apply ruff format. (#7301)

-LAN- 8 months ago
parent
commit
9414143b5f

+ 5 - 5
api/libs/bearer_data_source.py

@@ -25,7 +25,7 @@ class FireCrawlDataSource(BearerDataSource):
         TEST_CRAWL_SITE_URL = "https://www.google.com"
         FIRECRAWL_API_VERSION = "v0"
 
-        test_api_endpoint = self.api_base_url.rstrip('/') + f"/{FIRECRAWL_API_VERSION}/scrape"
+        test_api_endpoint = self.api_base_url.rstrip("/") + f"/{FIRECRAWL_API_VERSION}/scrape"
 
         headers = {
             "Authorization": f"Bearer {self.api_key}",
@@ -45,9 +45,9 @@ class FireCrawlDataSource(BearerDataSource):
         data_source_binding = DataSourceBearerBinding.query.filter(
             db.and_(
                 DataSourceBearerBinding.tenant_id == current_user.current_tenant_id,
-                DataSourceBearerBinding.provider == 'firecrawl',
+                DataSourceBearerBinding.provider == "firecrawl",
                 DataSourceBearerBinding.endpoint_url == self.api_base_url,
-                DataSourceBearerBinding.bearer_key == self.api_key
+                DataSourceBearerBinding.bearer_key == self.api_key,
             )
         ).first()
         if data_source_binding:
@@ -56,9 +56,9 @@ class FireCrawlDataSource(BearerDataSource):
         else:
             new_data_source_binding = DataSourceBearerBinding(
                 tenant_id=current_user.current_tenant_id,
-                provider='firecrawl',
+                provider="firecrawl",
                 endpoint_url=self.api_base_url,
-                bearer_key=self.api_key
+                bearer_key=self.api_key,
             )
             db.session.add(new_data_source_binding)
             db.session.commit()

+ 2 - 2
api/libs/exception.py

@@ -4,7 +4,7 @@ from werkzeug.exceptions import HTTPException
 
 
 class BaseHTTPException(HTTPException):
-    error_code: str = 'unknown'
+    error_code: str = "unknown"
     data: Optional[dict] = None
 
     def __init__(self, description=None, response=None):
@@ -14,4 +14,4 @@ class BaseHTTPException(HTTPException):
             "code": self.error_code,
             "message": self.description,
             "status": self.code,
-        }
+        }

+ 30 - 40
api/libs/external_api.py

@@ -10,7 +10,6 @@ from core.errors.error import AppInvokeQuotaExceededError
 
 
 class ExternalApi(Api):
-
     def handle_error(self, e):
         """Error handler for the API transforms a raised exception into a Flask
         response, with the appropriate HTTP status code and body.
@@ -29,54 +28,57 @@ class ExternalApi(Api):
 
             status_code = e.code
             default_data = {
-                'code': re.sub(r'(?<!^)(?=[A-Z])', '_', type(e).__name__).lower(),
-                'message': getattr(e, 'description', http_status_message(status_code)),
-                'status': status_code
+                "code": re.sub(r"(?<!^)(?=[A-Z])", "_", type(e).__name__).lower(),
+                "message": getattr(e, "description", http_status_message(status_code)),
+                "status": status_code,
             }
 
-            if default_data['message'] and default_data['message'] == 'Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)':
-                default_data['message'] = 'Invalid JSON payload received or JSON payload is empty.'
+            if (
+                default_data["message"]
+                and default_data["message"] == "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)"
+            ):
+                default_data["message"] = "Invalid JSON payload received or JSON payload is empty."
 
             headers = e.get_response().headers
         elif isinstance(e, ValueError):
             status_code = 400
             default_data = {
-                'code': 'invalid_param',
-                'message': str(e),
-                'status': status_code
+                "code": "invalid_param",
+                "message": str(e),
+                "status": status_code,
             }
         elif isinstance(e, AppInvokeQuotaExceededError):
             status_code = 429
             default_data = {
-                'code': 'too_many_requests',
-                'message': str(e),
-                'status': status_code
+                "code": "too_many_requests",
+                "message": str(e),
+                "status": status_code,
             }
         else:
             status_code = 500
             default_data = {
-                'message': http_status_message(status_code),
+                "message": http_status_message(status_code),
             }
 
         # Werkzeug exceptions generate a content-length header which is added
         # to the response in addition to the actual content-length header
         # https://github.com/flask-restful/flask-restful/issues/534
-        remove_headers = ('Content-Length',)
+        remove_headers = ("Content-Length",)
 
         for header in remove_headers:
             headers.pop(header, None)
 
-        data = getattr(e, 'data', default_data)
+        data = getattr(e, "data", default_data)
 
         error_cls_name = type(e).__name__
         if error_cls_name in self.errors:
             custom_data = self.errors.get(error_cls_name, {})
             custom_data = custom_data.copy()
-            status_code = custom_data.get('status', 500)
+            status_code = custom_data.get("status", 500)
 
-            if 'message' in custom_data:
-                custom_data['message'] = custom_data['message'].format(
-                    message=str(e.description if hasattr(e, 'description') else e)
+            if "message" in custom_data:
+                custom_data["message"] = custom_data["message"].format(
+                    message=str(e.description if hasattr(e, "description") else e)
                 )
             data.update(custom_data)
 
@@ -94,32 +96,20 @@ class ExternalApi(Api):
             # another NotAcceptable error).
             supported_mediatypes = list(self.representations.keys())  # only supported application/json
             fallback_mediatype = supported_mediatypes[0] if supported_mediatypes else "text/plain"
-            data = {
-                'code': 'not_acceptable',
-                'message': data.get('message')
-            }
-            resp = self.make_response(
-                data,
-                status_code,
-                headers,
-                fallback_mediatype = fallback_mediatype
-            )
+            data = {"code": "not_acceptable", "message": data.get("message")}
+            resp = self.make_response(data, status_code, headers, fallback_mediatype=fallback_mediatype)
         elif status_code == 400:
-            if isinstance(data.get('message'), dict):
-                param_key, param_value = list(data.get('message').items())[0]
-                data = {
-                    'code': 'invalid_param',
-                    'message': param_value,
-                    'params': param_key
-                }
+            if isinstance(data.get("message"), dict):
+                param_key, param_value = list(data.get("message").items())[0]
+                data = {"code": "invalid_param", "message": param_value, "params": param_key}
             else:
-                if 'code' not in data:
-                    data['code'] = 'unknown'
+                if "code" not in data:
+                    data["code"] = "unknown"
 
             resp = self.make_response(data, status_code, headers)
         else:
-            if 'code' not in data:
-                data['code'] = 'unknown'
+            if "code" not in data:
+                data["code"] = "unknown"
 
             resp = self.make_response(data, status_code, headers)
 

+ 15 - 14
api/libs/gmpy2_pkcs10aep_cipher.py

@@ -70,7 +70,7 @@ class PKCS1OAEP_Cipher:
         if mgfunc:
             self._mgf = mgfunc
         else:
-            self._mgf = lambda x,y: MGF1(x,y,self._hashObj)
+            self._mgf = lambda x, y: MGF1(x, y, self._hashObj)
 
         self._label = _copy_bytes(None, None, label)
         self._randfunc = randfunc
@@ -107,7 +107,7 @@ class PKCS1OAEP_Cipher:
 
         # See 7.1.1 in RFC3447
         modBits = Crypto.Util.number.size(self._key.n)
-        k = ceil_div(modBits, 8) # Convert from bits to bytes
+        k = ceil_div(modBits, 8)  # Convert from bits to bytes
         hLen = self._hashObj.digest_size
         mLen = len(message)
 
@@ -118,13 +118,13 @@ class PKCS1OAEP_Cipher:
         # Step 2a
         lHash = sha1(self._label).digest()
         # Step 2b
-        ps = b'\x00' * ps_len
+        ps = b"\x00" * ps_len
         # Step 2c
-        db = lHash + ps + b'\x01' + _copy_bytes(None, None, message)
+        db = lHash + ps + b"\x01" + _copy_bytes(None, None, message)
         # Step 2d
         ros = self._randfunc(hLen)
         # Step 2e
-        dbMask = self._mgf(ros, k-hLen-1)
+        dbMask = self._mgf(ros, k - hLen - 1)
         # Step 2f
         maskedDB = strxor(db, dbMask)
         # Step 2g
@@ -132,7 +132,7 @@ class PKCS1OAEP_Cipher:
         # Step 2h
         maskedSeed = strxor(ros, seedMask)
         # Step 2i
-        em = b'\x00' + maskedSeed + maskedDB
+        em = b"\x00" + maskedSeed + maskedDB
         # Step 3a (OS2IP)
         em_int = bytes_to_long(em)
         # Step 3b (RSAEP)
@@ -160,10 +160,10 @@ class PKCS1OAEP_Cipher:
         """
         # See 7.1.2 in RFC3447
         modBits = Crypto.Util.number.size(self._key.n)
-        k = ceil_div(modBits,8) # Convert from bits to bytes
+        k = ceil_div(modBits, 8)  # Convert from bits to bytes
         hLen = self._hashObj.digest_size
         # Step 1b and 1c
-        if len(ciphertext) != k or k<hLen+2:
+        if len(ciphertext) != k or k < hLen + 2:
             raise ValueError("Ciphertext with incorrect length.")
         # Step 2a (O2SIP)
         ct_int = bytes_to_long(ciphertext)
@@ -178,18 +178,18 @@ class PKCS1OAEP_Cipher:
         y = em[0]
         # y must be 0, but we MUST NOT check it here in order not to
         # allow attacks like Manger's (http://dl.acm.org/citation.cfm?id=704143)
-        maskedSeed = em[1:hLen+1]
-        maskedDB = em[hLen+1:]
+        maskedSeed = em[1 : hLen + 1]
+        maskedDB = em[hLen + 1 :]
         # Step 3c
         seedMask = self._mgf(maskedDB, hLen)
         # Step 3d
         seed = strxor(maskedSeed, seedMask)
         # Step 3e
-        dbMask = self._mgf(seed, k-hLen-1)
+        dbMask = self._mgf(seed, k - hLen - 1)
         # Step 3f
         db = strxor(maskedDB, dbMask)
         # Step 3g
-        one_pos = hLen + db[hLen:].find(b'\x01')
+        one_pos = hLen + db[hLen:].find(b"\x01")
         lHash1 = db[:hLen]
         invalid = bord(y) | int(one_pos < hLen)
         hash_compare = strxor(lHash1, lHash)
@@ -200,9 +200,10 @@ class PKCS1OAEP_Cipher:
         if invalid != 0:
             raise ValueError("Incorrect decryption.")
         # Step 4
-        return db[one_pos + 1:]
+        return db[one_pos + 1 :]
 
-def new(key, hashAlgo=None, mgfunc=None, label=b'', randfunc=None):
+
+def new(key, hashAlgo=None, mgfunc=None, label=b"", randfunc=None):
     """Return a cipher object :class:`PKCS1OAEP_Cipher` that can be used to perform PKCS#1 OAEP encryption or decryption.
 
     :param key:

+ 40 - 46
api/libs/helper.py

@@ -21,7 +21,7 @@ from models.account import Account
 
 
 def run(script):
-    return subprocess.getstatusoutput('source /root/.bashrc && ' + script)
+    return subprocess.getstatusoutput("source /root/.bashrc && " + script)
 
 
 class TimestampField(fields.Raw):
@@ -36,29 +36,29 @@ def email(email):
     if re.match(pattern, email) is not None:
         return email
 
-    error = ('{email} is not a valid email.'
-             .format(email=email))
+    error = "{email} is not a valid email.".format(email=email)
     raise ValueError(error)
 
 
 def uuid_value(value):
-    if value == '':
+    if value == "":
         return str(value)
 
     try:
         uuid_obj = uuid.UUID(value)
         return str(uuid_obj)
     except ValueError:
-        error = ('{value} is not a valid uuid.'
-                 .format(value=value))
+        error = "{value} is not a valid uuid.".format(value=value)
         raise ValueError(error)
 
+
 def alphanumeric(value: str):
     # check if the value is alphanumeric and underlined
-    if re.match(r'^[a-zA-Z0-9_]+$', value):
+    if re.match(r"^[a-zA-Z0-9_]+$", value):
         return value
 
-    raise ValueError(f'{value} is not a valid alphanumeric value')
+    raise ValueError(f"{value} is not a valid alphanumeric value")
+
 
 def timestamp_value(timestamp):
     try:
@@ -67,31 +67,32 @@ def timestamp_value(timestamp):
             raise ValueError
         return int_timestamp
     except ValueError:
-        error = ('{timestamp} is not a valid timestamp.'
-                 .format(timestamp=timestamp))
+        error = "{timestamp} is not a valid timestamp.".format(timestamp=timestamp)
         raise ValueError(error)
 
 
 class str_len:
-    """ Restrict input to an integer in a range (inclusive) """
+    """Restrict input to an integer in a range (inclusive)"""
 
-    def __init__(self, max_length, argument='argument'):
+    def __init__(self, max_length, argument="argument"):
         self.max_length = max_length
         self.argument = argument
 
     def __call__(self, value):
         length = len(value)
         if length > self.max_length:
-            error = ('Invalid {arg}: {val}. {arg} cannot exceed length {length}'
-                     .format(arg=self.argument, val=value, length=self.max_length))
+            error = "Invalid {arg}: {val}. {arg} cannot exceed length {length}".format(
+                arg=self.argument, val=value, length=self.max_length
+            )
             raise ValueError(error)
 
         return value
 
 
 class float_range:
-    """ Restrict input to an float in a range (inclusive) """
-    def __init__(self, low, high, argument='argument'):
+    """Restrict input to an float in a range (inclusive)"""
+
+    def __init__(self, low, high, argument="argument"):
         self.low = low
         self.high = high
         self.argument = argument
@@ -99,15 +100,16 @@ class float_range:
     def __call__(self, value):
         value = _get_float(value)
         if value < self.low or value > self.high:
-            error = ('Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}'
-                     .format(arg=self.argument, val=value, lo=self.low, hi=self.high))
+            error = "Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}".format(
+                arg=self.argument, val=value, lo=self.low, hi=self.high
+            )
             raise ValueError(error)
 
         return value
 
 
 class datetime_string:
-    def __init__(self, format, argument='argument'):
+    def __init__(self, format, argument="argument"):
         self.format = format
         self.argument = argument
 
@@ -115,8 +117,9 @@ class datetime_string:
         try:
             datetime.strptime(value, self.format)
         except ValueError:
-            error = ('Invalid {arg}: {val}. {arg} must be conform to the format {format}'
-                     .format(arg=self.argument, val=value, format=self.format))
+            error = "Invalid {arg}: {val}. {arg} must be conform to the format {format}".format(
+                arg=self.argument, val=value, format=self.format
+            )
             raise ValueError(error)
 
         return value
@@ -126,14 +129,14 @@ def _get_float(value):
     try:
         return float(value)
     except (TypeError, ValueError):
-        raise ValueError('{} is not a valid float'.format(value))
+        raise ValueError("{} is not a valid float".format(value))
+
 
 def timezone(timezone_string):
     if timezone_string and timezone_string in available_timezones():
         return timezone_string
 
-    error = ('{timezone_string} is not a valid timezone.'
-             .format(timezone_string=timezone_string))
+    error = "{timezone_string} is not a valid timezone.".format(timezone_string=timezone_string)
     raise ValueError(error)
 
 
@@ -147,8 +150,8 @@ def generate_string(n):
 
 
 def get_remote_ip(request) -> str:
-    if request.headers.get('CF-Connecting-IP'):
-        return request.headers.get('Cf-Connecting-Ip')
+    if request.headers.get("CF-Connecting-IP"):
+        return request.headers.get("Cf-Connecting-Ip")
     elif request.headers.getlist("X-Forwarded-For"):
         return request.headers.getlist("X-Forwarded-For")[0]
     else:
@@ -156,54 +159,45 @@ def get_remote_ip(request) -> str:
 
 
 def generate_text_hash(text: str) -> str:
-    hash_text = str(text) + 'None'
+    hash_text = str(text) + "None"
     return sha256(hash_text.encode()).hexdigest()
 
 
 def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Response:
     if isinstance(response, dict):
-        return Response(response=json.dumps(response), status=200, mimetype='application/json')
+        return Response(response=json.dumps(response), status=200, mimetype="application/json")
     else:
+
         def generate() -> Generator:
             yield from response
 
-        return Response(stream_with_context(generate()), status=200,
-                        mimetype='text/event-stream')
+        return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
 
 
 class TokenManager:
-
     @classmethod
     def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str:
         old_token = cls._get_current_token_for_account(account.id, token_type)
         if old_token:
             if isinstance(old_token, bytes):
-                old_token = old_token.decode('utf-8')
+                old_token = old_token.decode("utf-8")
             cls.revoke_token(old_token, token_type)
 
         token = str(uuid.uuid4())
-        token_data = {
-            'account_id': account.id,
-            'email': account.email,
-            'token_type': token_type
-        }
+        token_data = {"account_id": account.id, "email": account.email, "token_type": token_type}
         if additional_data:
             token_data.update(additional_data)
 
-        expiry_hours = current_app.config[f'{token_type.upper()}_TOKEN_EXPIRY_HOURS']
+        expiry_hours = current_app.config[f"{token_type.upper()}_TOKEN_EXPIRY_HOURS"]
         token_key = cls._get_token_key(token, token_type)
-        redis_client.setex(
-            token_key,
-            expiry_hours * 60 * 60,
-            json.dumps(token_data)
-        )
+        redis_client.setex(token_key, expiry_hours * 60 * 60, json.dumps(token_data))
 
         cls._set_current_token_for_account(account.id, token, token_type, expiry_hours)
         return token
 
     @classmethod
     def _get_token_key(cls, token: str, token_type: str) -> str:
-        return f'{token_type}:token:{token}'
+        return f"{token_type}:token:{token}"
 
     @classmethod
     def revoke_token(cls, token: str, token_type: str):
@@ -233,7 +227,7 @@ class TokenManager:
 
     @classmethod
     def _get_account_token_key(cls, account_id: str, token_type: str) -> str:
-        return f'{token_type}:account:{account_id}'
+        return f"{token_type}:account:{account_id}"
 
 
 class RateLimiter:
@@ -250,7 +244,7 @@ class RateLimiter:
         current_time = int(time.time())
         window_start_time = current_time - self.time_window
 
-        redis_client.zremrangebyscore(key, '-inf', window_start_time)
+        redis_client.zremrangebyscore(key, "-inf", window_start_time)
         attempts = redis_client.zcard(key)
 
         if attempts and int(attempts) >= self.max_attempts:

+ 0 - 1
api/libs/infinite_scroll_pagination.py

@@ -1,4 +1,3 @@
-
 class InfiniteScrollPagination:
     def __init__(self, data, limit, has_more):
         self.data = data

+ 3 - 4
api/libs/json_in_md_parser.py

@@ -10,13 +10,13 @@ def parse_json_markdown(json_string: str) -> dict:
     end_index = json_string.find("```", start_index + len("```json"))
 
     if start_index != -1 and end_index != -1:
-        extracted_content = json_string[start_index + len("```json"):end_index].strip()
+        extracted_content = json_string[start_index + len("```json") : end_index].strip()
 
         # Parse the JSON string into a Python dictionary
         parsed = json.loads(extracted_content)
     elif start_index != -1 and end_index == -1 and json_string.endswith("``"):
         end_index = json_string.find("``", start_index + len("```json"))
-        extracted_content = json_string[start_index + len("```json"):end_index].strip()
+        extracted_content = json_string[start_index + len("```json") : end_index].strip()
 
         # Parse the JSON string into a Python dictionary
         parsed = json.loads(extracted_content)
@@ -37,7 +37,6 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict:
     for key in expected_keys:
         if key not in json_obj:
             raise OutputParserException(
-                f"Got invalid return object. Expected key `{key}` "
-                f"to be present, but got {json_obj}"
+                f"Got invalid return object. Expected key `{key}` " f"to be present, but got {json_obj}"
             )
     return json_obj

+ 16 - 14
api/libs/login.py

@@ -51,27 +51,29 @@ def login_required(func):
 
     @wraps(func)
     def decorated_view(*args, **kwargs):
-        auth_header = request.headers.get('Authorization')
-        admin_api_key_enable = os.getenv('ADMIN_API_KEY_ENABLE', default='False')
-        if admin_api_key_enable.lower() == 'true':
+        auth_header = request.headers.get("Authorization")
+        admin_api_key_enable = os.getenv("ADMIN_API_KEY_ENABLE", default="False")
+        if admin_api_key_enable.lower() == "true":
             if auth_header:
-                if ' ' not in auth_header:
-                    raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
+                if " " not in auth_header:
+                    raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
                 auth_scheme, auth_token = auth_header.split(None, 1)
                 auth_scheme = auth_scheme.lower()
-                if auth_scheme != 'bearer':
-                    raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
-                admin_api_key = os.getenv('ADMIN_API_KEY')
+                if auth_scheme != "bearer":
+                    raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
+                admin_api_key = os.getenv("ADMIN_API_KEY")
 
                 if admin_api_key:
-                    if os.getenv('ADMIN_API_KEY') == auth_token:
-                        workspace_id = request.headers.get('X-WORKSPACE-ID')
+                    if os.getenv("ADMIN_API_KEY") == auth_token:
+                        workspace_id = request.headers.get("X-WORKSPACE-ID")
                         if workspace_id:
-                            tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \
-                                .filter(Tenant.id == workspace_id) \
-                                .filter(TenantAccountJoin.tenant_id == Tenant.id) \
-                                .filter(TenantAccountJoin.role == 'owner') \
+                            tenant_account_join = (
+                                db.session.query(Tenant, TenantAccountJoin)
+                                .filter(Tenant.id == workspace_id)
+                                .filter(TenantAccountJoin.tenant_id == Tenant.id)
+                                .filter(TenantAccountJoin.role == "owner")
                                 .one_or_none()
+                            )
                             if tenant_account_join:
                                 tenant, ta = tenant_account_join
                                 account = Account.query.filter_by(id=ta.account_id).first()

+ 34 - 44
api/libs/oauth.py

@@ -35,31 +35,31 @@ class OAuth:
 
 
 class GitHubOAuth(OAuth):
-    _AUTH_URL = 'https://github.com/login/oauth/authorize'
-    _TOKEN_URL = 'https://github.com/login/oauth/access_token'
-    _USER_INFO_URL = 'https://api.github.com/user'
-    _EMAIL_INFO_URL = 'https://api.github.com/user/emails'
+    _AUTH_URL = "https://github.com/login/oauth/authorize"
+    _TOKEN_URL = "https://github.com/login/oauth/access_token"
+    _USER_INFO_URL = "https://api.github.com/user"
+    _EMAIL_INFO_URL = "https://api.github.com/user/emails"
 
     def get_authorization_url(self):
         params = {
-            'client_id': self.client_id,
-            'redirect_uri': self.redirect_uri,
-            'scope': 'user:email'  # Request only basic user information
+            "client_id": self.client_id,
+            "redirect_uri": self.redirect_uri,
+            "scope": "user:email",  # Request only basic user information
         }
         return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
 
     def get_access_token(self, code: str):
         data = {
-            'client_id': self.client_id,
-            'client_secret': self.client_secret,
-            'code': code,
-            'redirect_uri': self.redirect_uri
+            "client_id": self.client_id,
+            "client_secret": self.client_secret,
+            "code": code,
+            "redirect_uri": self.redirect_uri,
         }
-        headers = {'Accept': 'application/json'}
+        headers = {"Accept": "application/json"}
         response = requests.post(self._TOKEN_URL, data=data, headers=headers)
 
         response_json = response.json()
-        access_token = response_json.get('access_token')
+        access_token = response_json.get("access_token")
 
         if not access_token:
             raise ValueError(f"Error in GitHub OAuth: {response_json}")
@@ -67,55 +67,51 @@ class GitHubOAuth(OAuth):
         return access_token
 
     def get_raw_user_info(self, token: str):
-        headers = {'Authorization': f"token {token}"}
+        headers = {"Authorization": f"token {token}"}
         response = requests.get(self._USER_INFO_URL, headers=headers)
         response.raise_for_status()
         user_info = response.json()
 
         email_response = requests.get(self._EMAIL_INFO_URL, headers=headers)
         email_info = email_response.json()
-        primary_email = next((email for email in email_info if email['primary'] == True), None)
+        primary_email = next((email for email in email_info if email["primary"] == True), None)
 
-        return {**user_info, 'email': primary_email['email']}
+        return {**user_info, "email": primary_email["email"]}
 
     def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
-        email = raw_info.get('email')
+        email = raw_info.get("email")
         if not email:
             email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com"
-        return OAuthUserInfo(
-            id=str(raw_info['id']),
-            name=raw_info['name'],
-            email=email
-        )
+        return OAuthUserInfo(id=str(raw_info["id"]), name=raw_info["name"], email=email)
 
 
 class GoogleOAuth(OAuth):
-    _AUTH_URL = 'https://accounts.google.com/o/oauth2/v2/auth'
-    _TOKEN_URL = 'https://oauth2.googleapis.com/token'
-    _USER_INFO_URL = 'https://www.googleapis.com/oauth2/v3/userinfo'
+    _AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
+    _TOKEN_URL = "https://oauth2.googleapis.com/token"
+    _USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
 
     def get_authorization_url(self):
         params = {
-            'client_id': self.client_id,
-            'response_type': 'code',
-            'redirect_uri': self.redirect_uri,
-            'scope': 'openid email'
+            "client_id": self.client_id,
+            "response_type": "code",
+            "redirect_uri": self.redirect_uri,
+            "scope": "openid email",
         }
         return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
 
     def get_access_token(self, code: str):
         data = {
-            'client_id': self.client_id,
-            'client_secret': self.client_secret,
-            'code': code,
-            'grant_type': 'authorization_code',
-            'redirect_uri': self.redirect_uri
+            "client_id": self.client_id,
+            "client_secret": self.client_secret,
+            "code": code,
+            "grant_type": "authorization_code",
+            "redirect_uri": self.redirect_uri,
         }
-        headers = {'Accept': 'application/json'}
+        headers = {"Accept": "application/json"}
         response = requests.post(self._TOKEN_URL, data=data, headers=headers)
 
         response_json = response.json()
-        access_token = response_json.get('access_token')
+        access_token = response_json.get("access_token")
 
         if not access_token:
             raise ValueError(f"Error in Google OAuth: {response_json}")
@@ -123,16 +119,10 @@ class GoogleOAuth(OAuth):
         return access_token
 
     def get_raw_user_info(self, token: str):
-        headers = {'Authorization': f"Bearer {token}"}
+        headers = {"Authorization": f"Bearer {token}"}
         response = requests.get(self._USER_INFO_URL, headers=headers)
         response.raise_for_status()
         return response.json()
 
     def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
-        return OAuthUserInfo(
-            id=str(raw_info['sub']),
-            name=None,
-            email=raw_info['email']
-        )
-
-
+        return OAuthUserInfo(id=str(raw_info["sub"]), name=None, email=raw_info["email"])

+ 102 - 128
api/libs/oauth_data_source.py

@@ -21,53 +21,49 @@ class OAuthDataSource:
 
 
 class NotionOAuth(OAuthDataSource):
-    _AUTH_URL = 'https://api.notion.com/v1/oauth/authorize'
-    _TOKEN_URL = 'https://api.notion.com/v1/oauth/token'
+    _AUTH_URL = "https://api.notion.com/v1/oauth/authorize"
+    _TOKEN_URL = "https://api.notion.com/v1/oauth/token"
     _NOTION_PAGE_SEARCH = "https://api.notion.com/v1/search"
     _NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks"
     _NOTION_BOT_USER = "https://api.notion.com/v1/users/me"
 
     def get_authorization_url(self):
         params = {
-            'client_id': self.client_id,
-            'response_type': 'code',
-            'redirect_uri': self.redirect_uri,
-            'owner': 'user'
+            "client_id": self.client_id,
+            "response_type": "code",
+            "redirect_uri": self.redirect_uri,
+            "owner": "user",
         }
         return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
 
     def get_access_token(self, code: str):
-        data = {
-            'code': code,
-            'grant_type': 'authorization_code',
-            'redirect_uri': self.redirect_uri
-        }
-        headers = {'Accept': 'application/json'}
+        data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
+        headers = {"Accept": "application/json"}
         auth = (self.client_id, self.client_secret)
         response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
 
         response_json = response.json()
-        access_token = response_json.get('access_token')
+        access_token = response_json.get("access_token")
         if not access_token:
             raise ValueError(f"Error in Notion OAuth: {response_json}")
-        workspace_name = response_json.get('workspace_name')
-        workspace_icon = response_json.get('workspace_icon')
-        workspace_id = response_json.get('workspace_id')
+        workspace_name = response_json.get("workspace_name")
+        workspace_icon = response_json.get("workspace_icon")
+        workspace_id = response_json.get("workspace_id")
         # get all authorized pages
         pages = self.get_authorized_pages(access_token)
         source_info = {
-            'workspace_name': workspace_name,
-            'workspace_icon': workspace_icon,
-            'workspace_id': workspace_id,
-            'pages': pages,
-            'total': len(pages)
+            "workspace_name": workspace_name,
+            "workspace_icon": workspace_icon,
+            "workspace_id": workspace_id,
+            "pages": pages,
+            "total": len(pages),
         }
         # save data source binding
         data_source_binding = DataSourceOauthBinding.query.filter(
             db.and_(
                 DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
-                DataSourceOauthBinding.provider == 'notion',
-                DataSourceOauthBinding.access_token == access_token
+                DataSourceOauthBinding.provider == "notion",
+                DataSourceOauthBinding.access_token == access_token,
             )
         ).first()
         if data_source_binding:
@@ -79,7 +75,7 @@ class NotionOAuth(OAuthDataSource):
                 tenant_id=current_user.current_tenant_id,
                 access_token=access_token,
                 source_info=source_info,
-                provider='notion'
+                provider="notion",
             )
             db.session.add(new_data_source_binding)
             db.session.commit()
@@ -91,18 +87,18 @@ class NotionOAuth(OAuthDataSource):
         # get all authorized pages
         pages = self.get_authorized_pages(access_token)
         source_info = {
-            'workspace_name': workspace_name,
-            'workspace_icon': workspace_icon,
-            'workspace_id': workspace_id,
-            'pages': pages,
-            'total': len(pages)
+            "workspace_name": workspace_name,
+            "workspace_icon": workspace_icon,
+            "workspace_id": workspace_id,
+            "pages": pages,
+            "total": len(pages),
         }
         # save data source binding
         data_source_binding = DataSourceOauthBinding.query.filter(
             db.and_(
                 DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
-                DataSourceOauthBinding.provider == 'notion',
-                DataSourceOauthBinding.access_token == access_token
+                DataSourceOauthBinding.provider == "notion",
+                DataSourceOauthBinding.access_token == access_token,
             )
         ).first()
         if data_source_binding:
@@ -114,7 +110,7 @@ class NotionOAuth(OAuthDataSource):
                 tenant_id=current_user.current_tenant_id,
                 access_token=access_token,
                 source_info=source_info,
-                provider='notion'
+                provider="notion",
             )
             db.session.add(new_data_source_binding)
             db.session.commit()
@@ -124,9 +120,9 @@ class NotionOAuth(OAuthDataSource):
         data_source_binding = DataSourceOauthBinding.query.filter(
             db.and_(
                 DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
-                DataSourceOauthBinding.provider == 'notion',
+                DataSourceOauthBinding.provider == "notion",
                 DataSourceOauthBinding.id == binding_id,
-                DataSourceOauthBinding.disabled == False
+                DataSourceOauthBinding.disabled == False,
             )
         ).first()
         if data_source_binding:
@@ -134,17 +130,17 @@ class NotionOAuth(OAuthDataSource):
             pages = self.get_authorized_pages(data_source_binding.access_token)
             source_info = data_source_binding.source_info
             new_source_info = {
-                'workspace_name': source_info['workspace_name'],
-                'workspace_icon': source_info['workspace_icon'],
-                'workspace_id': source_info['workspace_id'],
-                'pages': pages,
-                'total': len(pages)
+                "workspace_name": source_info["workspace_name"],
+                "workspace_icon": source_info["workspace_icon"],
+                "workspace_id": source_info["workspace_id"],
+                "pages": pages,
+                "total": len(pages),
             }
             data_source_binding.source_info = new_source_info
             data_source_binding.disabled = False
             db.session.commit()
         else:
-            raise ValueError('Data source binding not found')
+            raise ValueError("Data source binding not found")
 
     def get_authorized_pages(self, access_token: str):
         pages = []
@@ -152,143 +148,121 @@ class NotionOAuth(OAuthDataSource):
         database_results = self.notion_database_search(access_token)
         # get page detail
         for page_result in page_results:
-            page_id = page_result['id']
-            page_name = 'Untitled'
-            for key in page_result['properties']:
-                if 'title' in page_result['properties'][key] and page_result['properties'][key]['title']:
-                    title_list = page_result['properties'][key]['title']
-                    if len(title_list) > 0 and 'plain_text' in title_list[0]:
-                        page_name = title_list[0]['plain_text']
-            page_icon = page_result['icon']
+            page_id = page_result["id"]
+            page_name = "Untitled"
+            for key in page_result["properties"]:
+                if "title" in page_result["properties"][key] and page_result["properties"][key]["title"]:
+                    title_list = page_result["properties"][key]["title"]
+                    if len(title_list) > 0 and "plain_text" in title_list[0]:
+                        page_name = title_list[0]["plain_text"]
+            page_icon = page_result["icon"]
             if page_icon:
-                icon_type = page_icon['type']
-                if icon_type == 'external' or icon_type == 'file':
-                    url = page_icon[icon_type]['url']
-                    icon = {
-                        'type': 'url',
-                        'url': url if url.startswith('http') else f'https://www.notion.so{url}'
-                    }
+                icon_type = page_icon["type"]
+                if icon_type == "external" or icon_type == "file":
+                    url = page_icon[icon_type]["url"]
+                    icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"}
                 else:
-                    icon = {
-                        'type': 'emoji',
-                        'emoji': page_icon[icon_type]
-                    }
+                    icon = {"type": "emoji", "emoji": page_icon[icon_type]}
             else:
                 icon = None
-            parent = page_result['parent']
-            parent_type = parent['type']
-            if parent_type == 'block_id':
+            parent = page_result["parent"]
+            parent_type = parent["type"]
+            if parent_type == "block_id":
                 parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type])
-            elif parent_type == 'workspace':
-                parent_id = 'root'
+            elif parent_type == "workspace":
+                parent_id = "root"
             else:
                 parent_id = parent[parent_type]
             page = {
-                'page_id': page_id,
-                'page_name': page_name,
-                'page_icon': icon,
-                'parent_id': parent_id,
-                'type': 'page'
+                "page_id": page_id,
+                "page_name": page_name,
+                "page_icon": icon,
+                "parent_id": parent_id,
+                "type": "page",
             }
             pages.append(page)
             # get database detail
         for database_result in database_results:
-            page_id = database_result['id']
-            if len(database_result['title']) > 0:
-                page_name = database_result['title'][0]['plain_text']
+            page_id = database_result["id"]
+            if len(database_result["title"]) > 0:
+                page_name = database_result["title"][0]["plain_text"]
             else:
-                page_name = 'Untitled'
-            page_icon = database_result['icon']
+                page_name = "Untitled"
+            page_icon = database_result["icon"]
             if page_icon:
-                icon_type = page_icon['type']
-                if icon_type == 'external' or icon_type == 'file':
-                    url = page_icon[icon_type]['url']
-                    icon = {
-                        'type': 'url',
-                        'url': url if url.startswith('http') else f'https://www.notion.so{url}'
-                    }
+                icon_type = page_icon["type"]
+                if icon_type == "external" or icon_type == "file":
+                    url = page_icon[icon_type]["url"]
+                    icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"}
                 else:
-                    icon = {
-                        'type': icon_type,
-                        icon_type: page_icon[icon_type]
-                    }
+                    icon = {"type": icon_type, icon_type: page_icon[icon_type]}
             else:
                 icon = None
-            parent = database_result['parent']
-            parent_type = parent['type']
-            if parent_type == 'block_id':
+            parent = database_result["parent"]
+            parent_type = parent["type"]
+            if parent_type == "block_id":
                 parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type])
-            elif parent_type == 'workspace':
-                parent_id = 'root'
+            elif parent_type == "workspace":
+                parent_id = "root"
             else:
                 parent_id = parent[parent_type]
             page = {
-                'page_id': page_id,
-                'page_name': page_name,
-                'page_icon': icon,
-                'parent_id': parent_id,
-                'type': 'database'
+                "page_id": page_id,
+                "page_name": page_name,
+                "page_icon": icon,
+                "parent_id": parent_id,
+                "type": "database",
             }
             pages.append(page)
         return pages
 
     def notion_page_search(self, access_token: str):
-        data = {
-            'filter': {
-                "value": "page",
-                "property": "object"
-            }
-        }
+        data = {"filter": {"value": "page", "property": "object"}}
         headers = {
-            'Content-Type': 'application/json',
-            'Authorization': f"Bearer {access_token}",
-            'Notion-Version': '2022-06-28',
+            "Content-Type": "application/json",
+            "Authorization": f"Bearer {access_token}",
+            "Notion-Version": "2022-06-28",
         }
         response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
         response_json = response.json()
-        results = response_json.get('results', [])
+        results = response_json.get("results", [])
         return results
 
     def notion_block_parent_page_id(self, access_token: str, block_id: str):
         headers = {
-            'Authorization': f"Bearer {access_token}",
-            'Notion-Version': '2022-06-28',
+            "Authorization": f"Bearer {access_token}",
+            "Notion-Version": "2022-06-28",
         }
-        response = requests.get(url=f'{self._NOTION_BLOCK_SEARCH}/{block_id}', headers=headers)
+        response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
         response_json = response.json()
-        parent = response_json['parent']
-        parent_type = parent['type']
-        if parent_type == 'block_id':
+        parent = response_json["parent"]
+        parent_type = parent["type"]
+        if parent_type == "block_id":
             return self.notion_block_parent_page_id(access_token, parent[parent_type])
         return parent[parent_type]
 
     def notion_workspace_name(self, access_token: str):
         headers = {
-            'Authorization': f"Bearer {access_token}",
-            'Notion-Version': '2022-06-28',
+            "Authorization": f"Bearer {access_token}",
+            "Notion-Version": "2022-06-28",
         }
         response = requests.get(url=self._NOTION_BOT_USER, headers=headers)
         response_json = response.json()
-        if 'object' in response_json and response_json['object'] == 'user':
-            user_type = response_json['type']
+        if "object" in response_json and response_json["object"] == "user":
+            user_type = response_json["type"]
             user_info = response_json[user_type]
-            if 'workspace_name' in user_info:
-                return user_info['workspace_name']
-        return 'workspace'
+            if "workspace_name" in user_info:
+                return user_info["workspace_name"]
+        return "workspace"
 
     def notion_database_search(self, access_token: str):
-        data = {
-            'filter': {
-                "value": "database",
-                "property": "object"
-            }
-        }
+        data = {"filter": {"value": "database", "property": "object"}}
         headers = {
-            'Content-Type': 'application/json',
-            'Authorization': f"Bearer {access_token}",
-            'Notion-Version': '2022-06-28',
+            "Content-Type": "application/json",
+            "Authorization": f"Bearer {access_token}",
+            "Notion-Version": "2022-06-28",
         }
         response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
         response_json = response.json()
-        results = response_json.get('results', [])
+        results = response_json.get("results", [])
         return results

+ 5 - 5
api/libs/passport.py

@@ -9,14 +9,14 @@ class PassportService:
         self.sk = dify_config.SECRET_KEY
 
     def issue(self, payload):
-        return jwt.encode(payload, self.sk, algorithm='HS256')
+        return jwt.encode(payload, self.sk, algorithm="HS256")
 
     def verify(self, token):
         try:
-            return jwt.decode(token, self.sk, algorithms=['HS256'])
+            return jwt.decode(token, self.sk, algorithms=["HS256"])
         except jwt.exceptions.InvalidSignatureError:
-            raise Unauthorized('Invalid token signature.')
+            raise Unauthorized("Invalid token signature.")
         except jwt.exceptions.DecodeError:
-            raise Unauthorized('Invalid token.')
+            raise Unauthorized("Invalid token.")
         except jwt.exceptions.ExpiredSignatureError:
-            raise Unauthorized('Token has expired.')
+            raise Unauthorized("Token has expired.")

+ 3 - 2
api/libs/password.py

@@ -5,6 +5,7 @@ import re
 
 password_pattern = r"^(?=.*[a-zA-Z])(?=.*\d).{8,}$"
 
+
 def valid_password(password):
     # Define a regex pattern for password rules
     pattern = password_pattern
@@ -12,11 +13,11 @@ def valid_password(password):
     if re.match(pattern, password) is not None:
         return password
 
-    raise ValueError('Not a valid password.')
+    raise ValueError("Not a valid password.")
 
 
 def hash_password(password_str, salt_byte):
-    dk = hashlib.pbkdf2_hmac('sha256', password_str.encode('utf-8'), salt_byte, 10000)
+    dk = hashlib.pbkdf2_hmac("sha256", password_str.encode("utf-8"), salt_byte, 10000)
     return binascii.hexlify(dk)
 
 

+ 6 - 6
api/libs/rsa.py

@@ -48,7 +48,7 @@ def encrypt(text, public_key):
 def get_decrypt_decoding(tenant_id):
     filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem"
 
-    cache_key = 'tenant_privkey:{hash}'.format(hash=hashlib.sha3_256(filepath.encode()).hexdigest())
+    cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest())
     private_key = redis_client.get(cache_key)
     if not private_key:
         try:
@@ -66,12 +66,12 @@ def get_decrypt_decoding(tenant_id):
 
 def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa):
     if encrypted_text.startswith(prefix_hybrid):
-        encrypted_text = encrypted_text[len(prefix_hybrid):]
+        encrypted_text = encrypted_text[len(prefix_hybrid) :]
 
-        enc_aes_key = encrypted_text[:rsa_key.size_in_bytes()]
-        nonce = encrypted_text[rsa_key.size_in_bytes():rsa_key.size_in_bytes() + 16]
-        tag = encrypted_text[rsa_key.size_in_bytes() + 16:rsa_key.size_in_bytes() + 32]
-        ciphertext = encrypted_text[rsa_key.size_in_bytes() + 32:]
+        enc_aes_key = encrypted_text[: rsa_key.size_in_bytes()]
+        nonce = encrypted_text[rsa_key.size_in_bytes() : rsa_key.size_in_bytes() + 16]
+        tag = encrypted_text[rsa_key.size_in_bytes() + 16 : rsa_key.size_in_bytes() + 32]
+        ciphertext = encrypted_text[rsa_key.size_in_bytes() + 32 :]
 
         aes_key = cipher_rsa.decrypt(enc_aes_key)
 

+ 9 - 7
api/libs/smtp.py

@@ -5,7 +5,9 @@ from email.mime.text import MIMEText
 
 
 class SMTPClient:
-    def __init__(self, server: str, port: int, username: str, password: str, _from: str, use_tls=False, opportunistic_tls=False):
+    def __init__(
+        self, server: str, port: int, username: str, password: str, _from: str, use_tls=False, opportunistic_tls=False
+    ):
         self.server = server
         self.port = port
         self._from = _from
@@ -25,17 +27,17 @@ class SMTPClient:
                     smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10)
             else:
                 smtp = smtplib.SMTP(self.server, self.port, timeout=10)
-                
+
             if self.username and self.password:
                 smtp.login(self.username, self.password)
 
             msg = MIMEMultipart()
-            msg['Subject'] = mail['subject']
-            msg['From'] = self._from
-            msg['To'] = mail['to']
-            msg.attach(MIMEText(mail['html'], 'html'))
+            msg["Subject"] = mail["subject"]
+            msg["From"] = self._from
+            msg["To"] = mail["to"]
+            msg.attach(MIMEText(mail["html"], "html"))
 
-            smtp.sendmail(self._from, mail['to'], msg.as_string())
+            smtp.sendmail(self._from, mail["to"], msg.as_string())
         except smtplib.SMTPException as e:
             logging.error(f"SMTP error occurred: {str(e)}")
             raise

+ 0 - 2
api/pyproject.toml

@@ -73,12 +73,10 @@ exclude = [
     "core/**/*.py",
     "controllers/**/*.py",
     "models/**/*.py",
-    "utils/**/*.py",
     "migrations/**/*",
     "services/**/*.py",
     "tasks/**/*.py",
     "tests/**/*.py",
-    "libs/**/*.py",
     "configs/**/*.py",
 ]