Jelajahi Sumber

[Feat] Add New Postproc (Cut Road Connection) (#52)

*
Yizhou Chen 2 tahun lalu
induk
melakukan
7390b9e6ad

+ 15 - 1
README.md

@@ -119,6 +119,7 @@ PaddleRS具有以下五大特色:
           <li>ReduceDim</li>  
           <li>SelectBand</li>  
           <li>RandomSwap</li>
+          <li>AppendIndex</li>
           <li>...</li>
         </ul>  
       </td>
@@ -138,6 +139,17 @@ PaddleRS具有以下五大特色:
           <li>辐射校正</li>
           <li>...</li>
         </ul>
+        <b>数据后处理</b><br>
+        <ul>
+          <li>建筑边界规则化</li>
+          <li>道路断线连接</li>
+          <li>...</li>
+        </ul>
+        <b>数据可视化</b><br>
+        <ul>
+          <li>地图-栅格可视化</li>
+          <li>...</li>
+        </ul>
       </td>
       <td>
         <b>遥感场景分类</b><br>
@@ -177,8 +189,10 @@ PaddleRS目录树中关键部分如下:
 │     ├── datasets       # 数据集接口实现
 │     ├── models         # 视觉模型实现
 │     ├── tasks          # 训练器实现
-│     └── transforms     # 数据预处理/数据增强实现
+│     ├── transforms     # 数据预处理/数据增强实现
+│     └── utils          # 数据下载/可视化/后处理等
 ├── tools                # 遥感影像处理工具集
+├── examples             # 相关实践案例
 └── tutorials
       └── train          # 模型训练教程
 ```

+ 1 - 0
paddlers/utils/postprocs/__init__.py

@@ -13,3 +13,4 @@
 # limitations under the License.
 
 from .regularization import building_regularization
+from .connection import cut_road_connection

+ 278 - 0
paddlers/utils/postprocs/connection.py

@@ -0,0 +1,278 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import warnings
+
+import cv2
+import numpy as np
+from skimage import morphology
+from scipy import ndimage, optimize
+
+with warnings.catch_warnings():
+    warnings.filterwarnings("ignore", category=DeprecationWarning)
+    from sklearn import metrics
+    from sklearn.cluster import KMeans
+
+from .utils import prepro_mask, calc_distance
+
+
+def cut_road_connection(mask: np.ndarray, line_width: int=6) -> np.ndarray:
+    """
+    Connecting cut road lines.
+
+    The original article refers to
+    Wang B, Chen Z, et al. "Road extraction of high-resolution satellite remote sensing images in U-Net network with consideration of connectivity."
+    (http://hgs.publish.founderss.cn/thesisDetails?columnId=4759509).
+
+    This algorithm has no public code.
+    The implementation procedure refers to original article,
+    and it is not fully consistent with the article:
+    1. The way to determine the optimal number of clusters k used in k-means clustering is not described in the original article. In this implementation, we use the k that reports the highest silhouette score.
+    2. We unmark the breakpoints if the angle between the two road extensions is less than 90°.
+
+    Args:
+        mask (np.ndarray): Mask of road.
+        line_width (int, optional): Width of the line used for patching.
+            . Default is 6.
+
+    Returns:
+        np.ndarray: Mask of road after connecting cut road lines.
+    """
+    mask = prepro_mask(mask)
+    skeleton = morphology.skeletonize(mask).astype("uint8")
+    break_points = _find_breakpoint(skeleton)
+    labels = _k_means(break_points)
+    match_points = _get_match_points(break_points, labels)
+    res = _draw_curve(mask, skeleton, match_points, line_width)
+    return res
+
+
+def _find_breakpoint(skeleton):
+    kernel_3x3 = np.ones((3, 3), dtype="uint8")
+    k3 = ndimage.convolve(skeleton, kernel_3x3)
+    point_map = np.zeros_like(k3)
+    point_map[k3 == 2] = 1
+    point_map *= skeleton * 255
+    # boundary filtering
+    filter_w = 5
+    cropped = point_map[filter_w:-filter_w, filter_w:-filter_w]
+    padded = np.pad(cropped, (filter_w, filter_w), mode="constant")
+    breakpoints = np.column_stack(np.where(padded == 255))
+    return breakpoints
+
+
+def _k_means(data):
+    silhouette_int = -1  # threshold
+    labels = None
+    for k in range(2, data.shape[0]):
+        kms = KMeans(k, random_state=66)
+        labels_tmp = kms.fit_predict(data)  # train
+        silhouette = metrics.silhouette_score(data, labels_tmp)
+        if silhouette > silhouette_int:  # better
+            silhouette_int = silhouette
+            labels = labels_tmp
+    return labels
+
+
+def _get_match_points(break_points, labels):
+    match_points = {}
+    for point, lab in zip(break_points, labels):
+        if lab in match_points.keys():
+            match_points[lab].append(point)
+        else:
+            match_points[lab] = [point]
+    return match_points
+
+
+def _draw_curve(mask, skeleton, match_points, line_width):
+    result = mask * 255
+    for v in match_points.values():
+        p_num = len(v)
+        if p_num == 2:
+            points_list = _curve_backtracking(v, skeleton)
+            if points_list is not None:
+                result = _broken_wire_repair(result, points_list, line_width)
+        elif p_num == 3:
+            sim_v = list(itertools.combinations(v, 2))
+            min_di = 1e6
+            for vij in sim_v:
+                di = calc_distance(vij[0][np.newaxis], vij[1][np.newaxis])
+                if di < min_di:
+                    vv = vij
+                    min_di = di
+            points_list = _curve_backtracking(vv, skeleton)
+            if points_list is not None:
+                result = _broken_wire_repair(result, points_list, line_width)
+    return result
+
+
+def _curve_backtracking(add_lines, skeleton):
+    points_list = []
+    p1 = add_lines[0]
+    p2 = add_lines[1]
+    bpk1, ps1 = _calc_angle_by_road(p1, skeleton)
+    bpk2, ps2 = _calc_angle_by_road(p2, skeleton)
+    if _check_angle(bpk1, bpk2):
+        points_list.append((
+            np.array(
+                ps1, dtype="int64"),
+            add_lines[0],
+            add_lines[1],
+            np.array(
+                ps2, dtype="int64"), ))
+        return points_list
+    else:
+        return None
+
+
+def _broken_wire_repair(mask, points_list, line_width):
+    d_mask = mask.copy()
+    for points in points_list:
+        nx, ny = _line_cubic(points)
+        for i in range(len(nx) - 1):
+            loc_p1 = (int(ny[i]), int(nx[i]))
+            loc_p2 = (int(ny[i + 1]), int(nx[i + 1]))
+            cv2.line(d_mask, loc_p1, loc_p2, [255], line_width)
+    return d_mask
+
+
+def _calc_angle_by_road(p, skeleton, num_circle=10):
+    def _not_in(p1, ps):
+        for p in ps:
+            if p1[0] == p[0] and p1[1] == p[1]:
+                return False
+        return True
+
+    h, w = skeleton.shape
+    tmp_p = p.tolist() if isinstance(p, np.ndarray) else p
+    tmp_p = [int(tmp_p[0]), int(tmp_p[1])]
+    ps = []
+    ps.append(tmp_p)
+    for _ in range(num_circle):
+        t_x = 0 if tmp_p[0] - 1 < 0 else tmp_p[0] - 1
+        t_y = 0 if tmp_p[1] - 1 < 0 else tmp_p[1] - 1
+        b_x = w if tmp_p[0] + 1 >= w else tmp_p[0] + 1
+        b_y = h if tmp_p[1] + 1 >= h else tmp_p[1] + 1
+        if int(np.sum(skeleton[t_x:b_x + 1, t_y:b_y + 1])) <= 3:
+            for i in range(t_x, b_x + 1):
+                for j in range(t_y, b_y + 1):
+                    if skeleton[i, j] == 1:
+                        pp = [int(i), int(j)]
+                        if _not_in(pp, ps):
+                            tmp_p = pp
+                            ps.append(tmp_p)
+    # calc angle
+    theta = _angle_regression(ps)
+    dx, dy = np.cos(theta), np.sin(theta)
+    # calc direction
+    start = ps[-1]
+    end = ps[0]
+    if end[1] < start[1] or (end[1] == start[1] and end[0] < start[0]):
+        dx *= -1
+        dy *= -1
+    return [dx, dy], start
+
+
+def _angle_regression(datas):
+    def _linear(x: float, k: float, b: float) -> float:
+        return k * x + b
+
+    xs = []
+    ys = []
+    for data in datas:
+        xs.append(data[0])
+        ys.append(data[1])
+    xs_arr = np.array(xs)
+    ys_arr = np.array(ys)
+    # horizontal
+    if len(np.unique(xs_arr)) == 1:
+        theta = np.pi / 2
+    # vertical
+    elif len(np.unique(ys_arr)) == 1:
+        theta = 0
+    # cross calc
+    else:
+        k1, b1 = optimize.curve_fit(_linear, xs_arr, ys_arr)[0]
+        k2, b2 = optimize.curve_fit(_linear, ys_arr, xs_arr)[0]
+        err1 = 0
+        err2 = 0
+        for x, y in zip(xs_arr, ys_arr):
+            err1 += abs(_linear(x, k1, b1) - y) / np.sqrt(k1**2 + 1)
+            err2 += abs(_linear(y, k2, b2) - x) / np.sqrt(k2**2 + 1)
+        if err1 <= err2:
+            theta = (np.arctan(k1) + 2 * np.pi) % (2 * np.pi)
+        else:
+            theta = (np.pi / 2.0 - np.arctan(k2) + 2 * np.pi) % (2 * np.pi)
+    # [0, 180)
+    theta = theta * 180 / np.pi + 90
+    while theta >= 180:
+        theta -= 180
+    theta -= 90
+    if theta < 0:
+        theta += 180
+    return theta * np.pi / 180
+
+
+def _cubic(x, y):
+    def _func(x, a, b, c, d):
+        return a * x**3 + b * x**2 + c * x + d
+
+    arr_x = np.array(x).reshape((4, ))
+    arr_y = np.array(y).reshape((4, ))
+    popt1 = np.polyfit(arr_x, arr_y, 3)
+    popt2 = np.polyfit(arr_y, arr_x, 3)
+    x_min = np.min(arr_x)
+    x_max = np.max(arr_x)
+    y_min = np.min(arr_y)
+    y_max = np.max(arr_y)
+    nx = np.arange(x_min, x_max + 1, 1)
+    y_estimate = [_func(i, popt1[0], popt1[1], popt1[2], popt1[3]) for i in nx]
+    ny = np.arange(y_min, y_max + 1, 1)
+    x_estimate = [_func(i, popt2[0], popt2[1], popt2[2], popt2[3]) for i in ny]
+    if np.max(y_estimate) - np.min(y_estimate) <= np.max(x_estimate) - np.min(
+            x_estimate):
+        return nx, y_estimate
+    else:
+        return x_estimate, ny
+
+
+def _line_cubic(points):
+    xs = []
+    ys = []
+    for p in points:
+        x, y = p
+        xs.append(x)
+        ys.append(y)
+    nx, ny = _cubic(xs, ys)
+    return nx, ny
+
+
+def _get_theta(dy, dx):
+    theta = np.arctan2(dy, dx) * 180 / np.pi
+    if theta < 0.0:
+        theta = 360.0 - abs(theta)
+    return float(theta)
+
+
+def _check_angle(bpk1, bpk2, ang_threshold=90):
+    af1 = _get_theta(bpk1[0], bpk1[1])
+    af2 = _get_theta(bpk2[0], bpk2[1])
+    ang_diff = abs(af1 - af2)
+    if ang_diff > 180:
+        ang_diff = 360 - ang_diff
+    if ang_diff > ang_threshold:
+        return True
+    else:
+        return False

+ 120 - 35
paddlers/utils/postprocs/regularization.py

@@ -13,11 +13,11 @@
 # limitations under the License.
 
 import math
+
 import cv2
 import numpy as np
-from .utils import (calc_distance, calc_angle, calc_azimuth, rotation, line,
-                    intersection, calc_distance_between_lines,
-                    calc_project_in_line)
+
+from .utils import prepro_mask, calc_distance
 
 S = 20
 TD = 3
@@ -52,15 +52,7 @@ def building_regularization(mask: np.ndarray, W: int=32) -> np.ndarray:
         np.ndarray: Mask of building after regularized.
     """
     # check and pro processing
-    mask_shape = mask.shape
-    if len(mask_shape) != 2:
-        mask = mask[..., 0]
-    mask = cv2.medianBlur(mask, 5)
-    class_num = len(np.unique(mask))
-    if class_num != 2:
-        _, mask = cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY |
-                                cv2.THRESH_OTSU)
-    mask = np.clip(mask, 0, 1).astype("uint8")  # 0-255 / 0-1 -> 0-1
+    mask = prepro_mask(mask)
     mask_shape = mask.shape
     # find contours
     contours, hierarchys = cv2.findContours(mask, cv2.RETR_TREE,
@@ -115,7 +107,7 @@ def _coarse(contour, img_shape):
             continue
         # remove over-sharp angles with threshold α.
         # remove over-smooth angles with threshold β.
-        angle = calc_angle(last_point, current_point, next_point)
+        angle = _calc_angle(last_point, current_point, next_point)
         if (ALPHA > angle or angle > BETA) and _inline_check(current_point,
                                                              img_shape):
             contour = np.delete(contour, idx, axis=0)
@@ -143,7 +135,7 @@ def _fine(contour, W):
         next_idx = (idx + 1) % p_number
         next_point = contour[next_idx]
         distance_list.append(calc_distance(current_point, next_point))
-        azimuth_list.append(calc_azimuth(current_point, next_point))
+        azimuth_list.append(_calc_azimuth(current_point, next_point))
         indexs_list.append((idx, next_idx))
     # add the direction of the longest edge to the list of main direction.
     longest_distance_idx = np.argmax(distance_list)
@@ -177,11 +169,11 @@ def _fine(contour, W):
             abs_rotate_ang = abs(rotate_ang)
             # adjust long edges according to the list and angles.
             if abs_rotate_ang < DELTA or abs_rotate_ang > (180 - DELTA):
-                rp1 = rotation(p1, pm, rotate_ang)
-                rp2 = rotation(p2, pm, rotate_ang)
+                rp1 = _rotation(p1, pm, rotate_ang)
+                rp2 = _rotation(p2, pm, rotate_ang)
             elif (90 - DELTA) < abs_rotate_ang < (90 + DELTA):
-                rp1 = rotation(p1, pm, rotate_ang - 90)
-                rp2 = rotation(p2, pm, rotate_ang - 90)
+                rp1 = _rotation(p1, pm, rotate_ang - 90)
+                rp2 = _rotation(p2, pm, rotate_ang - 90)
             else:
                 rp1, rp2 = p1, p2
         # adjust short edges (judged by a threshold θ) according to the list and angles.
@@ -189,11 +181,11 @@ def _fine(contour, W):
             rotate_ang = md_used_list[-1] - azimuth
             abs_rotate_ang = abs(rotate_ang)
             if abs_rotate_ang < THETA or abs_rotate_ang > (180 - THETA):
-                rp1 = rotation(p1, pm, rotate_ang)
-                rp2 = rotation(p2, pm, rotate_ang)
+                rp1 = _rotation(p1, pm, rotate_ang)
+                rp2 = _rotation(p2, pm, rotate_ang)
             else:
-                rp1 = rotation(p1, pm, rotate_ang - 90)
-                rp2 = rotation(p2, pm, rotate_ang - 90)
+                rp1 = _rotation(p1, pm, rotate_ang - 90)
+                rp2 = _rotation(p2, pm, rotate_ang - 90)
         # contour_by_lines.extend([rp1, rp2])
         contour_by_lines.append([rp1[0], rp2[0]])
     correct_points = np.array(contour_by_lines)
@@ -208,35 +200,35 @@ def _fine(contour, W):
         cur_edge_p2 = correct_points[idx][1]
         next_edge_p1 = correct_points[next_idx][0]
         next_edge_p2 = correct_points[next_idx][1]
-        L1 = line(cur_edge_p1, cur_edge_p2)
-        L2 = line(next_edge_p1, next_edge_p2)
-        A1 = calc_azimuth([cur_edge_p1], [cur_edge_p2])
-        A2 = calc_azimuth([next_edge_p1], [next_edge_p2])
+        L1 = _line(cur_edge_p1, cur_edge_p2)
+        L2 = _line(next_edge_p1, next_edge_p2)
+        A1 = _calc_azimuth([cur_edge_p1], [cur_edge_p2])
+        A2 = _calc_azimuth([next_edge_p1], [next_edge_p2])
         dif_azi = abs(A1 - A2)
         # find intersection point if not parallel
         if (90 - DELTA) < dif_azi < (90 + DELTA):
-            point_intersection = intersection(L1, L2)
+            point_intersection = _intersection(L1, L2)
             if point_intersection is not None:
                 final_points.append(point_intersection)
         # move or add lines when parallel
         elif dif_azi < 1e-6:
-            marg = calc_distance_between_lines(L1, L2)
+            marg = _calc_distance_between_lines(L1, L2)
             if marg < D:
                 # move
-                point_move = calc_project_in_line(next_edge_p1, cur_edge_p1,
-                                                  cur_edge_p2)
+                point_move = _calc_project_in_line(next_edge_p1, cur_edge_p1,
+                                                   cur_edge_p2)
                 final_points.append(point_move)
                 # update next
                 correct_points[next_idx][0] = point_move
-                correct_points[next_idx][1] = calc_project_in_line(
+                correct_points[next_idx][1] = _calc_project_in_line(
                     next_edge_p2, cur_edge_p1, cur_edge_p2)
             else:
                 # add line
                 add_mid_point = (cur_edge_p2 + next_edge_p1) / 2
-                rp1 = calc_project_in_line(add_mid_point, cur_edge_p1,
-                                           cur_edge_p2)
-                rp2 = calc_project_in_line(add_mid_point, next_edge_p1,
-                                           next_edge_p2)
+                rp1 = _calc_project_in_line(add_mid_point, cur_edge_p1,
+                                            cur_edge_p2)
+                rp2 = _calc_project_in_line(add_mid_point, next_edge_p1,
+                                            next_edge_p2)
                 final_points.extend([rp1, rp2])
         else:
             final_points.extend(
@@ -262,3 +254,96 @@ def _fill(img, coarse_conts):
         else:
             cv2.fillPoly(result, [contour.astype(np.int32)], (255, 255, 255))
     return result
+
+
+def _calc_angle(p1, vertex, p2):
+    x1, y1 = p1[0]
+    xv, yv = vertex[0]
+    x2, y2 = p2[0]
+    a = ((xv - x2) * (xv - x2) + (yv - y2) * (yv - y2))**0.5
+    b = ((x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2))**0.5
+    c = ((x1 - xv) * (x1 - xv) + (y1 - yv) * (y1 - yv))**0.5
+    return math.degrees(math.acos((b**2 - a**2 - c**2) / (-2 * a * c)))
+
+
+def _calc_azimuth(p1, p2):
+    x1, y1 = p1[0]
+    x2, y2 = p2[0]
+    if y1 == y2:
+        return 0.0
+    if x1 == x2:
+        return 90.0
+    elif x1 < x2:
+        if y1 < y2:
+            ang = math.atan((y2 - y1) / (x2 - x1))
+            return math.degrees(ang)
+        else:
+            ang = math.atan((y1 - y2) / (x2 - x1))
+            return 180 - math.degrees(ang)
+    else:  # x1 > x2
+        if y1 < y2:
+            ang = math.atan((y2 - y1) / (x1 - x2))
+            return 180 - math.degrees(ang)
+        else:
+            ang = math.atan((y1 - y2) / (x1 - x2))
+            return math.degrees(ang)
+
+
+def _rotation(point, center, angle):
+    if angle == 0:
+        return point
+    x, y = point[0]
+    cx, cy = center[0]
+    radian = math.radians(abs(angle))
+    if angle > 0:  # clockwise
+        rx = (x - cx) * math.cos(radian) - (y - cy) * math.sin(radian) + cx
+        ry = (x - cx) * math.sin(radian) + (y - cy) * math.cos(radian) + cy
+    else:
+        rx = (x - cx) * math.cos(radian) + (y - cy) * math.sin(radian) + cx
+        ry = (y - cy) * math.cos(radian) - (x - cx) * math.sin(radian) + cy
+    return np.array([[rx, ry]])
+
+
+def _line(p1, p2):
+    A = (p1[1] - p2[1])
+    B = (p2[0] - p1[0])
+    C = (p1[0] * p2[1] - p2[0] * p1[1])
+    return A, B, -C
+
+
+def _intersection(L1, L2):
+    D = L1[0] * L2[1] - L1[1] * L2[0]
+    Dx = L1[2] * L2[1] - L1[1] * L2[2]
+    Dy = L1[0] * L2[2] - L1[2] * L2[0]
+    if D != 0:
+        x = Dx / D
+        y = Dy / D
+        return np.array([[x, y]])
+    else:
+        return None
+
+
+def _calc_distance_between_lines(L1, L2):
+    eps = 1e-16
+    A1, _, C1 = L1
+    A2, B2, C2 = L2
+    new_C1 = C1 / (A1 + eps)
+    new_A2 = 1
+    new_B2 = B2 / (A2 + eps)
+    new_C2 = C2 / (A2 + eps)
+    dist = (np.abs(new_C1 - new_C2)) / (
+        np.sqrt(new_A2 * new_A2 + new_B2 * new_B2) + eps)
+    return dist
+
+
+def _calc_project_in_line(point, line_point1, line_point2):
+    eps = 1e-16
+    m, n = point
+    x1, y1 = line_point1
+    x2, y2 = line_point2
+    F = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1)
+    x = (m * (x2 - x1) * (x2 - x1) + n * (y2 - y1) * (x2 - x1) +
+         (x1 * y2 - x2 * y1) * (y2 - y1)) / (F + eps)
+    y = (m * (x2 - x1) * (y2 - y1) + n * (y2 - y1) * (y2 - y1) +
+         (x2 * y1 - x1 * y2) * (x2 - x1)) / (F + eps)
+    return np.array([[x, y]])

+ 15 - 94
paddlers/utils/postprocs/utils.py

@@ -13,101 +13,22 @@
 # limitations under the License.
 
 import numpy as np
-import math
+import cv2
 
 
-def calc_distance(p1: np.ndarray, p2: np.ndarray) -> float:
-    return float(np.sqrt(np.sum(np.power((p1[0] - p2[0]), 2))))
-
-
-def calc_angle(p1: np.ndarray, vertex: np.ndarray, p2: np.ndarray) -> float:
-    x1, y1 = p1[0]
-    xv, yv = vertex[0]
-    x2, y2 = p2[0]
-    a = ((xv - x2) * (xv - x2) + (yv - y2) * (yv - y2))**0.5
-    b = ((x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2))**0.5
-    c = ((x1 - xv) * (x1 - xv) + (y1 - yv) * (y1 - yv))**0.5
-    return math.degrees(math.acos((b**2 - a**2 - c**2) / (-2 * a * c)))
-
-
-def calc_azimuth(p1: np.ndarray, p2: np.ndarray) -> float:
-    x1, y1 = p1[0]
-    x2, y2 = p2[0]
-    if y1 == y2:
-        return 0.0
-    if x1 == x2:
-        return 90.0
-    elif x1 < x2:
-        if y1 < y2:
-            ang = math.atan((y2 - y1) / (x2 - x1))
-            return math.degrees(ang)
-        else:
-            ang = math.atan((y1 - y2) / (x2 - x1))
-            return 180 - math.degrees(ang)
-    else:  # x1 > x2
-        if y1 < y2:
-            ang = math.atan((y2 - y1) / (x1 - x2))
-            return 180 - math.degrees(ang)
-        else:
-            ang = math.atan((y1 - y2) / (x1 - x2))
-            return math.degrees(ang)
-
-
-def rotation(point: np.ndarray, center: np.ndarray, angle: float) -> np.ndarray:
-    if angle == 0:
-        return point
-    x, y = point[0]
-    cx, cy = center[0]
-    radian = math.radians(abs(angle))
-    if angle > 0:  # clockwise
-        rx = (x - cx) * math.cos(radian) - (y - cy) * math.sin(radian) + cx
-        ry = (x - cx) * math.sin(radian) + (y - cy) * math.cos(radian) + cy
-    else:
-        rx = (x - cx) * math.cos(radian) + (y - cy) * math.sin(radian) + cx
-        ry = (y - cy) * math.cos(radian) - (x - cx) * math.sin(radian) + cy
-    return np.array([[rx, ry]])
+def prepro_mask(mask: np.ndarray):
+    mask_shape = mask.shape
+    if len(mask_shape) != 2:
+        mask = mask[..., 0]
+    mask = mask.astype("uint8")
+    mask = cv2.medianBlur(mask, 5)
+    class_num = len(np.unique(mask))
+    if class_num != 2:
+        _, mask = cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY |
+                                cv2.THRESH_OTSU)
+    mask = np.clip(mask, 0, 1).astype("uint8")  # 0-255 / 0-1 -> 0-1
+    return mask
 
 
-def line(p1, p2):
-    A = (p1[1] - p2[1])
-    B = (p2[0] - p1[0])
-    C = (p1[0] * p2[1] - p2[0] * p1[1])
-    return A, B, -C
-
-
-def intersection(L1, L2):
-    D = L1[0] * L2[1] - L1[1] * L2[0]
-    Dx = L1[2] * L2[1] - L1[1] * L2[2]
-    Dy = L1[0] * L2[2] - L1[2] * L2[0]
-    if D != 0:
-        x = Dx / D
-        y = Dy / D
-        return np.array([[x, y]])
-    else:
-        return None
-
-
-def calc_distance_between_lines(L1, L2):
-    eps = 1e-16
-    A1, _, C1 = L1
-    A2, B2, C2 = L2
-    new_C1 = C1 / (A1 + eps)
-    new_A2 = 1
-    new_B2 = B2 / (A2 + eps)
-    new_C2 = C2 / (A2 + eps)
-    dist = (np.abs(new_C1 - new_C2)) / (
-        np.sqrt(new_A2 * new_A2 + new_B2 * new_B2) + eps)
-    return dist
-
-
-def calc_project_in_line(point, line_point1, line_point2):
-    eps = 1e-16
-    m, n = point
-    x1, y1 = line_point1
-    x2, y2 = line_point2
-    F = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1)
-    x = (m * (x2 - x1) * (x2 - x1) + n * (y2 - y1) * (x2 - x1) +
-         (x1 * y2 - x2 * y1) * (y2 - y1)) / (F + eps)
-    y = (m * (x2 - x1) * (y2 - y1) + n * (y2 - y1) * (y2 - y1) +
-         (x2 * y1 - x1 * y2) * (x2 - x1)) / (F + eps)
-    return np.array([[x, y]])
+def calc_distance(p1: np.ndarray, p2: np.ndarray) -> float:
+    return float(np.sqrt(np.sum(np.power((p1[0] - p2[0]), 2))))