Browse Source

chore: make doc extractor node also can extract text by file extension (#9543)

非法操作 5 months ago
parent
commit
2346b0ab99

+ 34 - 4
api/core/workflow/nodes/document_extractor/node.py

@@ -75,7 +75,7 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
             )
 
 
-def _extract_text(*, file_content: bytes, mime_type: str) -> str:
+def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
     """Extract text from a file based on its MIME type."""
     if mime_type.startswith("text/plain") or mime_type in {"text/html", "text/htm", "text/markdown", "text/xml"}:
         return _extract_text_from_plain_text(file_content)
@@ -107,6 +107,33 @@ def _extract_text(*, file_content: bytes, mime_type: str) -> str:
         raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}")
 
 
+def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str:
+    """Extract text from a file based on its file extension."""
+    match file_extension:
+        case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml":
+            return _extract_text_from_plain_text(file_content)
+        case ".pdf":
+            return _extract_text_from_pdf(file_content)
+        case ".doc" | ".docx":
+            return _extract_text_from_doc(file_content)
+        case ".csv":
+            return _extract_text_from_csv(file_content)
+        case ".xls" | ".xlsx":
+            return _extract_text_from_excel(file_content)
+        case ".ppt":
+            return _extract_text_from_ppt(file_content)
+        case ".pptx":
+            return _extract_text_from_pptx(file_content)
+        case ".epub":
+            return _extract_text_from_epub(file_content)
+        case ".eml":
+            return _extract_text_from_eml(file_content)
+        case ".msg":
+            return _extract_text_from_msg(file_content)
+        case _:
+            raise UnsupportedFileTypeError(f"Unsupported Extension Type: {file_extension}")
+
+
 def _extract_text_from_plain_text(file_content: bytes) -> str:
     try:
         return file_content.decode("utf-8")
@@ -159,7 +186,10 @@ def _extract_text_from_file(file: File):
     if file.mime_type is None:
         raise UnsupportedFileTypeError("Unable to determine file type: MIME type is missing")
     file_content = _download_file_content(file)
-    extracted_text = _extract_text(file_content=file_content, mime_type=file.mime_type)
+    if file.transfer_method == FileTransferMethod.REMOTE_URL:
+        extracted_text = _extract_text_by_mime_type(file_content=file_content, mime_type=file.mime_type)
+    else:
+        extracted_text = _extract_text_by_file_extension(file_content=file_content, file_extension=file.extension)
     return extracted_text
 
 
@@ -172,7 +202,7 @@ def _extract_text_from_csv(file_content: bytes) -> str:
         if not rows:
             return ""
 
-        # Create markdown table
+        # Create Markdown table
         markdown_table = "| " + " | ".join(rows[0]) + " |\n"
         markdown_table += "| " + " | ".join(["---"] * len(rows[0])) + " |\n"
         for row in rows[1:]:
@@ -192,7 +222,7 @@ def _extract_text_from_excel(file_content: bytes) -> str:
         # Drop rows where all elements are NaN
         df.dropna(how="all", inplace=True)
 
-        # Convert DataFrame to markdown table
+        # Convert DataFrame to Markdown table
         markdown_table = df.to_markdown(index=False)
         return markdown_table
     except Exception as e:

+ 14 - 5
api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py

@@ -63,17 +63,24 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s
 
 
 @pytest.mark.parametrize(
-    ("mime_type", "file_content", "expected_text", "transfer_method"),
+    ("mime_type", "file_content", "expected_text", "transfer_method", "extension"),
     [
-        ("text/plain", b"Hello, world!", ["Hello, world!"], FileTransferMethod.LOCAL_FILE),
-        ("application/pdf", b"%PDF-1.5\n%Test PDF content", ["Mocked PDF content"], FileTransferMethod.LOCAL_FILE),
+        ("text/plain", b"Hello, world!", ["Hello, world!"], FileTransferMethod.LOCAL_FILE, ".txt"),
+        (
+            "application/pdf",
+            b"%PDF-1.5\n%Test PDF content",
+            ["Mocked PDF content"],
+            FileTransferMethod.LOCAL_FILE,
+            ".pdf",
+        ),
         (
             "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
             b"PK\x03\x04",
             ["Mocked DOCX content"],
-            FileTransferMethod.LOCAL_FILE,
+            FileTransferMethod.REMOTE_URL,
+            "",
         ),
-        ("text/plain", b"Remote content", ["Remote content"], FileTransferMethod.REMOTE_URL),
+        ("text/plain", b"Remote content", ["Remote content"], FileTransferMethod.REMOTE_URL, None),
     ],
 )
 def test_run_extract_text(
@@ -83,6 +90,7 @@ def test_run_extract_text(
     file_content,
     expected_text,
     transfer_method,
+    extension,
     monkeypatch,
 ):
     document_extractor_node.graph_runtime_state = mock_graph_runtime_state
@@ -92,6 +100,7 @@ def test_run_extract_text(
     mock_file.transfer_method = transfer_method
     mock_file.related_id = "test_file_id" if transfer_method == FileTransferMethod.LOCAL_FILE else None
     mock_file.remote_url = "https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None
+    mock_file.extension = extension
 
     mock_array_file_segment = Mock(spec=ArrayFileSegment)
     mock_array_file_segment.value = [mock_file]