connection.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import itertools
  15. import warnings
  16. import cv2
  17. import numpy as np
  18. from skimage import morphology
  19. from scipy import ndimage, optimize
  20. with warnings.catch_warnings():
  21. warnings.filterwarnings("ignore", category=DeprecationWarning)
  22. from sklearn import metrics
  23. from sklearn.cluster import KMeans
  24. from .utils import prepro_mask, calc_distance
  25. def cut_road_connection(mask: np.ndarray, line_width: int=6) -> np.ndarray:
  26. """
  27. Connecting cut road lines.
  28. The original article refers to
  29. Wang B, Chen Z, et al. "Road extraction of high-resolution satellite remote sensing images in U-Net network with consideration of connectivity."
  30. (http://hgs.publish.founderss.cn/thesisDetails?columnId=4759509).
  31. This algorithm has no public code.
  32. The implementation procedure refers to original article,
  33. and it is not fully consistent with the article:
  34. 1. The way to determine the optimal number of clusters k used in k-means clustering is not described in the original article. In this implementation, we use the k that reports the highest silhouette score.
  35. 2. We unmark the breakpoints if the angle between the two road extensions is less than 90°.
  36. Args:
  37. mask (np.ndarray): Mask of road.
  38. line_width (int, optional): Width of the line used for patching.
  39. . Default is 6.
  40. Returns:
  41. np.ndarray: Mask of road after connecting cut road lines.
  42. """
  43. mask = prepro_mask(mask)
  44. skeleton = morphology.skeletonize(mask).astype("uint8")
  45. break_points = _find_breakpoint(skeleton)
  46. labels = _k_means(break_points)
  47. match_points = _get_match_points(break_points, labels)
  48. res = _draw_curve(mask, skeleton, match_points, line_width)
  49. return res
  50. def _find_breakpoint(skeleton):
  51. kernel_3x3 = np.ones((3, 3), dtype="uint8")
  52. k3 = ndimage.convolve(skeleton, kernel_3x3)
  53. point_map = np.zeros_like(k3)
  54. point_map[k3 == 2] = 1
  55. point_map *= skeleton * 255
  56. # boundary filtering
  57. filter_w = 5
  58. cropped = point_map[filter_w:-filter_w, filter_w:-filter_w]
  59. padded = np.pad(cropped, (filter_w, filter_w), mode="constant")
  60. breakpoints = np.column_stack(np.where(padded == 255))
  61. return breakpoints
  62. def _k_means(data):
  63. silhouette_int = -1 # threshold
  64. labels = None
  65. for k in range(2, data.shape[0]):
  66. kms = KMeans(k, random_state=66)
  67. labels_tmp = kms.fit_predict(data) # train
  68. silhouette = metrics.silhouette_score(data, labels_tmp)
  69. if silhouette > silhouette_int: # better
  70. silhouette_int = silhouette
  71. labels = labels_tmp
  72. return labels
  73. def _get_match_points(break_points, labels):
  74. match_points = {}
  75. for point, lab in zip(break_points, labels):
  76. if lab in match_points.keys():
  77. match_points[lab].append(point)
  78. else:
  79. match_points[lab] = [point]
  80. return match_points
  81. def _draw_curve(mask, skeleton, match_points, line_width):
  82. result = mask * 255
  83. for v in match_points.values():
  84. p_num = len(v)
  85. if p_num == 2:
  86. points_list = _curve_backtracking(v, skeleton)
  87. if points_list is not None:
  88. result = _broken_wire_repair(result, points_list, line_width)
  89. elif p_num == 3:
  90. sim_v = list(itertools.combinations(v, 2))
  91. min_di = 1e6
  92. for vij in sim_v:
  93. di = calc_distance(vij[0][np.newaxis], vij[1][np.newaxis])
  94. if di < min_di:
  95. vv = vij
  96. min_di = di
  97. points_list = _curve_backtracking(vv, skeleton)
  98. if points_list is not None:
  99. result = _broken_wire_repair(result, points_list, line_width)
  100. return result
  101. def _curve_backtracking(add_lines, skeleton):
  102. points_list = []
  103. p1 = add_lines[0]
  104. p2 = add_lines[1]
  105. bpk1, ps1 = _calc_angle_by_road(p1, skeleton)
  106. bpk2, ps2 = _calc_angle_by_road(p2, skeleton)
  107. if _check_angle(bpk1, bpk2):
  108. points_list.append((
  109. np.array(
  110. ps1, dtype="int64"),
  111. add_lines[0],
  112. add_lines[1],
  113. np.array(
  114. ps2, dtype="int64"), ))
  115. return points_list
  116. else:
  117. return None
  118. def _broken_wire_repair(mask, points_list, line_width):
  119. d_mask = mask.copy()
  120. for points in points_list:
  121. nx, ny = _line_cubic(points)
  122. for i in range(len(nx) - 1):
  123. loc_p1 = (int(ny[i]), int(nx[i]))
  124. loc_p2 = (int(ny[i + 1]), int(nx[i + 1]))
  125. cv2.line(d_mask, loc_p1, loc_p2, [255], line_width)
  126. return d_mask
  127. def _calc_angle_by_road(p, skeleton, num_circle=10):
  128. def _not_in(p1, ps):
  129. for p in ps:
  130. if p1[0] == p[0] and p1[1] == p[1]:
  131. return False
  132. return True
  133. h, w = skeleton.shape
  134. tmp_p = p.tolist() if isinstance(p, np.ndarray) else p
  135. tmp_p = [int(tmp_p[0]), int(tmp_p[1])]
  136. ps = []
  137. ps.append(tmp_p)
  138. for _ in range(num_circle):
  139. t_x = 0 if tmp_p[0] - 1 < 0 else tmp_p[0] - 1
  140. t_y = 0 if tmp_p[1] - 1 < 0 else tmp_p[1] - 1
  141. b_x = w if tmp_p[0] + 1 >= w else tmp_p[0] + 1
  142. b_y = h if tmp_p[1] + 1 >= h else tmp_p[1] + 1
  143. if int(np.sum(skeleton[t_x:b_x + 1, t_y:b_y + 1])) <= 3:
  144. for i in range(t_x, b_x + 1):
  145. for j in range(t_y, b_y + 1):
  146. if skeleton[i, j] == 1:
  147. pp = [int(i), int(j)]
  148. if _not_in(pp, ps):
  149. tmp_p = pp
  150. ps.append(tmp_p)
  151. # calc angle
  152. theta = _angle_regression(ps)
  153. dx, dy = np.cos(theta), np.sin(theta)
  154. # calc direction
  155. start = ps[-1]
  156. end = ps[0]
  157. if end[1] < start[1] or (end[1] == start[1] and end[0] < start[0]):
  158. dx *= -1
  159. dy *= -1
  160. return [dx, dy], start
  161. def _angle_regression(datas):
  162. def _linear(x: float, k: float, b: float) -> float:
  163. return k * x + b
  164. xs = []
  165. ys = []
  166. for data in datas:
  167. xs.append(data[0])
  168. ys.append(data[1])
  169. xs_arr = np.array(xs)
  170. ys_arr = np.array(ys)
  171. # horizontal
  172. if len(np.unique(xs_arr)) == 1:
  173. theta = np.pi / 2
  174. # vertical
  175. elif len(np.unique(ys_arr)) == 1:
  176. theta = 0
  177. # cross calc
  178. else:
  179. k1, b1 = optimize.curve_fit(_linear, xs_arr, ys_arr)[0]
  180. k2, b2 = optimize.curve_fit(_linear, ys_arr, xs_arr)[0]
  181. err1 = 0
  182. err2 = 0
  183. for x, y in zip(xs_arr, ys_arr):
  184. err1 += abs(_linear(x, k1, b1) - y) / np.sqrt(k1**2 + 1)
  185. err2 += abs(_linear(y, k2, b2) - x) / np.sqrt(k2**2 + 1)
  186. if err1 <= err2:
  187. theta = (np.arctan(k1) + 2 * np.pi) % (2 * np.pi)
  188. else:
  189. theta = (np.pi / 2.0 - np.arctan(k2) + 2 * np.pi) % (2 * np.pi)
  190. # [0, 180)
  191. theta = theta * 180 / np.pi + 90
  192. while theta >= 180:
  193. theta -= 180
  194. theta -= 90
  195. if theta < 0:
  196. theta += 180
  197. return theta * np.pi / 180
  198. def _cubic(x, y):
  199. def _func(x, a, b, c, d):
  200. return a * x**3 + b * x**2 + c * x + d
  201. arr_x = np.array(x).reshape((4, ))
  202. arr_y = np.array(y).reshape((4, ))
  203. popt1 = np.polyfit(arr_x, arr_y, 3)
  204. popt2 = np.polyfit(arr_y, arr_x, 3)
  205. x_min = np.min(arr_x)
  206. x_max = np.max(arr_x)
  207. y_min = np.min(arr_y)
  208. y_max = np.max(arr_y)
  209. nx = np.arange(x_min, x_max + 1, 1)
  210. y_estimate = [_func(i, popt1[0], popt1[1], popt1[2], popt1[3]) for i in nx]
  211. ny = np.arange(y_min, y_max + 1, 1)
  212. x_estimate = [_func(i, popt2[0], popt2[1], popt2[2], popt2[3]) for i in ny]
  213. if np.max(y_estimate) - np.min(y_estimate) <= np.max(x_estimate) - np.min(
  214. x_estimate):
  215. return nx, y_estimate
  216. else:
  217. return x_estimate, ny
  218. def _line_cubic(points):
  219. xs = []
  220. ys = []
  221. for p in points:
  222. x, y = p
  223. xs.append(x)
  224. ys.append(y)
  225. nx, ny = _cubic(xs, ys)
  226. return nx, ny
  227. def _get_theta(dy, dx):
  228. theta = np.arctan2(dy, dx) * 180 / np.pi
  229. if theta < 0.0:
  230. theta = 360.0 - abs(theta)
  231. return float(theta)
  232. def _check_angle(bpk1, bpk2, ang_threshold=90):
  233. af1 = _get_theta(bpk1[0], bpk1[1])
  234. af2 = _get_theta(bpk2[0], bpk2[1])
  235. ang_diff = abs(af1 - af2)
  236. if ang_diff > 180:
  237. ang_diff = 360 - ang_diff
  238. if ang_diff > ang_threshold:
  239. return True
  240. else:
  241. return False