Forráskód Böngészése

feat: update notion extractor (#3898)

Co-authored-by: duyalei <>
yalei 11 hónapja
szülő
commit
026175c8f7

+ 16 - 22
api/core/rag/extractor/notion_extractor.py

@@ -19,8 +19,12 @@ SEARCH_URL = "https://api.notion.com/v1/search"
 
 RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
 RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
-HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
-
+# if user want split by headings, use the corresponding splitter
+HEADING_SPLITTER = {
+    'heading_1': '# ',
+    'heading_2': '## ',
+    'heading_3': '### ',
+}
 
 class NotionExtractor(BaseExtractor):
 
@@ -73,8 +77,7 @@ class NotionExtractor(BaseExtractor):
             docs.extend(page_text_documents)
         elif notion_page_type == 'page':
             page_text_list = self._get_notion_block_data(notion_obj_id)
-            for page_text in page_text_list:
-                docs.append(Document(page_content=page_text))
+            docs.append(Document(page_content='\n'.join(page_text_list)))
         else:
             raise ValueError("notion page type not supported")
 
@@ -96,7 +99,7 @@ class NotionExtractor(BaseExtractor):
 
         data = res.json()
 
-        database_content_list = []
+        database_content = []
         if 'results' not in data or data["results"] is None:
             return []
         for result in data["results"]:
@@ -131,10 +134,9 @@ class NotionExtractor(BaseExtractor):
                     row_content = row_content + f'{key}:{value_content}\n'
                 else:
                     row_content = row_content + f'{key}:{value}\n'
-            document = Document(page_content=row_content)
-            database_content_list.append(document)
+            database_content.append(row_content)
 
-        return database_content_list
+        return [Document(page_content='\n'.join(database_content))]
 
     def _get_notion_block_data(self, page_id: str) -> list[str]:
         result_lines_arr = []
@@ -154,8 +156,6 @@ class NotionExtractor(BaseExtractor):
                 json=query_dict
             )
             data = res.json()
-            # current block's heading
-            heading = ''
             for result in data["results"]:
                 result_type = result["type"]
                 result_obj = result[result_type]
@@ -172,8 +172,6 @@ class NotionExtractor(BaseExtractor):
                             if "text" in rich_text:
                                 text = rich_text["text"]["content"]
                                 cur_result_text_arr.append(text)
-                                if result_type in HEADING_TYPE:
-                                    heading = text
 
                     result_block_id = result["id"]
                     has_children = result["has_children"]
@@ -185,11 +183,10 @@ class NotionExtractor(BaseExtractor):
                         cur_result_text_arr.append(children_text)
 
                     cur_result_text = "\n".join(cur_result_text_arr)
-                    cur_result_text += "\n\n"
-                    if result_type in HEADING_TYPE:
-                        result_lines_arr.append(cur_result_text)
+                    if result_type in HEADING_SPLITTER:
+                        result_lines_arr.append(f"{HEADING_SPLITTER[result_type]}{cur_result_text}")
                     else:
-                        result_lines_arr.append(f'{heading}\n{cur_result_text}')
+                        result_lines_arr.append(cur_result_text + '\n\n')
 
             if data["next_cursor"] is None:
                 break
@@ -218,7 +215,6 @@ class NotionExtractor(BaseExtractor):
             data = res.json()
             if 'results' not in data or data["results"] is None:
                 break
-            heading = ''
             for result in data["results"]:
                 result_type = result["type"]
                 result_obj = result[result_type]
@@ -235,8 +231,6 @@ class NotionExtractor(BaseExtractor):
                                 text = rich_text["text"]["content"]
                                 prefix = "\t" * num_tabs
                                 cur_result_text_arr.append(prefix + text)
-                                if result_type in HEADING_TYPE:
-                                    heading = text
                     result_block_id = result["id"]
                     has_children = result["has_children"]
                     block_type = result["type"]
@@ -247,10 +241,10 @@ class NotionExtractor(BaseExtractor):
                         cur_result_text_arr.append(children_text)
 
                     cur_result_text = "\n".join(cur_result_text_arr)
-                    if result_type in HEADING_TYPE:
-                        result_lines_arr.append(cur_result_text)
+                    if result_type in HEADING_SPLITTER:
+                        result_lines_arr.append(f'{HEADING_SPLITTER[result_type]}{cur_result_text}')
                     else:
-                        result_lines_arr.append(f'{heading}\n{cur_result_text}')
+                        result_lines_arr.append(cur_result_text + '\n\n')
 
             if data["next_cursor"] is None:
                 break

+ 0 - 0
api/tests/unit_tests/core/rag/extractor/__init__.py


+ 102 - 0
api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py

@@ -0,0 +1,102 @@
+from unittest import mock
+
+from core.rag.extractor import notion_extractor
+
+user_id = "user1"
+database_id = "database1"
+page_id = "page1"
+
+
+extractor = notion_extractor.NotionExtractor(
+        notion_workspace_id='x',
+        notion_obj_id='x',
+        notion_page_type='page',
+        tenant_id='x',
+        notion_access_token='x')
+
+
+def _generate_page(page_title: str):
+    return {
+        "object": "page",
+        "id": page_id,
+        "properties": {
+            "Page": {
+                "type": "title", 
+                "title": [
+                    {
+                        "type": "text",
+                        "text": {"content": page_title},
+                        "plain_text": page_title
+                    }
+                ]
+            }
+        }
+    }
+
+
+def _generate_block(block_id: str, block_type: str, block_text: str):
+    return {
+        "object": "block",
+        "id": block_id,
+        "parent": {
+            "type": "page_id",
+            "page_id": page_id
+        },
+        "type": block_type,
+        "has_children": False,
+        block_type: {
+            "rich_text": [
+                {
+                    "type": "text",
+                    "text": {"content": block_text},
+                   "plain_text": block_text,
+               }]
+           }
+       }
+
+
+def _mock_response(data):
+    response = mock.Mock()
+    response.status_code = 200
+    response.json.return_value = data
+    return response
+
+
+def _remove_multiple_new_lines(text):
+    while '\n\n' in text:
+        text = text.replace("\n\n", "\n")
+    return text.strip()
+
+
+def test_notion_page(mocker):
+    texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"]
+    mocked_notion_page = {
+    "object": "list",
+    "results": [
+        _generate_block("b1", "heading_1", texts[0]),
+        _generate_block("b2", "heading_2", texts[1]),
+        _generate_block("b3", "paragraph", texts[2]),
+        _generate_block("b4", "heading_3", texts[3])
+    ],
+    "next_cursor": None
+    }
+    mocker.patch("requests.request", return_value=_mock_response(mocked_notion_page))
+
+    page_docs = extractor._load_data_as_documents(page_id, "page")
+    assert len(page_docs) == 1
+    content = _remove_multiple_new_lines(page_docs[0].page_content)
+    assert content == '# Head 1\n## 1.1\nparagraph 1\n### 1.1.1'
+
+
+def test_notion_database(mocker):
+    page_title_list = ["page1", "page2", "page3"]
+    mocked_notion_database = {
+        "object": "list",
+        "results": [_generate_page(i) for i in page_title_list],
+        "next_cursor": None
+    }
+    mocker.patch("requests.post", return_value=_mock_response(mocked_notion_database))
+    database_docs = extractor._load_data_as_documents(database_id, "database")
+    assert len(database_docs) == 1
+    content = _remove_multiple_new_lines(database_docs[0].page_content)
+    assert content == '\n'.join([f'Page:{i}' for i in page_title_list])