Browse Source

Feature/self host notion import (#397)

Jyong 1 year ago
parent
commit
226f28edcb

+ 3 - 0
api/config.py

@@ -190,6 +190,9 @@ class Config:
         # notion import setting
         self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
         self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
+        self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
+        self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
+
 
 class CloudEditionConfig(Config):
 

+ 8 - 2
api/controllers/console/auth/data_source_oauth.py

@@ -39,9 +39,15 @@ class OAuthDataSource(Resource):
             print(vars(oauth_provider))
         if not oauth_provider:
             return {'error': 'Invalid provider'}, 400
+        if current_app.config.get('NOTION_INTEGRATION_TYPE') == 'internal':
+            internal_secret = current_app.config.get('NOTION_INTERNAL_SECRET')
+            oauth_provider.save_internal_access_token(internal_secret)
+            return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source=success')
+        else:
+            auth_url = oauth_provider.get_authorization_url()
+            return redirect(auth_url)
+
 
-        auth_url = oauth_provider.get_authorization_url()
-        return redirect(auth_url)
 
 
 class OAuthDataSourceCallback(Resource):

+ 4 - 2
api/core/data_source/notion.py

@@ -84,7 +84,8 @@ class NotionPageReader(BaseReader):
                                     heading = text
                     result_block_id = result["id"]
                     has_children = result["has_children"]
-                    if has_children:
+                    block_type = result["type"]
+                    if has_children and block_type != 'child_page':
                         children_text = self._read_block(
                             result_block_id, num_tabs=num_tabs + 1
                         )
@@ -184,7 +185,8 @@ class NotionPageReader(BaseReader):
 
                     result_block_id = result["id"]
                     has_children = result["has_children"]
-                    if has_children:
+                    block_type = result["type"]
+                    if has_children and block_type != 'child_page':
                         children_text = self._read_block(
                             result_block_id, num_tabs=num_tabs + 1
                         )

+ 58 - 2
api/libs/oauth_data_source.py

@@ -26,6 +26,7 @@ class NotionOAuth(OAuthDataSource):
     _TOKEN_URL = 'https://api.notion.com/v1/oauth/token'
     _NOTION_PAGE_SEARCH = "https://api.notion.com/v1/search"
     _NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks"
+    _NOTION_BOT_USER = "https://api.notion.com/v1/users/me"
 
     def get_authorization_url(self):
         params = {
@@ -84,6 +85,41 @@ class NotionOAuth(OAuthDataSource):
             db.session.add(new_data_source_binding)
             db.session.commit()
 
+    def save_internal_access_token(self, access_token: str):
+        workspace_name = self.notion_workspace_name(access_token)
+        workspace_icon = None
+        workspace_id = current_user.current_tenant_id
+        # get all authorized pages
+        pages = self.get_authorized_pages(access_token)
+        source_info = {
+            'workspace_name': workspace_name,
+            'workspace_icon': workspace_icon,
+            'workspace_id': workspace_id,
+            'pages': pages,
+            'total': len(pages)
+        }
+        # save data source binding
+        data_source_binding = DataSourceBinding.query.filter(
+            db.and_(
+                DataSourceBinding.tenant_id == current_user.current_tenant_id,
+                DataSourceBinding.provider == 'notion',
+                DataSourceBinding.access_token == access_token
+            )
+        ).first()
+        if data_source_binding:
+            data_source_binding.source_info = source_info
+            data_source_binding.disabled = False
+            db.session.commit()
+        else:
+            new_data_source_binding = DataSourceBinding(
+                tenant_id=current_user.current_tenant_id,
+                access_token=access_token,
+                source_info=source_info,
+                provider='notion'
+            )
+            db.session.add(new_data_source_binding)
+            db.session.commit()
+
     def sync_data_source(self, binding_id: str):
         # save data source binding
         data_source_binding = DataSourceBinding.query.filter(
@@ -222,7 +258,10 @@ class NotionOAuth(OAuthDataSource):
         }
         response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
         response_json = response.json()
-        results = response_json['results']
+        if 'results' in response_json:
+            results = response_json['results']
+        else:
+            results = []
         return results
 
     def notion_block_parent_page_id(self, access_token: str, block_id: str):
@@ -238,6 +277,20 @@ class NotionOAuth(OAuthDataSource):
             return self.notion_block_parent_page_id(access_token, parent[parent_type])
         return parent[parent_type]
 
+    def notion_workspace_name(self, access_token: str):
+        headers = {
+            'Authorization': f"Bearer {access_token}",
+            'Notion-Version': '2022-06-28',
+        }
+        response = requests.get(url=self._NOTION_BOT_USER, headers=headers)
+        response_json = response.json()
+        if 'object' in response_json and response_json['object'] == 'user':
+            user_type = response_json['type']
+            user_info = response_json[user_type]
+            if 'workspace_name' in user_info:
+                return user_info['workspace_name']
+        return 'workspace'
+
     def notion_database_search(self, access_token: str):
         data = {
             'filter': {
@@ -252,5 +305,8 @@ class NotionOAuth(OAuthDataSource):
         }
         response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
         response_json = response.json()
-        results = response_json['results']
+        if 'results' in response_json:
+            results = response_json['results']
+        else:
+            results = []
         return results