Sfoglia il codice sorgente

[Fix] Update road connection (#79)

* [Fix] Update road connection

* Remove note

* Fix code style

Co-authored-by: Bobholamovic <mhlin425@whu.edu.cn>
Yizhou Chen 2 anni fa
parent
commit
4bf530af99

+ 10 - 5
paddlers/utils/postprocs/connection.py

@@ -28,7 +28,9 @@ with warnings.catch_warnings():
 from .utils import prepro_mask, calc_distance
 from .utils import prepro_mask, calc_distance
 
 
 
 
-def cut_road_connection(mask: np.ndarray, line_width: int=6) -> np.ndarray:
+def cut_road_connection(mask: np.ndarray,
+                        area_threshold: int=32,
+                        line_width: int=6) -> np.ndarray:
     """
     """
     Connecting cut road lines.
     Connecting cut road lines.
 
 
@@ -39,21 +41,24 @@ def cut_road_connection(mask: np.ndarray, line_width: int=6) -> np.ndarray:
     This algorithm has no public code.
     This algorithm has no public code.
     The implementation procedure refers to original article,
     The implementation procedure refers to original article,
     and it is not fully consistent with the article:
     and it is not fully consistent with the article:
-    1. The way to determine the optimal number of clusters k used in k-means clustering is not described in the original article. In this implementation, we use the k that reports the highest silhouette score.
+    1. The way to determine the optimal number of clusters k used in k-means clustering is not described in the original article. 
+        In this implementation, we use the k that reports the highest silhouette score.
     2. We unmark the breakpoints if the angle between the two road extensions is less than 90°.
     2. We unmark the breakpoints if the angle between the two road extensions is less than 90°.
 
 
     Args:
     Args:
         mask (np.ndarray): Mask of road.
         mask (np.ndarray): Mask of road.
-        line_width (int, optional): Width of the line used for patching.
-            . Default is 6.
+        area_threshold (int, optional): Threshold to filter out small connected area. Default is 32.
+        line_width (int, optional): Width of the line used for patching. Default is 6.
 
 
     Returns:
     Returns:
         np.ndarray: Mask of road after connecting cut road lines.
         np.ndarray: Mask of road after connecting cut road lines.
     """
     """
-    mask = prepro_mask(mask)
+    mask = prepro_mask(mask, area_threshold)
     skeleton = morphology.skeletonize(mask).astype("uint8")
     skeleton = morphology.skeletonize(mask).astype("uint8")
     break_points = _find_breakpoint(skeleton)
     break_points = _find_breakpoint(skeleton)
     labels = _k_means(break_points)
     labels = _k_means(break_points)
+    if labels is None:
+        return mask * 255
     match_points = _get_match_points(break_points, labels)
     match_points = _get_match_points(break_points, labels)
     res = _draw_curve(mask, skeleton, match_points, line_width)
     res = _draw_curve(mask, skeleton, match_points, line_width)
     return res
     return res

+ 15 - 2
paddlers/utils/postprocs/utils.py

@@ -16,12 +16,12 @@ import numpy as np
 import cv2
 import cv2
 
 
 
 
-def prepro_mask(mask: np.ndarray):
+def prepro_mask(mask: np.ndarray, area_threshold: int=32) -> np.ndarray:
     mask_shape = mask.shape
     mask_shape = mask.shape
     if len(mask_shape) != 2:
     if len(mask_shape) != 2:
         mask = mask[..., 0]
         mask = mask[..., 0]
     mask = mask.astype("uint8")
     mask = mask.astype("uint8")
-    mask = cv2.medianBlur(mask, 5)
+    mask = _del_small_connection(mask, area_threshold)
     class_num = len(np.unique(mask))
     class_num = len(np.unique(mask))
     if class_num != 2:
     if class_num != 2:
         _, mask = cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY |
         _, mask = cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY |
@@ -32,3 +32,16 @@ def prepro_mask(mask: np.ndarray):
 
 
 def calc_distance(p1: np.ndarray, p2: np.ndarray) -> float:
 def calc_distance(p1: np.ndarray, p2: np.ndarray) -> float:
     return float(np.sqrt(np.sum(np.power((p1[0] - p2[0]), 2))))
     return float(np.sqrt(np.sum(np.power((p1[0] - p2[0]), 2))))
+
+
+def _del_small_connection(pred: np.ndarray, threshold: int=32) -> np.ndarray:
+    result = np.zeros_like(pred)
+    contours, reals = cv2.findContours(pred, cv2.RETR_TREE,
+                                       cv2.CHAIN_APPROX_NONE)
+    for contour, real in zip(contours, reals[0]):
+        if real[-1] == -1:
+            if cv2.contourArea(contour) > threshold:
+                cv2.fillPoly(result, [contour], (1))
+        else:
+            cv2.fillPoly(result, [contour], (0))
+    return result.astype("uint8")