Ver Fonte

[Fix] Update road connection (#79)

* [Fix] Update road connection

* Remove note

* Fix code style

Co-authored-by: Bobholamovic <mhlin425@whu.edu.cn>
Yizhou Chen há 2 anos atrás
pai
commit
4bf530af99
2 ficheiros alterados com 25 adições e 7 exclusões
  1. 10 5
      paddlers/utils/postprocs/connection.py
  2. 15 2
      paddlers/utils/postprocs/utils.py

+ 10 - 5
paddlers/utils/postprocs/connection.py

@@ -28,7 +28,9 @@ with warnings.catch_warnings():
 from .utils import prepro_mask, calc_distance
 
 
-def cut_road_connection(mask: np.ndarray, line_width: int=6) -> np.ndarray:
+def cut_road_connection(mask: np.ndarray,
+                        area_threshold: int=32,
+                        line_width: int=6) -> np.ndarray:
     """
     Connecting cut road lines.
 
@@ -39,21 +41,24 @@ def cut_road_connection(mask: np.ndarray, line_width: int=6) -> np.ndarray:
     This algorithm has no public code.
     The implementation procedure refers to original article,
     and it is not fully consistent with the article:
-    1. The way to determine the optimal number of clusters k used in k-means clustering is not described in the original article. In this implementation, we use the k that reports the highest silhouette score.
+    1. The way to determine the optimal number of clusters k used in k-means clustering is not described in the original article. 
+        In this implementation, we use the k that reports the highest silhouette score.
     2. We unmark the breakpoints if the angle between the two road extensions is less than 90°.
 
     Args:
         mask (np.ndarray): Mask of road.
-        line_width (int, optional): Width of the line used for patching.
-            . Default is 6.
+        area_threshold (int, optional): Threshold to filter out small connected area. Default is 32.
+        line_width (int, optional): Width of the line used for patching. Default is 6.
 
     Returns:
         np.ndarray: Mask of road after connecting cut road lines.
     """
-    mask = prepro_mask(mask)
+    mask = prepro_mask(mask, area_threshold)
     skeleton = morphology.skeletonize(mask).astype("uint8")
     break_points = _find_breakpoint(skeleton)
     labels = _k_means(break_points)
+    if labels is None:
+        return mask * 255
     match_points = _get_match_points(break_points, labels)
     res = _draw_curve(mask, skeleton, match_points, line_width)
     return res

+ 15 - 2
paddlers/utils/postprocs/utils.py

@@ -16,12 +16,12 @@ import numpy as np
 import cv2
 
 
-def prepro_mask(mask: np.ndarray):
+def prepro_mask(mask: np.ndarray, area_threshold: int=32) -> np.ndarray:
     mask_shape = mask.shape
     if len(mask_shape) != 2:
         mask = mask[..., 0]
     mask = mask.astype("uint8")
-    mask = cv2.medianBlur(mask, 5)
+    mask = _del_small_connection(mask, area_threshold)
     class_num = len(np.unique(mask))
     if class_num != 2:
         _, mask = cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY |
@@ -32,3 +32,16 @@ def prepro_mask(mask: np.ndarray):
 
 def calc_distance(p1: np.ndarray, p2: np.ndarray) -> float:
     return float(np.sqrt(np.sum(np.power((p1[0] - p2[0]), 2))))
+
+
+def _del_small_connection(pred: np.ndarray, threshold: int=32) -> np.ndarray:
+    result = np.zeros_like(pred)
+    contours, reals = cv2.findContours(pred, cv2.RETR_TREE,
+                                       cv2.CHAIN_APPROX_NONE)
+    for contour, real in zip(contours, reals[0]):
+        if real[-1] == -1:
+            if cv2.contourArea(contour) > threshold:
+                cv2.fillPoly(result, [contour], (1))
+        else:
+            cv2.fillPoly(result, [contour], (0))
+    return result.astype("uint8")