Browse Source

refactor(question_classifier): improve error handling with custom exceptions (#10365)

-LAN- 5 months ago
parent
commit
d3e9930235

+ 6 - 0
api/core/workflow/nodes/question_classifier/exc.py

@@ -0,0 +1,6 @@
+class QuestionClassifierNodeError(ValueError):
+    """Base class for QuestionClassifierNode errors."""
+
+
+class InvalidModelTypeError(QuestionClassifierNodeError):
+    """Raised when the model is not a Large Language Model."""

+ 4 - 2
api/core/workflow/nodes/question_classifier/question_classifier_node.py

@@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
 from typing import TYPE_CHECKING, Any, Optional, cast
 
 from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
+from core.llm_generator.output_parser.errors import OutputParserError
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_manager import ModelInstance
 from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
@@ -24,6 +25,7 @@ from libs.json_in_md_parser import parse_and_check_json_markdown
 from models.workflow import WorkflowNodeExecutionStatus
 
 from .entities import QuestionClassifierNodeData
+from .exc import InvalidModelTypeError
 from .template_prompts import (
     QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1,
     QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2,
@@ -124,7 +126,7 @@ class QuestionClassifierNode(LLMNode):
                     category_name = classes_map[category_id_result]
                     category_id = category_id_result
 
-        except Exception:
+        except OutputParserError:
             logging.error(f"Failed to parse result text: {result_text}")
         try:
             process_data = {
@@ -309,4 +311,4 @@ class QuestionClassifierNode(LLMNode):
             )
 
         else:
-            raise ValueError(f"Model mode {model_mode} not support.")
+            raise InvalidModelTypeError(f"Model mode {model_mode} not support.")

+ 1 - 1
api/libs/json_in_md_parser.py

@@ -9,6 +9,7 @@ def parse_json_markdown(json_string: str) -> dict:
     starts = ["```json", "```", "``", "`", "{"]
     ends = ["```", "``", "`", "}"]
     end_index = -1
+    start_index = 0
     for s in starts:
         start_index = json_string.find(s)
         if start_index != -1:
@@ -24,7 +25,6 @@ def parse_json_markdown(json_string: str) -> dict:
                 break
     if start_index != -1 and end_index != -1 and start_index < end_index:
         extracted_content = json_string[start_index:end_index].strip()
-        print("content:", extracted_content, start_index, end_index)
         parsed = json.loads(extracted_content)
     else:
         raise Exception("Could not find JSON block in the output.")