functions.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import cv2
  16. import numpy as np
  17. import shapely.ops
  18. from shapely.geometry import Polygon, MultiPolygon, GeometryCollection
  19. from sklearn.linear_model import LinearRegression
  20. from skimage import exposure
  21. from joblib import load
  22. def normalize(im, mean, std, min_value=[0, 0, 0], max_value=[255, 255, 255]):
  23. # Rescaling (min-max normalization)
  24. range_value = np.asarray(
  25. [1. / (max_value[i] - min_value[i]) for i in range(len(max_value))],
  26. dtype=np.float32)
  27. im = (im - np.asarray(min_value, dtype=np.float32)) * range_value
  28. # Standardization (Z-score Normalization)
  29. im -= mean
  30. im /= std
  31. return im
  32. def permute(im, to_bgr=False):
  33. im = np.swapaxes(im, 1, 2)
  34. im = np.swapaxes(im, 1, 0)
  35. if to_bgr:
  36. im = im[[2, 1, 0], :, :]
  37. return im
  38. def center_crop(im, crop_size=224):
  39. height, width = im.shape[:2]
  40. w_start = (width - crop_size) // 2
  41. h_start = (height - crop_size) // 2
  42. w_end = w_start + crop_size
  43. h_end = h_start + crop_size
  44. im = im[h_start:h_end, w_start:w_end, ...]
  45. return im
  46. # region flip
  47. def img_flip(im, method=0):
  48. """
  49. flip image in different ways, this function provides 5 method to filp
  50. this function can be applied to 2D or 3D images
  51. Args:
  52. im(array): image array
  53. method(int or string): choose the flip method, it must be one of [
  54. 0, 1, 2, 3, 4, 'h', 'v', 'hv', 'rt2lb', 'lt2rb', 'dia', 'adia']
  55. 0 or 'h': flipped in horizontal direction, which is the most frequently used method
  56. 1 or 'v': flipped in vertical direction
  57. 2 or 'hv': flipped in both horizontal diction and vertical direction
  58. 3 or 'rt2lb' or 'dia': flipped around the diagonal,
  59. which also can be thought as changing the RightTop part with LeftBottom part,
  60. so it is called 'rt2lb' as well.
  61. 4 or 'lt2rb' or 'adia': flipped around the anti-diagonal
  62. which also can be thought as changing the LeftTop part with RightBottom part,
  63. so it is called 'lt2rb' as well.
  64. Returns:
  65. flipped image(array)
  66. Raises:
  67. ValueError: Shape of image should 2d, 3d or more.
  68. Examples:
  69. --assume an image is like this:
  70. img:
  71. / + +
  72. - / *
  73. - * /
  74. --we can flip it in following code:
  75. img_h = im_flip(img, 'h')
  76. img_v = im_flip(img, 'v')
  77. img_vh = im_flip(img, 2)
  78. img_rt2lb = im_flip(img, 3)
  79. img_lt2rb = im_flip(img, 4)
  80. --we can get flipped image:
  81. img_h, flipped in horizontal direction
  82. + + \
  83. * \ -
  84. \ * -
  85. img_v, flipped in vertical direction
  86. - * \
  87. - \ *
  88. \ + +
  89. img_vh, flipped in both horizontal diction and vertical direction
  90. / * -
  91. * / -
  92. + + /
  93. img_rt2lb, flipped around the diagonal
  94. / | |
  95. + / *
  96. + * /
  97. img_lt2rb, flipped around the anti-diagonal
  98. / * +
  99. * / +
  100. | | /
  101. """
  102. if not len(im.shape) >= 2:
  103. raise ValueError("Shape of image should 2d, 3d or more")
  104. if method == 0 or method == 'h':
  105. return horizontal_flip(im)
  106. elif method == 1 or method == 'v':
  107. return vertical_flip(im)
  108. elif method == 2 or method == 'hv':
  109. return hv_flip(im)
  110. elif method == 3 or method == 'rt2lb' or method == 'dia':
  111. return rt2lb_flip(im)
  112. elif method == 4 or method == 'lt2rb' or method == 'adia':
  113. return lt2rb_flip(im)
  114. else:
  115. return im
  116. def horizontal_flip(im):
  117. im = im[:, ::-1, ...]
  118. return im
  119. def vertical_flip(im):
  120. im = im[::-1, :, ...]
  121. return im
  122. def hv_flip(im):
  123. im = im[::-1, ::-1, ...]
  124. return im
  125. def rt2lb_flip(im):
  126. axs_list = list(range(len(im.shape)))
  127. axs_list[:2] = [1, 0]
  128. im = im.transpose(axs_list)
  129. return im
  130. def lt2rb_flip(im):
  131. axs_list = list(range(len(im.shape)))
  132. axs_list[:2] = [1, 0]
  133. im = im[::-1, ::-1, ...].transpose(axs_list)
  134. return im
  135. # endregion
  136. # region rotation
  137. def img_simple_rotate(im, method=0):
  138. """
  139. rotate image in simple ways, this function provides 3 method to rotate
  140. this function can be applied to 2D or 3D images
  141. Args:
  142. im(array): image array
  143. method(int or string): choose the flip method, it must be one of [
  144. 0, 1, 2, 90, 180, 270
  145. ]
  146. 0 or 90 : rotated in 90 degree, clockwise
  147. 1 or 180: rotated in 180 degree, clockwise
  148. 2 or 270: rotated in 270 degree, clockwise
  149. Returns:
  150. flipped image(array)
  151. Raises:
  152. ValueError: Shape of image should 2d, 3d or more.
  153. Examples:
  154. --assume an image is like this:
  155. img:
  156. / + +
  157. - / *
  158. - * /
  159. --we can rotate it in following code:
  160. img_r90 = img_simple_rotate(img, 90)
  161. img_r180 = img_simple_rotate(img, 1)
  162. img_r270 = img_simple_rotate(img, 2)
  163. --we can get rotated image:
  164. img_r90, rotated in 90 degree
  165. | | \
  166. * \ +
  167. \ * +
  168. img_r180, rotated in 180 degree
  169. / * -
  170. * / -
  171. + + /
  172. img_r270, rotated in 270 degree
  173. + * \
  174. + \ *
  175. \ | |
  176. """
  177. if not len(im.shape) >= 2:
  178. raise ValueError("Shape of image should 2d, 3d or more")
  179. if method == 0 or method == 90:
  180. return rot_90(im)
  181. elif method == 1 or method == 180:
  182. return rot_180(im)
  183. elif method == 2 or method == 270:
  184. return rot_270(im)
  185. else:
  186. return im
  187. def rot_90(im):
  188. axs_list = list(range(len(im.shape)))
  189. axs_list[:2] = [1, 0]
  190. im = im[::-1, :, ...].transpose(axs_list)
  191. return im
  192. def rot_180(im):
  193. im = im[::-1, ::-1, ...]
  194. return im
  195. def rot_270(im):
  196. axs_list = list(range(len(im.shape)))
  197. axs_list[:2] = [1, 0]
  198. im = im[:, ::-1, ...].transpose(axs_list)
  199. return im
  200. # endregion
  201. def rgb2bgr(im):
  202. return im[:, :, ::-1]
  203. def is_poly(poly):
  204. assert isinstance(poly, (list, dict)), \
  205. "Invalid poly type: {}".format(type(poly))
  206. return isinstance(poly, list)
  207. def horizontal_flip_poly(poly, width):
  208. flipped_poly = np.array(poly)
  209. flipped_poly[0::2] = width - np.array(poly[0::2])
  210. return flipped_poly.tolist()
  211. def horizontal_flip_rle(rle, height, width):
  212. import pycocotools.mask as mask_util
  213. if 'counts' in rle and type(rle['counts']) == list:
  214. rle = mask_util.frPyObjects(rle, height, width)
  215. mask = mask_util.decode(rle)
  216. mask = mask[:, ::-1]
  217. rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8))
  218. return rle
  219. def vertical_flip_poly(poly, height):
  220. flipped_poly = np.array(poly)
  221. flipped_poly[1::2] = height - np.array(poly[1::2])
  222. return flipped_poly.tolist()
  223. def vertical_flip_rle(rle, height, width):
  224. import pycocotools.mask as mask_util
  225. if 'counts' in rle and type(rle['counts']) == list:
  226. rle = mask_util.frPyObjects(rle, height, width)
  227. mask = mask_util.decode(rle)
  228. mask = mask[::-1, :]
  229. rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8))
  230. return rle
  231. def crop_poly(segm, crop):
  232. xmin, ymin, xmax, ymax = crop
  233. crop_coord = [xmin, ymin, xmin, ymax, xmax, ymax, xmax, ymin]
  234. crop_p = np.array(crop_coord).reshape(4, 2)
  235. crop_p = Polygon(crop_p)
  236. crop_segm = list()
  237. for poly in segm:
  238. poly = np.array(poly).reshape(len(poly) // 2, 2)
  239. polygon = Polygon(poly)
  240. if not polygon.is_valid:
  241. exterior = polygon.exterior
  242. multi_lines = exterior.intersection(exterior)
  243. polygons = shapely.ops.polygonize(multi_lines)
  244. polygon = MultiPolygon(polygons)
  245. multi_polygon = list()
  246. if isinstance(polygon, MultiPolygon):
  247. multi_polygon = copy.deepcopy(polygon)
  248. else:
  249. multi_polygon.append(copy.deepcopy(polygon))
  250. for per_polygon in multi_polygon:
  251. inter = per_polygon.intersection(crop_p)
  252. if not inter:
  253. continue
  254. if isinstance(inter, (MultiPolygon, GeometryCollection)):
  255. for part in inter:
  256. if not isinstance(part, Polygon):
  257. continue
  258. part = np.squeeze(
  259. np.array(part.exterior.coords[:-1]).reshape(1, -1))
  260. part[0::2] -= xmin
  261. part[1::2] -= ymin
  262. crop_segm.append(part.tolist())
  263. elif isinstance(inter, Polygon):
  264. crop_poly = np.squeeze(
  265. np.array(inter.exterior.coords[:-1]).reshape(1, -1))
  266. crop_poly[0::2] -= xmin
  267. crop_poly[1::2] -= ymin
  268. crop_segm.append(crop_poly.tolist())
  269. else:
  270. continue
  271. return crop_segm
  272. def crop_rle(rle, crop, height, width):
  273. import pycocotools.mask as mask_util
  274. if 'counts' in rle and type(rle['counts']) == list:
  275. rle = mask_util.frPyObjects(rle, height, width)
  276. mask = mask_util.decode(rle)
  277. mask = mask[crop[1]:crop[3], crop[0]:crop[2]]
  278. rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8))
  279. return rle
  280. def expand_poly(poly, x, y):
  281. expanded_poly = np.array(poly)
  282. expanded_poly[0::2] += x
  283. expanded_poly[1::2] += y
  284. return expanded_poly.tolist()
  285. def expand_rle(rle, x, y, height, width, h, w):
  286. import pycocotools.mask as mask_util
  287. if 'counts' in rle and type(rle['counts']) == list:
  288. rle = mask_util.frPyObjects(rle, height, width)
  289. mask = mask_util.decode(rle)
  290. expanded_mask = np.full((h, w), 0).astype(mask.dtype)
  291. expanded_mask[y:y + height, x:x + width] = mask
  292. rle = mask_util.encode(np.array(expanded_mask, order='F', dtype=np.uint8))
  293. return rle
  294. def resize_poly(poly, im_scale_x, im_scale_y):
  295. resized_poly = np.array(poly, dtype=np.float32)
  296. resized_poly[0::2] *= im_scale_x
  297. resized_poly[1::2] *= im_scale_y
  298. return resized_poly.tolist()
  299. def resize_rle(rle, im_h, im_w, im_scale_x, im_scale_y, interp):
  300. import pycocotools.mask as mask_util
  301. if 'counts' in rle and type(rle['counts']) == list:
  302. rle = mask_util.frPyObjects(rle, im_h, im_w)
  303. mask = mask_util.decode(rle)
  304. mask = cv2.resize(
  305. mask, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=interp)
  306. rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8))
  307. return rle
  308. def to_uint8(im, is_linear=False):
  309. """ Convert raster to uint8.
  310. Args:
  311. im (np.ndarray): The image.
  312. is_linear (bool, optional): Use 2% linear stretch or not. Default is False.
  313. Returns:
  314. np.ndarray: Image on uint8.
  315. """
  316. # 2% linear stretch
  317. def _two_percent_linear(image, max_out=255, min_out=0):
  318. def _gray_process(gray, maxout=max_out, minout=min_out):
  319. # get the corresponding gray level at 98% histogram
  320. high_value = np.percentile(gray, 98)
  321. low_value = np.percentile(gray, 2)
  322. truncated_gray = np.clip(gray, a_min=low_value, a_max=high_value)
  323. processed_gray = ((truncated_gray - low_value) / (high_value - low_value)) * \
  324. (maxout - minout)
  325. return np.uint8(processed_gray)
  326. if len(image.shape) == 3:
  327. processes = []
  328. for b in range(image.shape[-1]):
  329. processes.append(_gray_process(image[:, :, b]))
  330. result = np.stack(processes, axis=2)
  331. else: # if len(image.shape) == 2
  332. result = _gray_process(image)
  333. return np.uint8(result)
  334. # simple image standardization
  335. def _sample_norm(image):
  336. stretches = []
  337. if len(image.shape) == 3:
  338. for b in range(image.shape[-1]):
  339. stretched = exposure.equalize_hist(image[:, :, b])
  340. stretched /= float(np.max(stretched))
  341. stretches.append(stretched)
  342. stretched_img = np.stack(stretches, axis=2)
  343. else: # if len(image.shape) == 2
  344. stretched_img = exposure.equalize_hist(image)
  345. return np.uint8(stretched_img * 255)
  346. dtype = im.dtype.name
  347. if dtype != "uint8":
  348. im = _sample_norm(im)
  349. if is_linear:
  350. im = _two_percent_linear(im)
  351. return im
  352. def to_intensity(im):
  353. """ calculate SAR data's intensity diagram.
  354. Args:
  355. im (np.ndarray): The SAR image.
  356. Returns:
  357. np.ndarray: Intensity diagram.
  358. """
  359. if len(im.shape) != 2:
  360. raise ValueError("im's shape must be 2.")
  361. # the type is complex means this is a SAR data
  362. if isinstance(type(im[0, 0]), complex):
  363. im = abs(im)
  364. return im
  365. def select_bands(im, band_list=[1, 2, 3]):
  366. """ Select bands.
  367. Args:
  368. im (np.ndarray): The image.
  369. band_list (list, optional): Bands of selected (Start with 1). Defaults to [1, 2, 3].
  370. Returns:
  371. np.ndarray: The image after band selected.
  372. """
  373. if len(im.shape) == 2: # just have one channel
  374. return im
  375. if not isinstance(band_list, list) or len(band_list) == 0:
  376. raise TypeError("band_list must be non empty list.")
  377. total_band = im.shape[-1]
  378. result = []
  379. for band in band_list:
  380. band = int(band - 1)
  381. if band < 0 or band >= total_band:
  382. raise ValueError("The element in band_list must > 1 and <= {}.".
  383. format(str(total_band)))
  384. result.append(im[:, :, band])
  385. ima = np.stack(result, axis=-1)
  386. return ima
  387. def dehaze(im, gamma=False):
  388. """
  389. Single image haze removal using dark channel prior.
  390. Args:
  391. im (np.ndarray): Input image.
  392. gamma (bool, optional): Use gamma correction or not. Defaults to False.
  393. Returns:
  394. np.ndarray: The image after dehazed.
  395. """
  396. def _guided_filter(I, p, r, eps):
  397. m_I = cv2.boxFilter(I, -1, (r, r))
  398. m_p = cv2.boxFilter(p, -1, (r, r))
  399. m_Ip = cv2.boxFilter(I * p, -1, (r, r))
  400. cov_Ip = m_Ip - m_I * m_p
  401. m_II = cv2.boxFilter(I * I, -1, (r, r))
  402. var_I = m_II - m_I * m_I
  403. a = cov_Ip / (var_I + eps)
  404. b = m_p - a * m_I
  405. m_a = cv2.boxFilter(a, -1, (r, r))
  406. m_b = cv2.boxFilter(b, -1, (r, r))
  407. return m_a * I + m_b
  408. def _dehaze(im, r, w, maxatmo_mask, eps):
  409. # im is RGB and range[0, 1]
  410. atmo_mask = np.min(im, 2)
  411. dark_channel = cv2.erode(atmo_mask, np.ones((15, 15)))
  412. atmo_mask = _guided_filter(atmo_mask, dark_channel, r, eps)
  413. bins = 2000
  414. ht = np.histogram(atmo_mask, bins)
  415. d = np.cumsum(ht[0]) / float(atmo_mask.size)
  416. for lmax in range(bins - 1, 0, -1):
  417. if d[lmax] <= 0.999:
  418. break
  419. atmo_illum = np.mean(im, 2)[atmo_mask >= ht[1][lmax]].max()
  420. atmo_mask = np.minimum(atmo_mask * w, maxatmo_mask)
  421. return atmo_mask, atmo_illum
  422. if np.max(im) > 1:
  423. im = im / 255.
  424. result = np.zeros(im.shape)
  425. mask_img, atmo_illum = _dehaze(
  426. im, r=81, w=0.95, maxatmo_mask=0.80, eps=1e-8)
  427. for k in range(3):
  428. result[:, :, k] = (im[:, :, k] - mask_img) / (1 - mask_img / atmo_illum)
  429. result = np.clip(result, 0, 1)
  430. if gamma:
  431. result = result**(np.log(0.5) / np.log(result.mean()))
  432. return (result * 255).astype("uint8")
  433. def match_histograms(im, ref):
  434. """
  435. Match the cumulative histogram of one image to another.
  436. Args:
  437. im (np.ndarray): Input image.
  438. ref (np.ndarray): Reference image to match histogram of. `ref` must have the same number of channels as `im`.
  439. Returns:
  440. np.ndarray: Transformed input image.
  441. Raises:
  442. ValueError: When the number of channels of `ref` differs from that of im`.
  443. """
  444. # TODO: Check the data types of the inputs to see if they are supported by skimage
  445. return exposure.match_histograms(
  446. im, ref, channel_axis=-1 if im.ndim > 2 else None)
  447. def match_by_regression(im, ref, pif_loc=None):
  448. """
  449. Match the brightness values of two images using a linear regression method.
  450. Args:
  451. im (np.ndarray): Input image.
  452. ref (np.ndarray): Reference image to match. `ref` must have the same shape as `im`.
  453. pif_loc (tuple|None, optional): Spatial locations where pseudo-invariant features (PIFs) are obtained. If
  454. `pif_loc` is set to None, all pixels in the image will be used as training samples for the regression model.
  455. In other cases, `pif_loc` should be a tuple of np.ndarrays. Default: None.
  456. Returns:
  457. np.ndarray: Transformed input image.
  458. Raises:
  459. ValueError: When the shape of `ref` differs from that of `im`.
  460. """
  461. def _linear_regress(im, ref, loc):
  462. regressor = LinearRegression()
  463. if loc is not None:
  464. x, y = im[loc], ref[loc]
  465. else:
  466. x, y = im, ref
  467. x, y = x.reshape(-1, 1), y.ravel()
  468. regressor.fit(x, y)
  469. matched = regressor.predict(im.reshape(-1, 1))
  470. return matched.reshape(im.shape)
  471. if im.shape != ref.shape:
  472. raise ValueError("Image and Reference must have the same shape!")
  473. if im.ndim > 2:
  474. # Multiple channels
  475. matched = np.empty(im.shape, dtype=im.dtype)
  476. for ch in range(im.shape[-1]):
  477. matched[..., ch] = _linear_regress(im[..., ch], ref[..., ch],
  478. pif_loc)
  479. else:
  480. # Single channel
  481. matched = _linear_regress(im, ref, pif_loc).astype(im.dtype)
  482. return matched
  483. def inv_pca(im, joblib_path):
  484. """
  485. Restore PCA result.
  486. Args:
  487. im (np.ndarray): The input image after PCA.
  488. joblib_path (str): Path of *.joblib about PCA.
  489. Returns:
  490. np.ndarray: The raw input image.
  491. """
  492. pca = load(joblib_path)
  493. H, W, C = im.shape
  494. n_im = np.reshape(im, (-1, C))
  495. r_im = pca.inverse_transform(n_im)
  496. r_im = np.reshape(r_im, (H, W, -1))
  497. return r_im