operators.py 76 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import copy
  16. import random
  17. from numbers import Number
  18. from functools import partial
  19. from operator import methodcaller
  20. from collections.abc import Sequence
  21. import numpy as np
  22. import cv2
  23. import imghdr
  24. from PIL import Image
  25. from joblib import load
  26. import paddlers
  27. import paddlers.transforms.indices as indices
  28. import paddlers.transforms.satellites as satellites
  29. from .functions import (
  30. normalize, horizontal_flip, permute, vertical_flip, center_crop, is_poly,
  31. horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly,
  32. vertical_flip_rle, crop_poly, crop_rle, expand_poly, expand_rle,
  33. resize_poly, resize_rle, dehaze, select_bands, to_intensity, to_uint8,
  34. img_flip, img_simple_rotate, decode_seg_mask, calc_hr_shape,
  35. match_by_regression, match_histograms)
  36. __all__ = [
  37. "Compose",
  38. "DecodeImg",
  39. "Resize",
  40. "RandomResize",
  41. "ResizeByShort",
  42. "RandomResizeByShort",
  43. "ResizeByLong",
  44. "RandomHorizontalFlip",
  45. "RandomVerticalFlip",
  46. "Normalize",
  47. "CenterCrop",
  48. "RandomCrop",
  49. "RandomScaleAspect",
  50. "RandomExpand",
  51. "Pad",
  52. "MixupImage",
  53. "RandomDistort",
  54. "RandomBlur",
  55. "RandomSwap",
  56. "Dehaze",
  57. "ReduceDim",
  58. "SelectBand",
  59. "RandomFlipOrRotate",
  60. "ReloadMask",
  61. "AppendIndex",
  62. "MatchRadiance",
  63. "ArrangeRestorer",
  64. "ArrangeSegmenter",
  65. "ArrangeChangeDetector",
  66. "ArrangeClassifier",
  67. "ArrangeDetector",
  68. ]
  69. interp_dict = {
  70. 'NEAREST': cv2.INTER_NEAREST,
  71. 'LINEAR': cv2.INTER_LINEAR,
  72. 'CUBIC': cv2.INTER_CUBIC,
  73. 'AREA': cv2.INTER_AREA,
  74. 'LANCZOS4': cv2.INTER_LANCZOS4
  75. }
  76. class Compose(object):
  77. """
  78. Apply a series of data augmentation strategies to the input.
  79. All input images should be in Height-Width-Channel ([H, W, C]) format.
  80. Args:
  81. transforms (list[paddlers.transforms.Transform]): List of data preprocess or
  82. augmentation operators.
  83. Raises:
  84. TypeError: Invalid type of transforms.
  85. ValueError: Invalid length of transforms.
  86. """
  87. def __init__(self, transforms):
  88. super(Compose, self).__init__()
  89. if not isinstance(transforms, list):
  90. raise TypeError(
  91. "Type of transforms is invalid. Must be a list, but received is {}."
  92. .format(type(transforms)))
  93. if len(transforms) < 1:
  94. raise ValueError(
  95. "Length of transforms must not be less than 1, but received is {}."
  96. .format(len(transforms)))
  97. transforms = copy.deepcopy(transforms)
  98. self.arrange = self._pick_arrange(transforms)
  99. self.transforms = transforms
  100. def __call__(self, sample):
  101. """
  102. This is equivalent to sequentially calling compose_obj.apply_transforms()
  103. and compose_obj.arrange_outputs().
  104. """
  105. sample = self.apply_transforms(sample)
  106. sample = self.arrange_outputs(sample)
  107. return sample
  108. def apply_transforms(self, sample):
  109. for op in self.transforms:
  110. # Skip batch transforms amd mixup
  111. if isinstance(op, (paddlers.transforms.BatchRandomResize,
  112. paddlers.transforms.BatchRandomResizeByShort,
  113. MixupImage)):
  114. continue
  115. sample = op(sample)
  116. return sample
  117. def arrange_outputs(self, sample):
  118. if self.arrange is not None:
  119. sample = self.arrange(sample)
  120. return sample
  121. def _pick_arrange(self, transforms):
  122. arrange = None
  123. for idx, op in enumerate(transforms):
  124. if isinstance(op, Arrange):
  125. if idx != len(transforms) - 1:
  126. raise ValueError(
  127. "Arrange operator must be placed at the end of the list."
  128. )
  129. arrange = transforms.pop(idx)
  130. return arrange
  131. class Transform(object):
  132. """
  133. Parent class of all data augmentation operators.
  134. """
  135. def __init__(self):
  136. pass
  137. def apply_im(self, image):
  138. return image
  139. def apply_mask(self, mask):
  140. return mask
  141. def apply_bbox(self, bbox):
  142. return bbox
  143. def apply_segm(self, segms):
  144. return segms
  145. def apply(self, sample):
  146. if 'image' in sample:
  147. sample['image'] = self.apply_im(sample['image'])
  148. else: # image_tx
  149. sample['image'] = self.apply_im(sample['image_t1'])
  150. sample['image2'] = self.apply_im(sample['image_t2'])
  151. if 'mask' in sample:
  152. sample['mask'] = self.apply_mask(sample['mask'])
  153. if 'gt_bbox' in sample:
  154. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'])
  155. if 'aux_masks' in sample:
  156. sample['aux_masks'] = list(
  157. map(self.apply_mask, sample['aux_masks']))
  158. if 'target' in sample:
  159. sample['target'] = self.apply_im(sample['target'])
  160. return sample
  161. def __call__(self, sample):
  162. if isinstance(sample, Sequence):
  163. sample = [self.apply(s) for s in sample]
  164. else:
  165. sample = self.apply(sample)
  166. return sample
  167. class DecodeImg(Transform):
  168. """
  169. Decode image(s) in input.
  170. Args:
  171. to_rgb (bool, optional): If True, convert input image(s) from BGR format to
  172. RGB format. Defaults to True.
  173. to_uint8 (bool, optional): If True, quantize and convert decoded image(s) to
  174. uint8 type. Defaults to True.
  175. decode_bgr (bool, optional): If True, automatically interpret a non-geo image
  176. (e.g., jpeg images) as a BGR image. Defaults to True.
  177. decode_sar (bool, optional): If True, automatically interpret a single-channel
  178. geo image (e.g. geotiff images) as a SAR image, set this argument to
  179. True. Defaults to True.
  180. read_geo_info (bool, optional): If True, read geographical information from
  181. the image. Deafults to False.
  182. use_stretch (bool, optional): Whether to apply 2% linear stretch. Valid only if
  183. `to_uint8` is True. Defaults to False.
  184. """
  185. def __init__(self,
  186. to_rgb=True,
  187. to_uint8=True,
  188. decode_bgr=True,
  189. decode_sar=True,
  190. read_geo_info=False,
  191. use_stretch=False):
  192. super(DecodeImg, self).__init__()
  193. self.to_rgb = to_rgb
  194. self.to_uint8 = to_uint8
  195. self.decode_bgr = decode_bgr
  196. self.decode_sar = decode_sar
  197. self.read_geo_info = read_geo_info
  198. self.use_stretch = use_stretch
  199. def read_img(self, img_path):
  200. img_format = imghdr.what(img_path)
  201. name, ext = os.path.splitext(img_path)
  202. geo_trans, geo_proj = None, None
  203. if img_format == 'tiff' or ext == '.img':
  204. try:
  205. import gdal
  206. except:
  207. try:
  208. from osgeo import gdal
  209. except ImportError:
  210. raise ImportError(
  211. "Failed to import gdal! Please install GDAL library according to the document."
  212. )
  213. dataset = gdal.Open(img_path)
  214. if dataset == None:
  215. raise IOError('Cannot open', img_path)
  216. im_data = dataset.ReadAsArray()
  217. if im_data.ndim == 2 and self.decode_sar:
  218. im_data = to_intensity(im_data)
  219. im_data = im_data[:, :, np.newaxis]
  220. else:
  221. if im_data.ndim == 3:
  222. im_data = im_data.transpose((1, 2, 0))
  223. if self.read_geo_info:
  224. geo_trans = dataset.GetGeoTransform()
  225. geo_proj = dataset.GetProjection()
  226. elif img_format in ['jpeg', 'bmp', 'png', 'jpg']:
  227. if self.decode_bgr:
  228. im_data = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
  229. cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR)
  230. else:
  231. im_data = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
  232. cv2.IMREAD_ANYCOLOR)
  233. elif ext == '.npy':
  234. im_data = np.load(img_path)
  235. else:
  236. raise TypeError("Image format {} is not supported!".format(ext))
  237. if self.read_geo_info:
  238. return im_data, geo_trans, geo_proj
  239. else:
  240. return im_data
  241. def apply_im(self, im_path):
  242. if isinstance(im_path, str):
  243. try:
  244. data = self.read_img(im_path)
  245. except:
  246. raise ValueError("Cannot read the image file {}!".format(
  247. im_path))
  248. if self.read_geo_info:
  249. image, geo_trans, geo_proj = data
  250. geo_info_dict = {'geo_trans': geo_trans, 'geo_proj': geo_proj}
  251. else:
  252. image = data
  253. else:
  254. image = im_path
  255. if self.to_rgb and image.shape[-1] == 3:
  256. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  257. if self.to_uint8:
  258. image = to_uint8(image, stretch=self.use_stretch)
  259. if self.read_geo_info:
  260. return image, geo_info_dict
  261. else:
  262. return image
  263. def apply_mask(self, mask):
  264. try:
  265. mask = np.asarray(Image.open(mask))
  266. except:
  267. raise ValueError("Cannot read the mask file {}!".format(mask))
  268. if len(mask.shape) != 2:
  269. raise ValueError(
  270. "Mask should be a 1-channel image, but recevied is a {}-channel image.".
  271. format(mask.shape[2]))
  272. return mask
  273. def apply(self, sample):
  274. """
  275. Args:
  276. sample (dict): Input sample.
  277. Returns:
  278. dict: Sample with decoded images.
  279. """
  280. if 'image' in sample:
  281. if self.read_geo_info:
  282. image, geo_info_dict = self.apply_im(sample['image'])
  283. sample['image'] = image
  284. sample['geo_info_dict'] = geo_info_dict
  285. else:
  286. sample['image'] = self.apply_im(sample['image'])
  287. if 'image2' in sample:
  288. if self.read_geo_info:
  289. image2, geo_info_dict2 = self.apply_im(sample['image2'])
  290. sample['image2'] = image2
  291. sample['geo_info_dict2'] = geo_info_dict2
  292. else:
  293. sample['image2'] = self.apply_im(sample['image2'])
  294. if 'image_t1' in sample and not 'image' in sample:
  295. if not ('image_t2' in sample and 'image2' not in sample):
  296. raise ValueError
  297. if self.read_geo_info:
  298. image, geo_info_dict = self.apply_im(sample['image_t1'])
  299. sample['image'] = image
  300. sample['geo_info_dict'] = geo_info_dict
  301. else:
  302. sample['image'] = self.apply_im(sample['image_t1'])
  303. if self.read_geo_info:
  304. image2, geo_info_dict2 = self.apply_im(sample['image_t2'])
  305. sample['image2'] = image2
  306. sample['geo_info_dict2'] = geo_info_dict2
  307. else:
  308. sample['image2'] = self.apply_im(sample['image_t2'])
  309. if 'mask' in sample:
  310. sample['mask_ori'] = copy.deepcopy(sample['mask'])
  311. sample['mask'] = self.apply_mask(sample['mask'])
  312. im_height, im_width, _ = sample['image'].shape
  313. se_height, se_width = sample['mask'].shape
  314. if im_height != se_height or im_width != se_width:
  315. raise ValueError(
  316. "The height or width of the image is not same as the mask.")
  317. if 'aux_masks' in sample:
  318. sample['aux_masks_ori'] = copy.deepcopy(sample['aux_masks'])
  319. sample['aux_masks'] = list(
  320. map(self.apply_mask, sample['aux_masks']))
  321. # TODO: check the shape of auxiliary masks
  322. if 'target' in sample:
  323. if self.read_geo_info:
  324. target, geo_info_dict = self.apply_im(sample['target'])
  325. sample['target'] = target
  326. sample['geo_info_dict_tar'] = geo_info_dict
  327. else:
  328. sample['target'] = self.apply_im(sample['target'])
  329. sample['im_shape'] = np.array(
  330. sample['image'].shape[:2], dtype=np.float32)
  331. sample['scale_factor'] = np.array([1., 1.], dtype=np.float32)
  332. return sample
  333. class Resize(Transform):
  334. """
  335. Resize input.
  336. - If `target_size` is an int, resize the image(s) to (`target_size`, `target_size`).
  337. - If `target_size` is a list or tuple, resize the image(s) to `target_size`.
  338. Attention: If `interp` is 'RANDOM', the interpolation method will be chosen randomly.
  339. Args:
  340. target_size (int | list[int] | tuple[int]): Target size. If it is an integer, the
  341. target height and width will be both set to `target_size`. Otherwise,
  342. `target_size` represents [target height, target width].
  343. interp (str, optional): Interpolation method for resizing image(s). One of
  344. {'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}.
  345. Defaults to 'LINEAR'.
  346. keep_ratio (bool, optional): If True, the scaling factor of width and height will
  347. be set to same value, and height/width of the resized image will be not
  348. greater than the target width/height. Defaults to False.
  349. Raises:
  350. TypeError: Invalid type of target_size.
  351. ValueError: Invalid interpolation method.
  352. """
  353. def __init__(self, target_size, interp='LINEAR', keep_ratio=False):
  354. super(Resize, self).__init__()
  355. if not (interp == "RANDOM" or interp in interp_dict):
  356. raise ValueError("`interp` should be one of {}.".format(
  357. interp_dict.keys()))
  358. if isinstance(target_size, int):
  359. target_size = (target_size, target_size)
  360. else:
  361. if not (isinstance(target_size,
  362. (list, tuple)) and len(target_size) == 2):
  363. raise TypeError(
  364. "`target_size` should be an int or a list of length 2, but received {}.".
  365. format(target_size))
  366. # (height, width)
  367. self.target_size = target_size
  368. self.interp = interp
  369. self.keep_ratio = keep_ratio
  370. def apply_im(self, image, interp, target_size):
  371. flag = image.shape[2] == 1
  372. image = cv2.resize(image, target_size, interpolation=interp)
  373. if flag:
  374. image = image[:, :, np.newaxis]
  375. return image
  376. def apply_mask(self, mask, target_size):
  377. mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
  378. return mask
  379. def apply_bbox(self, bbox, scale, target_size):
  380. im_scale_x, im_scale_y = scale
  381. bbox[:, 0::2] *= im_scale_x
  382. bbox[:, 1::2] *= im_scale_y
  383. bbox[:, 0::2] = np.clip(bbox[:, 0::2], 0, target_size[0])
  384. bbox[:, 1::2] = np.clip(bbox[:, 1::2], 0, target_size[1])
  385. return bbox
  386. def apply_segm(self, segms, im_size, scale):
  387. im_h, im_w = im_size
  388. im_scale_x, im_scale_y = scale
  389. resized_segms = []
  390. for segm in segms:
  391. if is_poly(segm):
  392. # Polygon format
  393. resized_segms.append([
  394. resize_poly(poly, im_scale_x, im_scale_y) for poly in segm
  395. ])
  396. else:
  397. # RLE format
  398. resized_segms.append(
  399. resize_rle(segm, im_h, im_w, im_scale_x, im_scale_y))
  400. return resized_segms
  401. def apply(self, sample):
  402. if self.interp == "RANDOM":
  403. interp = random.choice(list(interp_dict.values()))
  404. else:
  405. interp = interp_dict[self.interp]
  406. im_h, im_w = sample['image'].shape[:2]
  407. im_scale_y = self.target_size[0] / im_h
  408. im_scale_x = self.target_size[1] / im_w
  409. target_size = (self.target_size[1], self.target_size[0])
  410. if self.keep_ratio:
  411. scale = min(im_scale_y, im_scale_x)
  412. target_w = int(round(im_w * scale))
  413. target_h = int(round(im_h * scale))
  414. target_size = (target_w, target_h)
  415. im_scale_y = target_h / im_h
  416. im_scale_x = target_w / im_w
  417. sample['image'] = self.apply_im(sample['image'], interp, target_size)
  418. if 'image2' in sample:
  419. sample['image2'] = self.apply_im(sample['image2'], interp,
  420. target_size)
  421. if 'mask' in sample:
  422. sample['mask'] = self.apply_mask(sample['mask'], target_size)
  423. if 'aux_masks' in sample:
  424. sample['aux_masks'] = list(
  425. map(partial(
  426. self.apply_mask, target_size=target_size),
  427. sample['aux_masks']))
  428. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  429. sample['gt_bbox'] = self.apply_bbox(
  430. sample['gt_bbox'], [im_scale_x, im_scale_y], target_size)
  431. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  432. sample['gt_poly'] = self.apply_segm(
  433. sample['gt_poly'], [im_h, im_w], [im_scale_x, im_scale_y])
  434. if 'target' in sample:
  435. if 'sr_factor' in sample:
  436. # For SR tasks
  437. sample['target'] = self.apply_im(
  438. sample['target'], interp,
  439. calc_hr_shape(target_size, sample['sr_factor']))
  440. else:
  441. # For non-SR tasks
  442. sample['target'] = self.apply_im(sample['target'], interp,
  443. target_size)
  444. sample['im_shape'] = np.asarray(
  445. sample['image'].shape[:2], dtype=np.float32)
  446. if 'scale_factor' in sample:
  447. scale_factor = sample['scale_factor']
  448. sample['scale_factor'] = np.asarray(
  449. [scale_factor[0] * im_scale_y, scale_factor[1] * im_scale_x],
  450. dtype=np.float32)
  451. return sample
  452. class RandomResize(Transform):
  453. """
  454. Resize input to random sizes.
  455. Attention: If `interp` is 'RANDOM', the interpolation method will be chosen randomly.
  456. Args:
  457. target_sizes (list[int] | list[list|tuple] | tuple[list|tuple]):
  458. Multiple target sizes, each of which should be int, list, or tuple.
  459. interp (str, optional): Interpolation method for resizing image(s). One of
  460. {'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}.
  461. Defaults to 'LINEAR'.
  462. Raises:
  463. TypeError: Invalid type of `target_size`.
  464. ValueError: Invalid interpolation method.
  465. """
  466. def __init__(self, target_sizes, interp='LINEAR'):
  467. super(RandomResize, self).__init__()
  468. if not (interp == "RANDOM" or interp in interp_dict):
  469. raise ValueError("`interp` should be one of {}.".format(
  470. interp_dict.keys()))
  471. self.interp = interp
  472. assert isinstance(target_sizes, list), \
  473. "`target_size` must be a list."
  474. for i, item in enumerate(target_sizes):
  475. if isinstance(item, int):
  476. target_sizes[i] = (item, item)
  477. self.target_size = target_sizes
  478. def apply(self, sample):
  479. height, width = random.choice(self.target_size)
  480. resizer = Resize((height, width), interp=self.interp)
  481. sample = resizer(sample)
  482. return sample
  483. class ResizeByShort(Transform):
  484. """
  485. Resize input while keeping the aspect ratio.
  486. Attention: If `interp` is 'RANDOM', the interpolation method will be chosen randomly.
  487. Args:
  488. short_size (int): Target size of the shorter side of the image(s).
  489. max_size (int, optional): Upper bound of longer side of the image(s). If
  490. `max_size` is -1, no upper bound will be applied. Defaults to -1.
  491. interp (str, optional): Interpolation method for resizing image(s). One of
  492. {'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}.
  493. Defaults to 'LINEAR'.
  494. Raises:
  495. ValueError: Invalid interpolation method.
  496. """
  497. def __init__(self, short_size=256, max_size=-1, interp='LINEAR'):
  498. if not (interp == "RANDOM" or interp in interp_dict):
  499. raise ValueError("`interp` should be one of {}".format(
  500. interp_dict.keys()))
  501. super(ResizeByShort, self).__init__()
  502. self.short_size = short_size
  503. self.max_size = max_size
  504. self.interp = interp
  505. def apply(self, sample):
  506. im_h, im_w = sample['image'].shape[:2]
  507. im_short_size = min(im_h, im_w)
  508. im_long_size = max(im_h, im_w)
  509. scale = float(self.short_size) / float(im_short_size)
  510. if 0 < self.max_size < np.round(scale * im_long_size):
  511. scale = float(self.max_size) / float(im_long_size)
  512. target_w = int(round(im_w * scale))
  513. target_h = int(round(im_h * scale))
  514. sample = Resize(
  515. target_size=(target_h, target_w), interp=self.interp)(sample)
  516. return sample
  517. class RandomResizeByShort(Transform):
  518. """
  519. Resize input to random sizes while keeping the aspect ratio.
  520. Attention: If `interp` is 'RANDOM', the interpolation method will be chosen randomly.
  521. Args:
  522. short_sizes (list[int]): Target size of the shorter side of the image(s).
  523. max_size (int, optional): Upper bound of longer side of the image(s).
  524. If `max_size` is -1, no upper bound will be applied. Defaults to -1.
  525. interp (str, optional): Interpolation method for resizing image(s). One of
  526. {'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}.
  527. Defaults to 'LINEAR'.
  528. Raises:
  529. TypeError: Invalid type of `target_size`.
  530. ValueError: Invalid interpolation method.
  531. See Also:
  532. ResizeByShort: Resize image(s) in input while keeping the aspect ratio.
  533. """
  534. def __init__(self, short_sizes, max_size=-1, interp='LINEAR'):
  535. super(RandomResizeByShort, self).__init__()
  536. if not (interp == "RANDOM" or interp in interp_dict):
  537. raise ValueError("`interp` should be one of {}".format(
  538. interp_dict.keys()))
  539. self.interp = interp
  540. assert isinstance(short_sizes, list), \
  541. "`short_sizes` must be a list."
  542. self.short_sizes = short_sizes
  543. self.max_size = max_size
  544. def apply(self, sample):
  545. short_size = random.choice(self.short_sizes)
  546. resizer = ResizeByShort(
  547. short_size=short_size, max_size=self.max_size, interp=self.interp)
  548. sample = resizer(sample)
  549. return sample
  550. class ResizeByLong(Transform):
  551. def __init__(self, long_size=256, interp='LINEAR'):
  552. super(ResizeByLong, self).__init__()
  553. self.long_size = long_size
  554. self.interp = interp
  555. def apply(self, sample):
  556. im_h, im_w = sample['image'].shape[:2]
  557. im_long_size = max(im_h, im_w)
  558. scale = float(self.long_size) / float(im_long_size)
  559. target_h = int(round(im_h * scale))
  560. target_w = int(round(im_w * scale))
  561. sample = Resize(
  562. target_size=(target_h, target_w), interp=self.interp)(sample)
  563. return sample
  564. class RandomFlipOrRotate(Transform):
  565. """
  566. Flip or Rotate an image in different directions with a certain probability.
  567. Args:
  568. probs (list[float]): Probabilities of performing flipping and rotation.
  569. Default: [0.35,0.25].
  570. probsf (list[float]): Probabilities of 5 flipping modes (horizontal,
  571. vertical, both horizontal and vertical, diagonal, anti-diagonal).
  572. Default: [0.3, 0.3, 0.2, 0.1, 0.1].
  573. probsr (list[float]): Probabilities of 3 rotation modes (90°, 180°, 270°
  574. clockwise). Default: [0.25, 0.5, 0.25].
  575. Examples:
  576. from paddlers import transforms as T
  577. # Define operators for data augmentation
  578. train_transforms = T.Compose([
  579. T.DecodeImg(),
  580. T.RandomFlipOrRotate(
  581. probs = [0.3, 0.2] # p=0.3 to flip the image,p=0.2 to rotate the image,p=0.5 to keep the image unchanged.
  582. probsf = [0.3, 0.25, 0, 0, 0] # p=0.3 and p=0.25 to perform horizontal and vertical flipping; probility of no-flipping is 0.45.
  583. probsr = [0, 0.65, 0]), # p=0.65 to rotate the image by 180°; probility of no-rotation is 0.35.
  584. T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  585. ])
  586. """
  587. def __init__(self,
  588. probs=[0.35, 0.25],
  589. probsf=[0.3, 0.3, 0.2, 0.1, 0.1],
  590. probsr=[0.25, 0.5, 0.25]):
  591. super(RandomFlipOrRotate, self).__init__()
  592. # Change various probabilities into probability intervals, to judge in which mode to flip or rotate
  593. self.probs = [probs[0], probs[0] + probs[1]]
  594. self.probsf = self.get_probs_range(probsf)
  595. self.probsr = self.get_probs_range(probsr)
  596. def apply_im(self, image, mode_id, flip_mode=True):
  597. if flip_mode:
  598. image = img_flip(image, mode_id)
  599. else:
  600. image = img_simple_rotate(image, mode_id)
  601. return image
  602. def apply_mask(self, mask, mode_id, flip_mode=True):
  603. if flip_mode:
  604. mask = img_flip(mask, mode_id)
  605. else:
  606. mask = img_simple_rotate(mask, mode_id)
  607. return mask
  608. def apply_bbox(self, bbox, mode_id, flip_mode=True):
  609. raise TypeError(
  610. "Currently, RandomFlipOrRotate is not available for object detection tasks."
  611. )
  612. def apply_segm(self, bbox, mode_id, flip_mode=True):
  613. raise TypeError(
  614. "Currently, RandomFlipOrRotate is not available for object detection tasks."
  615. )
  616. def get_probs_range(self, probs):
  617. """
  618. Change list of probabilities into cumulative probability intervals.
  619. Args:
  620. probs (list[float]): Probabilities of different modes, shape: [n].
  621. Returns:
  622. list[list]: Probability intervals, shape: [n, 2].
  623. """
  624. ps = []
  625. last_prob = 0
  626. for prob in probs:
  627. p_s = last_prob
  628. cur_prob = prob / sum(probs)
  629. last_prob += cur_prob
  630. p_e = last_prob
  631. ps.append([p_s, p_e])
  632. return ps
  633. def judge_probs_range(self, p, probs):
  634. """
  635. Judge whether the value of `p` falls within the given probability interval.
  636. Args:
  637. p (float): Value between 0 and 1.
  638. probs (list[list]): Probability intervals, shape: [n, 2].
  639. Returns:
  640. int: Interval where the input probability falls into.
  641. """
  642. for id, id_range in enumerate(probs):
  643. if p > id_range[0] and p < id_range[1]:
  644. return id
  645. return -1
  646. def apply(self, sample):
  647. p_m = random.random()
  648. if p_m < self.probs[0]:
  649. mode_p = random.random()
  650. mode_id = self.judge_probs_range(mode_p, self.probsf)
  651. sample['image'] = self.apply_im(sample['image'], mode_id, True)
  652. if 'image2' in sample:
  653. sample['image2'] = self.apply_im(sample['image2'], mode_id,
  654. True)
  655. if 'mask' in sample:
  656. sample['mask'] = self.apply_mask(sample['mask'], mode_id, True)
  657. if 'aux_masks' in sample:
  658. sample['aux_masks'] = [
  659. self.apply_mask(aux_mask, mode_id, True)
  660. for aux_mask in sample['aux_masks']
  661. ]
  662. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  663. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], mode_id,
  664. True)
  665. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  666. sample['gt_poly'] = self.apply_segm(sample['gt_poly'], mode_id,
  667. True)
  668. if 'target' in sample:
  669. sample['target'] = self.apply_im(sample['target'], mode_id,
  670. True)
  671. elif p_m < self.probs[1]:
  672. mode_p = random.random()
  673. mode_id = self.judge_probs_range(mode_p, self.probsr)
  674. sample['image'] = self.apply_im(sample['image'], mode_id, False)
  675. if 'image2' in sample:
  676. sample['image2'] = self.apply_im(sample['image2'], mode_id,
  677. False)
  678. if 'mask' in sample:
  679. sample['mask'] = self.apply_mask(sample['mask'], mode_id, False)
  680. if 'aux_masks' in sample:
  681. sample['aux_masks'] = [
  682. self.apply_mask(aux_mask, mode_id, False)
  683. for aux_mask in sample['aux_masks']
  684. ]
  685. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  686. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], mode_id,
  687. False)
  688. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  689. sample['gt_poly'] = self.apply_segm(sample['gt_poly'], mode_id,
  690. False)
  691. if 'target' in sample:
  692. sample['target'] = self.apply_im(sample['target'], mode_id,
  693. False)
  694. return sample
  695. class RandomHorizontalFlip(Transform):
  696. """
  697. Randomly flip the input horizontally.
  698. Args:
  699. prob (float, optional): Probability of flipping the input. Defaults to .5.
  700. """
  701. def __init__(self, prob=0.5):
  702. super(RandomHorizontalFlip, self).__init__()
  703. self.prob = prob
  704. def apply_im(self, image):
  705. image = horizontal_flip(image)
  706. return image
  707. def apply_mask(self, mask):
  708. mask = horizontal_flip(mask)
  709. return mask
  710. def apply_bbox(self, bbox, width):
  711. oldx1 = bbox[:, 0].copy()
  712. oldx2 = bbox[:, 2].copy()
  713. bbox[:, 0] = width - oldx2
  714. bbox[:, 2] = width - oldx1
  715. return bbox
  716. def apply_segm(self, segms, height, width):
  717. flipped_segms = []
  718. for segm in segms:
  719. if is_poly(segm):
  720. # Polygon format
  721. flipped_segms.append(
  722. [horizontal_flip_poly(poly, width) for poly in segm])
  723. else:
  724. # RLE format
  725. flipped_segms.append(horizontal_flip_rle(segm, height, width))
  726. return flipped_segms
  727. def apply(self, sample):
  728. if random.random() < self.prob:
  729. im_h, im_w = sample['image'].shape[:2]
  730. sample['image'] = self.apply_im(sample['image'])
  731. if 'image2' in sample:
  732. sample['image2'] = self.apply_im(sample['image2'])
  733. if 'mask' in sample:
  734. sample['mask'] = self.apply_mask(sample['mask'])
  735. if 'aux_masks' in sample:
  736. sample['aux_masks'] = list(
  737. map(self.apply_mask, sample['aux_masks']))
  738. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  739. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], im_w)
  740. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  741. sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_h,
  742. im_w)
  743. if 'target' in sample:
  744. sample['target'] = self.apply_im(sample['target'])
  745. return sample
  746. class RandomVerticalFlip(Transform):
  747. """
  748. Randomly flip the input vertically.
  749. Args:
  750. prob (float, optional): Probability of flipping the input. Defaults to .5.
  751. """
  752. def __init__(self, prob=0.5):
  753. super(RandomVerticalFlip, self).__init__()
  754. self.prob = prob
  755. def apply_im(self, image):
  756. image = vertical_flip(image)
  757. return image
  758. def apply_mask(self, mask):
  759. mask = vertical_flip(mask)
  760. return mask
  761. def apply_bbox(self, bbox, height):
  762. oldy1 = bbox[:, 1].copy()
  763. oldy2 = bbox[:, 3].copy()
  764. bbox[:, 0] = height - oldy2
  765. bbox[:, 2] = height - oldy1
  766. return bbox
  767. def apply_segm(self, segms, height, width):
  768. flipped_segms = []
  769. for segm in segms:
  770. if is_poly(segm):
  771. # Polygon format
  772. flipped_segms.append(
  773. [vertical_flip_poly(poly, height) for poly in segm])
  774. else:
  775. # RLE format
  776. flipped_segms.append(vertical_flip_rle(segm, height, width))
  777. return flipped_segms
  778. def apply(self, sample):
  779. if random.random() < self.prob:
  780. im_h, im_w = sample['image'].shape[:2]
  781. sample['image'] = self.apply_im(sample['image'])
  782. if 'image2' in sample:
  783. sample['image2'] = self.apply_im(sample['image2'])
  784. if 'mask' in sample:
  785. sample['mask'] = self.apply_mask(sample['mask'])
  786. if 'aux_masks' in sample:
  787. sample['aux_masks'] = list(
  788. map(self.apply_mask, sample['aux_masks']))
  789. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  790. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], im_h)
  791. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  792. sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_h,
  793. im_w)
  794. if 'target' in sample:
  795. sample['target'] = self.apply_im(sample['target'])
  796. return sample
  797. class Normalize(Transform):
  798. """
  799. Apply normalization to the input image(s). The normalization steps are:
  800. 1. im = (im - min_value) * 1 / (max_value - min_value)
  801. 2. im = im - mean
  802. 3. im = im / std
  803. Args:
  804. mean (list[float] | tuple[float], optional): Mean of input image(s).
  805. Defaults to [0.485, 0.456, 0.406].
  806. std (list[float] | tuple[float], optional): Standard deviation of input
  807. image(s). Defaults to [0.229, 0.224, 0.225].
  808. min_val (list[float] | tuple[float], optional): Minimum value of input
  809. image(s). If None, use 0 for all channels. Defaults to None.
  810. max_val (list[float] | tuple[float], optional): Maximum value of input
  811. image(s). If None, use 255. for all channels. Defaults to None.
  812. apply_to_tar (bool, optional): Whether to apply transformation to the target
  813. image. Defaults to True.
  814. """
  815. def __init__(self,
  816. mean=[0.485, 0.456, 0.406],
  817. std=[0.229, 0.224, 0.225],
  818. min_val=None,
  819. max_val=None,
  820. apply_to_tar=True):
  821. super(Normalize, self).__init__()
  822. channel = len(mean)
  823. if min_val is None:
  824. min_val = [0] * channel
  825. if max_val is None:
  826. max_val = [255.] * channel
  827. from functools import reduce
  828. if reduce(lambda x, y: x * y, std) == 0:
  829. raise ValueError(
  830. "`std` should not contain 0, but received is {}.".format(std))
  831. if reduce(lambda x, y: x * y,
  832. [a - b for a, b in zip(max_val, min_val)]) == 0:
  833. raise ValueError(
  834. "(`max_val` - `min_val`) should not contain 0, but received is {}.".
  835. format((np.asarray(max_val) - np.asarray(min_val)).tolist()))
  836. self.mean = mean
  837. self.std = std
  838. self.min_val = min_val
  839. self.max_val = max_val
  840. self.apply_to_tar = apply_to_tar
  841. def apply_im(self, image):
  842. image = image.astype(np.float32)
  843. mean = np.asarray(
  844. self.mean, dtype=np.float32)[np.newaxis, np.newaxis, :]
  845. std = np.asarray(self.std, dtype=np.float32)[np.newaxis, np.newaxis, :]
  846. image = normalize(image, mean, std, self.min_val, self.max_val)
  847. return image
  848. def apply(self, sample):
  849. sample['image'] = self.apply_im(sample['image'])
  850. if 'image2' in sample:
  851. sample['image2'] = self.apply_im(sample['image2'])
  852. if 'target' in sample and self.apply_to_tar:
  853. sample['target'] = self.apply_im(sample['target'])
  854. return sample
  855. class CenterCrop(Transform):
  856. """
  857. Crop the input image(s) at the center.
  858. 1. Locate the center of the image.
  859. 2. Crop the image.
  860. Args:
  861. crop_size (int, optional): Target size of the cropped image(s).
  862. Defaults to 224.
  863. """
  864. def __init__(self, crop_size=224):
  865. super(CenterCrop, self).__init__()
  866. self.crop_size = crop_size
  867. def apply_im(self, image):
  868. image = center_crop(image, self.crop_size)
  869. return image
  870. def apply_mask(self, mask):
  871. mask = center_crop(mask, self.crop_size)
  872. return mask
  873. def apply(self, sample):
  874. sample['image'] = self.apply_im(sample['image'])
  875. if 'image2' in sample:
  876. sample['image2'] = self.apply_im(sample['image2'])
  877. if 'mask' in sample:
  878. sample['mask'] = self.apply_mask(sample['mask'])
  879. if 'aux_masks' in sample:
  880. sample['aux_masks'] = list(
  881. map(self.apply_mask, sample['aux_masks']))
  882. if 'target' in sample:
  883. sample['target'] = self.apply_im(sample['target'])
  884. return sample
  885. class RandomCrop(Transform):
  886. """
  887. Randomly crop the input.
  888. 1. Compute the height and width of cropped area according to `aspect_ratio` and
  889. `scaling`.
  890. 2. Locate the upper left corner of cropped area randomly.
  891. 3. Crop the image(s).
  892. 4. Resize the cropped area to `crop_size` x `crop_size`.
  893. Args:
  894. crop_size (int | list[int] | tuple[int]): Target size of the cropped area. If
  895. None, the cropped area will not be resized. Defaults to None.
  896. aspect_ratio (list[float], optional): Aspect ratio of cropped region in
  897. [min, max] format. Defaults to [.5, 2.].
  898. thresholds (list[float], optional): Iou thresholds to decide a valid bbox
  899. crop. Defaults to [.0, .1, .3, .5, .7, .9].
  900. scaling (list[float], optional): Ratio between the cropped region and the
  901. original image in [min, max] format. Defaults to [.3, 1.].
  902. num_attempts (int, optional): Max number of tries before giving up.
  903. Defaults to 50.
  904. allow_no_crop (bool, optional): Whether returning without doing crop is
  905. allowed. Defaults to True.
  906. cover_all_box (bool, optional): Whether to ensure all bboxes be covered in
  907. the final crop. Defaults to False.
  908. """
  909. def __init__(self,
  910. crop_size=None,
  911. aspect_ratio=[.5, 2.],
  912. thresholds=[.0, .1, .3, .5, .7, .9],
  913. scaling=[.3, 1.],
  914. num_attempts=50,
  915. allow_no_crop=True,
  916. cover_all_box=False):
  917. super(RandomCrop, self).__init__()
  918. self.crop_size = crop_size
  919. self.aspect_ratio = aspect_ratio
  920. self.thresholds = thresholds
  921. self.scaling = scaling
  922. self.num_attempts = num_attempts
  923. self.allow_no_crop = allow_no_crop
  924. self.cover_all_box = cover_all_box
  925. def _generate_crop_info(self, sample):
  926. im_h, im_w = sample['image'].shape[:2]
  927. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  928. thresholds = self.thresholds
  929. if self.allow_no_crop:
  930. thresholds.append('no_crop')
  931. np.random.shuffle(thresholds)
  932. for thresh in thresholds:
  933. if thresh == 'no_crop':
  934. return None
  935. for i in range(self.num_attempts):
  936. crop_box = self._get_crop_box(im_h, im_w)
  937. if crop_box is None:
  938. continue
  939. iou = self._iou_matrix(
  940. sample['gt_bbox'],
  941. np.array(
  942. [crop_box], dtype=np.float32))
  943. if iou.max() < thresh:
  944. continue
  945. if self.cover_all_box and iou.min() < thresh:
  946. continue
  947. cropped_box, valid_ids = self._crop_box_with_center_constraint(
  948. sample['gt_bbox'], np.array(
  949. crop_box, dtype=np.float32))
  950. if valid_ids.size > 0:
  951. return crop_box, cropped_box, valid_ids
  952. else:
  953. for i in range(self.num_attempts):
  954. crop_box = self._get_crop_box(im_h, im_w)
  955. if crop_box is None:
  956. continue
  957. return crop_box, None, None
  958. return None
  959. def _get_crop_box(self, im_h, im_w):
  960. scale = np.random.uniform(*self.scaling)
  961. if self.aspect_ratio is not None:
  962. min_ar, max_ar = self.aspect_ratio
  963. aspect_ratio = np.random.uniform(
  964. max(min_ar, scale**2), min(max_ar, scale**-2))
  965. h_scale = scale / np.sqrt(aspect_ratio)
  966. w_scale = scale * np.sqrt(aspect_ratio)
  967. else:
  968. h_scale = np.random.uniform(*self.scaling)
  969. w_scale = np.random.uniform(*self.scaling)
  970. crop_h = im_h * h_scale
  971. crop_w = im_w * w_scale
  972. if self.aspect_ratio is None:
  973. if crop_h / crop_w < 0.5 or crop_h / crop_w > 2.0:
  974. return None
  975. crop_h = int(crop_h)
  976. crop_w = int(crop_w)
  977. crop_y = np.random.randint(0, im_h - crop_h)
  978. crop_x = np.random.randint(0, im_w - crop_w)
  979. return [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h]
  980. def _iou_matrix(self, a, b):
  981. tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2])
  982. br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
  983. area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2)
  984. area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
  985. area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
  986. area_o = (area_a[:, np.newaxis] + area_b - area_i)
  987. return area_i / (area_o + 1e-10)
  988. def _crop_box_with_center_constraint(self, box, crop):
  989. cropped_box = box.copy()
  990. cropped_box[:, :2] = np.maximum(box[:, :2], crop[:2])
  991. cropped_box[:, 2:] = np.minimum(box[:, 2:], crop[2:])
  992. cropped_box[:, :2] -= crop[:2]
  993. cropped_box[:, 2:] -= crop[:2]
  994. centers = (box[:, :2] + box[:, 2:]) / 2
  995. valid = np.logical_and(crop[:2] <= centers,
  996. centers < crop[2:]).all(axis=1)
  997. valid = np.logical_and(
  998. valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1))
  999. return cropped_box, np.where(valid)[0]
  1000. def _crop_segm(self, segms, valid_ids, crop, height, width):
  1001. crop_segms = []
  1002. for id in valid_ids:
  1003. segm = segms[id]
  1004. if is_poly(segm):
  1005. # Polygon format
  1006. crop_segms.append(crop_poly(segm, crop))
  1007. else:
  1008. # RLE format
  1009. crop_segms.append(crop_rle(segm, crop, height, width))
  1010. return crop_segms
  1011. def apply_im(self, image, crop):
  1012. x1, y1, x2, y2 = crop
  1013. return image[y1:y2, x1:x2, :]
  1014. def apply_mask(self, mask, crop):
  1015. x1, y1, x2, y2 = crop
  1016. return mask[y1:y2, x1:x2, ...]
  1017. def apply(self, sample):
  1018. crop_info = self._generate_crop_info(sample)
  1019. if crop_info is not None:
  1020. crop_box, cropped_box, valid_ids = crop_info
  1021. im_h, im_w = sample['image'].shape[:2]
  1022. sample['image'] = self.apply_im(sample['image'], crop_box)
  1023. if 'image2' in sample:
  1024. sample['image2'] = self.apply_im(sample['image2'], crop_box)
  1025. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  1026. crop_polys = self._crop_segm(
  1027. sample['gt_poly'],
  1028. valid_ids,
  1029. np.array(
  1030. crop_box, dtype=np.int64),
  1031. im_h,
  1032. im_w)
  1033. if [] in crop_polys:
  1034. delete_id = list()
  1035. valid_polys = list()
  1036. for idx, poly in enumerate(crop_polys):
  1037. if not crop_poly:
  1038. delete_id.append(idx)
  1039. else:
  1040. valid_polys.append(poly)
  1041. valid_ids = np.delete(valid_ids, delete_id)
  1042. if not valid_polys:
  1043. return sample
  1044. sample['gt_poly'] = valid_polys
  1045. else:
  1046. sample['gt_poly'] = crop_polys
  1047. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  1048. sample['gt_bbox'] = np.take(cropped_box, valid_ids, axis=0)
  1049. sample['gt_class'] = np.take(
  1050. sample['gt_class'], valid_ids, axis=0)
  1051. if 'gt_score' in sample:
  1052. sample['gt_score'] = np.take(
  1053. sample['gt_score'], valid_ids, axis=0)
  1054. if 'is_crowd' in sample:
  1055. sample['is_crowd'] = np.take(
  1056. sample['is_crowd'], valid_ids, axis=0)
  1057. if 'mask' in sample:
  1058. sample['mask'] = self.apply_mask(sample['mask'], crop_box)
  1059. if 'aux_masks' in sample:
  1060. sample['aux_masks'] = list(
  1061. map(partial(
  1062. self.apply_mask, crop=crop_box),
  1063. sample['aux_masks']))
  1064. if 'target' in sample:
  1065. if 'sr_factor' in sample:
  1066. sample['target'] = self.apply_im(
  1067. sample['target'],
  1068. calc_hr_shape(crop_box, sample['sr_factor']))
  1069. else:
  1070. sample['target'] = self.apply_im(sample['image'], crop_box)
  1071. if self.crop_size is not None:
  1072. sample = Resize(self.crop_size)(sample)
  1073. return sample
  1074. class RandomScaleAspect(Transform):
  1075. """
  1076. Crop input image(s) and resize back to original sizes.
  1077. Args:
  1078. min_scale (float): Minimum ratio between the cropped region and the original
  1079. image. If 0, image(s) will not be cropped. Defaults to .5.
  1080. aspect_ratio (float): Aspect ratio of cropped region. Defaults to .33.
  1081. """
  1082. def __init__(self, min_scale=0.5, aspect_ratio=0.33):
  1083. super(RandomScaleAspect, self).__init__()
  1084. self.min_scale = min_scale
  1085. self.aspect_ratio = aspect_ratio
  1086. def apply(self, sample):
  1087. if self.min_scale != 0 and self.aspect_ratio != 0:
  1088. img_height, img_width = sample['image'].shape[:2]
  1089. sample = RandomCrop(
  1090. crop_size=(img_height, img_width),
  1091. aspect_ratio=[self.aspect_ratio, 1. / self.aspect_ratio],
  1092. scaling=[self.min_scale, 1.],
  1093. num_attempts=10,
  1094. allow_no_crop=False)(sample)
  1095. return sample
  1096. class RandomExpand(Transform):
  1097. """
  1098. Randomly expand the input by padding according to random offsets.
  1099. Args:
  1100. upper_ratio (float, optional): Maximum ratio to which the original image
  1101. is expanded. Defaults to 4..
  1102. prob (float, optional): Probability of apply expanding. Defaults to .5.
  1103. im_padding_value (list[float] | tuple[float], optional): RGB filling value
  1104. for the image. Defaults to (127.5, 127.5, 127.5).
  1105. label_padding_value (int, optional): Filling value for the mask.
  1106. Defaults to 255.
  1107. See Also:
  1108. paddlers.transforms.Pad
  1109. """
  1110. def __init__(self,
  1111. upper_ratio=4.,
  1112. prob=.5,
  1113. im_padding_value=127.5,
  1114. label_padding_value=255):
  1115. super(RandomExpand, self).__init__()
  1116. assert upper_ratio > 1.01, "`upper_ratio` must be larger than 1.01."
  1117. self.upper_ratio = upper_ratio
  1118. self.prob = prob
  1119. assert isinstance(im_padding_value, (Number, Sequence)), \
  1120. "Value to fill must be either float or sequence."
  1121. self.im_padding_value = im_padding_value
  1122. self.label_padding_value = label_padding_value
  1123. def apply(self, sample):
  1124. if random.random() < self.prob:
  1125. im_h, im_w = sample['image'].shape[:2]
  1126. ratio = np.random.uniform(1., self.upper_ratio)
  1127. h = int(im_h * ratio)
  1128. w = int(im_w * ratio)
  1129. if h > im_h and w > im_w:
  1130. y = np.random.randint(0, h - im_h)
  1131. x = np.random.randint(0, w - im_w)
  1132. target_size = (h, w)
  1133. offsets = (x, y)
  1134. sample = Pad(
  1135. target_size=target_size,
  1136. pad_mode=-1,
  1137. offsets=offsets,
  1138. im_padding_value=self.im_padding_value,
  1139. label_padding_value=self.label_padding_value)(sample)
  1140. return sample
  1141. class Pad(Transform):
  1142. def __init__(self,
  1143. target_size=None,
  1144. pad_mode=0,
  1145. offsets=None,
  1146. im_padding_value=127.5,
  1147. label_padding_value=255,
  1148. size_divisor=32):
  1149. """
  1150. Pad image to a specified size or multiple of `size_divisor`.
  1151. Args:
  1152. target_size (list[int] | tuple[int], optional): Image target size, if None, pad to
  1153. multiple of size_divisor. Defaults to None.
  1154. pad_mode (int, optional): Pad mode. Currently only four modes are supported:
  1155. [-1, 0, 1, 2]. if -1, use specified offsets. If 0, only pad to right and bottom
  1156. If 1, pad according to center. If 2, only pad left and top. Defaults to 0.
  1157. offsets (list[int]|None, optional): Padding offsets. Defaults to None.
  1158. im_padding_value (list[float] | tuple[float]): RGB value of padded area.
  1159. Defaults to (127.5, 127.5, 127.5).
  1160. label_padding_value (int, optional): Filling value for the mask.
  1161. Defaults to 255.
  1162. size_divisor (int): Image width and height after padding will be a multiple of
  1163. `size_divisor`.
  1164. """
  1165. super(Pad, self).__init__()
  1166. if isinstance(target_size, (list, tuple)):
  1167. if len(target_size) != 2:
  1168. raise ValueError(
  1169. "`target_size` should contain 2 elements, but it is {}.".
  1170. format(target_size))
  1171. if isinstance(target_size, int):
  1172. target_size = [target_size] * 2
  1173. assert pad_mode in [
  1174. -1, 0, 1, 2
  1175. ], "Currently only four modes are supported: [-1, 0, 1, 2]."
  1176. if pad_mode == -1:
  1177. assert offsets, "if `pad_mode` is -1, `offsets` should not be None."
  1178. self.target_size = target_size
  1179. self.size_divisor = size_divisor
  1180. self.pad_mode = pad_mode
  1181. self.offsets = offsets
  1182. self.im_padding_value = im_padding_value
  1183. self.label_padding_value = label_padding_value
  1184. def apply_im(self, image, offsets, target_size):
  1185. x, y = offsets
  1186. h, w = target_size
  1187. im_h, im_w, channel = image.shape[:3]
  1188. canvas = np.ones((h, w, channel), dtype=np.float32)
  1189. canvas *= np.array(self.im_padding_value, dtype=np.float32)
  1190. canvas[y:y + im_h, x:x + im_w, :] = image.astype(np.float32)
  1191. return canvas
  1192. def apply_mask(self, mask, offsets, target_size):
  1193. x, y = offsets
  1194. im_h, im_w = mask.shape[:2]
  1195. h, w = target_size
  1196. canvas = np.ones((h, w), dtype=np.float32)
  1197. canvas *= np.array(self.label_padding_value, dtype=np.float32)
  1198. canvas[y:y + im_h, x:x + im_w] = mask.astype(np.float32)
  1199. return canvas
  1200. def apply_bbox(self, bbox, offsets):
  1201. return bbox + np.array(offsets * 2, dtype=np.float32)
  1202. def apply_segm(self, segms, offsets, im_size, size):
  1203. x, y = offsets
  1204. height, width = im_size
  1205. h, w = size
  1206. expanded_segms = []
  1207. for segm in segms:
  1208. if is_poly(segm):
  1209. # Polygon format
  1210. expanded_segms.append(
  1211. [expand_poly(poly, x, y) for poly in segm])
  1212. else:
  1213. # RLE format
  1214. expanded_segms.append(
  1215. expand_rle(segm, x, y, height, width, h, w))
  1216. return expanded_segms
  1217. def _get_offsets(self, im_h, im_w, h, w):
  1218. if self.pad_mode == -1:
  1219. offsets = self.offsets
  1220. elif self.pad_mode == 0:
  1221. offsets = [0, 0]
  1222. elif self.pad_mode == 1:
  1223. offsets = [(w - im_w) // 2, (h - im_h) // 2]
  1224. else:
  1225. offsets = [w - im_w, h - im_h]
  1226. return offsets
  1227. def apply(self, sample):
  1228. im_h, im_w = sample['image'].shape[:2]
  1229. if self.target_size:
  1230. h, w = self.target_size
  1231. assert (
  1232. im_h <= h and im_w <= w
  1233. ), 'target size ({}, {}) cannot be less than image size ({}, {})'\
  1234. .format(h, w, im_h, im_w)
  1235. else:
  1236. h = (np.ceil(im_h / self.size_divisor) *
  1237. self.size_divisor).astype(int)
  1238. w = (np.ceil(im_w / self.size_divisor) *
  1239. self.size_divisor).astype(int)
  1240. if h == im_h and w == im_w:
  1241. return sample
  1242. offsets = self._get_offsets(im_h, im_w, h, w)
  1243. sample['image'] = self.apply_im(sample['image'], offsets, (h, w))
  1244. if 'image2' in sample:
  1245. sample['image2'] = self.apply_im(sample['image2'], offsets, (h, w))
  1246. if 'mask' in sample:
  1247. sample['mask'] = self.apply_mask(sample['mask'], offsets, (h, w))
  1248. if 'aux_masks' in sample:
  1249. sample['aux_masks'] = list(
  1250. map(partial(
  1251. self.apply_mask, offsets=offsets, target_size=(h, w)),
  1252. sample['aux_masks']))
  1253. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  1254. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], offsets)
  1255. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  1256. sample['gt_poly'] = self.apply_segm(
  1257. sample['gt_poly'], offsets, im_size=[im_h, im_w], size=[h, w])
  1258. if 'target' in sample:
  1259. if 'sr_factor' in sample:
  1260. hr_shape = calc_hr_shape((h, w), sample['sr_factor'])
  1261. hr_offsets = self._get_offsets(*sample['target'].shape[:2],
  1262. *hr_shape)
  1263. sample['target'] = self.apply_im(sample['target'], hr_offsets,
  1264. hr_shape)
  1265. else:
  1266. sample['target'] = self.apply_im(sample['target'], offsets,
  1267. (h, w))
  1268. return sample
  1269. class MixupImage(Transform):
  1270. def __init__(self, alpha=1.5, beta=1.5, mixup_epoch=-1):
  1271. """
  1272. Mixup two images and their gt_bbbox/gt_score.
  1273. Args:
  1274. alpha (float, optional): Alpha parameter of beta distribution.
  1275. Defaults to 1.5.
  1276. beta (float, optional): Beta parameter of beta distribution.
  1277. Defaults to 1.5.
  1278. """
  1279. super(MixupImage, self).__init__()
  1280. if alpha <= 0.0:
  1281. raise ValueError("`alpha` should be positive in MixupImage.")
  1282. if beta <= 0.0:
  1283. raise ValueError("`beta` should be positive in MixupImage.")
  1284. self.alpha = alpha
  1285. self.beta = beta
  1286. self.mixup_epoch = mixup_epoch
  1287. def apply_im(self, image1, image2, factor):
  1288. h = max(image1.shape[0], image2.shape[0])
  1289. w = max(image1.shape[1], image2.shape[1])
  1290. img = np.zeros((h, w, image1.shape[2]), 'float32')
  1291. img[:image1.shape[0], :image1.shape[1], :] = \
  1292. image1.astype('float32') * factor
  1293. img[:image2.shape[0], :image2.shape[1], :] += \
  1294. image2.astype('float32') * (1.0 - factor)
  1295. return img.astype('uint8')
  1296. def __call__(self, sample):
  1297. if not isinstance(sample, Sequence):
  1298. return sample
  1299. assert len(sample) == 2, 'mixup need two samples'
  1300. factor = np.random.beta(self.alpha, self.beta)
  1301. factor = max(0.0, min(1.0, factor))
  1302. if factor >= 1.0:
  1303. return sample[0]
  1304. if factor <= 0.0:
  1305. return sample[1]
  1306. image = self.apply_im(sample[0]['image'], sample[1]['image'], factor)
  1307. result = copy.deepcopy(sample[0])
  1308. result['image'] = image
  1309. # Apply bbox and score
  1310. if 'gt_bbox' in sample[0]:
  1311. gt_bbox1 = sample[0]['gt_bbox']
  1312. gt_bbox2 = sample[1]['gt_bbox']
  1313. gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0)
  1314. result['gt_bbox'] = gt_bbox
  1315. if 'gt_poly' in sample[0]:
  1316. gt_poly1 = sample[0]['gt_poly']
  1317. gt_poly2 = sample[1]['gt_poly']
  1318. gt_poly = gt_poly1 + gt_poly2
  1319. result['gt_poly'] = gt_poly
  1320. if 'gt_class' in sample[0]:
  1321. gt_class1 = sample[0]['gt_class']
  1322. gt_class2 = sample[1]['gt_class']
  1323. gt_class = np.concatenate((gt_class1, gt_class2), axis=0)
  1324. result['gt_class'] = gt_class
  1325. gt_score1 = np.ones_like(sample[0]['gt_class'])
  1326. gt_score2 = np.ones_like(sample[1]['gt_class'])
  1327. gt_score = np.concatenate(
  1328. (gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
  1329. result['gt_score'] = gt_score
  1330. if 'is_crowd' in sample[0]:
  1331. is_crowd1 = sample[0]['is_crowd']
  1332. is_crowd2 = sample[1]['is_crowd']
  1333. is_crowd = np.concatenate((is_crowd1, is_crowd2), axis=0)
  1334. result['is_crowd'] = is_crowd
  1335. if 'difficult' in sample[0]:
  1336. is_difficult1 = sample[0]['difficult']
  1337. is_difficult2 = sample[1]['difficult']
  1338. is_difficult = np.concatenate(
  1339. (is_difficult1, is_difficult2), axis=0)
  1340. result['difficult'] = is_difficult
  1341. return result
  1342. class RandomDistort(Transform):
  1343. """
  1344. Random color distortion.
  1345. Args:
  1346. brightness_range (float, optional): Range of brightness distortion.
  1347. Defaults to .5.
  1348. brightness_prob (float, optional): Probability of brightness distortion.
  1349. Defaults to .5.
  1350. contrast_range (float, optional): Range of contrast distortion.
  1351. Defaults to .5.
  1352. contrast_prob (float, optional): Probability of contrast distortion.
  1353. Defaults to .5.
  1354. saturation_range (float, optional): Range of saturation distortion.
  1355. Defaults to .5.
  1356. saturation_prob (float, optional): Probability of saturation distortion.
  1357. Defaults to .5.
  1358. hue_range (float, optional): Range of hue distortion. Defaults to .5.
  1359. hue_prob (float, optional): Probability of hue distortion. Defaults to .5.
  1360. random_apply (bool, optional): Apply the transformation in random (yolo) or
  1361. fixed (SSD) order. Defaults to True.
  1362. count (int, optional): Number of distortions to apply. Defaults to 4.
  1363. shuffle_channel (bool, optional): Whether to swap channels randomly.
  1364. Defaults to False.
  1365. """
  1366. def __init__(self,
  1367. brightness_range=0.5,
  1368. brightness_prob=0.5,
  1369. contrast_range=0.5,
  1370. contrast_prob=0.5,
  1371. saturation_range=0.5,
  1372. saturation_prob=0.5,
  1373. hue_range=18,
  1374. hue_prob=0.5,
  1375. random_apply=True,
  1376. count=4,
  1377. shuffle_channel=False):
  1378. super(RandomDistort, self).__init__()
  1379. self.brightness_range = [1 - brightness_range, 1 + brightness_range]
  1380. self.brightness_prob = brightness_prob
  1381. self.contrast_range = [1 - contrast_range, 1 + contrast_range]
  1382. self.contrast_prob = contrast_prob
  1383. self.saturation_range = [1 - saturation_range, 1 + saturation_range]
  1384. self.saturation_prob = saturation_prob
  1385. self.hue_range = [1 - hue_range, 1 + hue_range]
  1386. self.hue_prob = hue_prob
  1387. self.random_apply = random_apply
  1388. self.count = count
  1389. self.shuffle_channel = shuffle_channel
  1390. def apply_hue(self, image):
  1391. low, high = self.hue_range
  1392. if np.random.uniform(0., 1.) < self.hue_prob:
  1393. return image
  1394. # It works, but the result differs from HSV version.
  1395. delta = np.random.uniform(low, high)
  1396. u = np.cos(delta * np.pi)
  1397. w = np.sin(delta * np.pi)
  1398. bt = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
  1399. tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.321],
  1400. [0.211, -0.523, 0.311]])
  1401. ityiq = np.array([[1.0, 0.956, 0.621], [1.0, -0.272, -0.647],
  1402. [1.0, -1.107, 1.705]])
  1403. t = np.dot(np.dot(ityiq, bt), tyiq).T
  1404. res_list = []
  1405. channel = image.shape[2]
  1406. for i in range(channel // 3):
  1407. sub_img = image[:, :, 3 * i:3 * (i + 1)]
  1408. sub_img = sub_img.astype(np.float32)
  1409. sub_img = np.dot(image, t)
  1410. res_list.append(sub_img)
  1411. if channel % 3 != 0:
  1412. i = channel % 3
  1413. res_list.append(image[:, :, -i:])
  1414. return np.concatenate(res_list, axis=2)
  1415. def apply_saturation(self, image):
  1416. low, high = self.saturation_range
  1417. delta = np.random.uniform(low, high)
  1418. if np.random.uniform(0., 1.) < self.saturation_prob:
  1419. return image
  1420. res_list = []
  1421. channel = image.shape[2]
  1422. for i in range(channel // 3):
  1423. sub_img = image[:, :, 3 * i:3 * (i + 1)]
  1424. sub_img = sub_img.astype(np.float32)
  1425. # It works, but the result differs from HSV version.
  1426. gray = sub_img * np.array(
  1427. [[[0.299, 0.587, 0.114]]], dtype=np.float32)
  1428. gray = gray.sum(axis=2, keepdims=True)
  1429. gray *= (1.0 - delta)
  1430. sub_img *= delta
  1431. sub_img += gray
  1432. res_list.append(sub_img)
  1433. if channel % 3 != 0:
  1434. i = channel % 3
  1435. res_list.append(image[:, :, -i:])
  1436. return np.concatenate(res_list, axis=2)
  1437. def apply_contrast(self, image):
  1438. low, high = self.contrast_range
  1439. if np.random.uniform(0., 1.) < self.contrast_prob:
  1440. return image
  1441. delta = np.random.uniform(low, high)
  1442. image = image.astype(np.float32)
  1443. image *= delta
  1444. return image
  1445. def apply_brightness(self, image):
  1446. low, high = self.brightness_range
  1447. if np.random.uniform(0., 1.) < self.brightness_prob:
  1448. return image
  1449. delta = np.random.uniform(low, high)
  1450. image = image.astype(np.float32)
  1451. image += delta
  1452. return image
  1453. def apply(self, sample):
  1454. if self.random_apply:
  1455. functions = [
  1456. self.apply_brightness, self.apply_contrast,
  1457. self.apply_saturation, self.apply_hue
  1458. ]
  1459. distortions = np.random.permutation(functions)[:self.count]
  1460. for func in distortions:
  1461. sample['image'] = func(sample['image'])
  1462. if 'image2' in sample:
  1463. sample['image2'] = func(sample['image2'])
  1464. return sample
  1465. sample['image'] = self.apply_brightness(sample['image'])
  1466. if 'image2' in sample:
  1467. sample['image2'] = self.apply_brightness(sample['image2'])
  1468. mode = np.random.randint(0, 2)
  1469. if mode:
  1470. sample['image'] = self.apply_contrast(sample['image'])
  1471. if 'image2' in sample:
  1472. sample['image2'] = self.apply_contrast(sample['image2'])
  1473. sample['image'] = self.apply_saturation(sample['image'])
  1474. sample['image'] = self.apply_hue(sample['image'])
  1475. if 'image2' in sample:
  1476. sample['image2'] = self.apply_saturation(sample['image2'])
  1477. sample['image2'] = self.apply_hue(sample['image2'])
  1478. if not mode:
  1479. sample['image'] = self.apply_contrast(sample['image'])
  1480. if 'image2' in sample:
  1481. sample['image2'] = self.apply_contrast(sample['image2'])
  1482. if self.shuffle_channel:
  1483. if np.random.randint(0, 2):
  1484. sample['image'] = sample['image'][..., np.random.permutation(3)]
  1485. if 'image2' in sample:
  1486. sample['image2'] = sample['image2'][
  1487. ..., np.random.permutation(3)]
  1488. return sample
  1489. class RandomBlur(Transform):
  1490. """
  1491. Randomly blur input image(s).
  1492. Args:
  1493. prob (float): Probability of blurring.
  1494. """
  1495. def __init__(self, prob=0.1):
  1496. super(RandomBlur, self).__init__()
  1497. self.prob = prob
  1498. def apply_im(self, image, radius):
  1499. image = cv2.GaussianBlur(image, (radius, radius), 0, 0)
  1500. return image
  1501. def apply(self, sample):
  1502. if self.prob <= 0:
  1503. n = 0
  1504. elif self.prob >= 1:
  1505. n = 1
  1506. else:
  1507. n = int(1.0 / self.prob)
  1508. if n > 0:
  1509. if np.random.randint(0, n) == 0:
  1510. radius = np.random.randint(3, 10)
  1511. if radius % 2 != 1:
  1512. radius = radius + 1
  1513. if radius > 9:
  1514. radius = 9
  1515. sample['image'] = self.apply_im(sample['image'], radius)
  1516. if 'image2' in sample:
  1517. sample['image2'] = self.apply_im(sample['image2'], radius)
  1518. return sample
  1519. class Dehaze(Transform):
  1520. """
  1521. Dehaze input image(s).
  1522. Args:
  1523. gamma (bool, optional): Use gamma correction or not. Defaults to False.
  1524. """
  1525. def __init__(self, gamma=False):
  1526. super(Dehaze, self).__init__()
  1527. self.gamma = gamma
  1528. def apply_im(self, image):
  1529. image = dehaze(image, self.gamma)
  1530. return image
  1531. def apply(self, sample):
  1532. sample['image'] = self.apply_im(sample['image'])
  1533. if 'image2' in sample:
  1534. sample['image2'] = self.apply_im(sample['image2'])
  1535. return sample
  1536. class ReduceDim(Transform):
  1537. """
  1538. Use PCA to reduce the dimension of input image(s).
  1539. Args:
  1540. joblib_path (str): Path of *.joblib file of PCA.
  1541. apply_to_tar (bool, optional): Whether to apply transformation to the target
  1542. image. Defaults to True.
  1543. """
  1544. def __init__(self, joblib_path, apply_to_tar=True):
  1545. super(ReduceDim, self).__init__()
  1546. ext = joblib_path.split(".")[-1]
  1547. if ext != "joblib":
  1548. raise ValueError("`joblib_path` must be *.joblib, not *.{}.".format(
  1549. ext))
  1550. self.pca = load(joblib_path)
  1551. self.apply_to_tar = apply_to_tar
  1552. def apply_im(self, image):
  1553. H, W, C = image.shape
  1554. n_im = np.reshape(image, (-1, C))
  1555. im_pca = self.pca.transform(n_im)
  1556. result = np.reshape(im_pca, (H, W, -1))
  1557. return result
  1558. def apply(self, sample):
  1559. sample['image'] = self.apply_im(sample['image'])
  1560. if 'image2' in sample:
  1561. sample['image2'] = self.apply_im(sample['image2'])
  1562. if 'target' in sample and self.apply_to_tar:
  1563. sample['target'] = self.apply_im(sample['target'])
  1564. return sample
  1565. class SelectBand(Transform):
  1566. """
  1567. Select a set of bands of input image(s).
  1568. Args:
  1569. band_list (list, optional): Bands to select (band index starts from 1).
  1570. Defaults to [1, 2, 3].
  1571. apply_to_tar (bool, optional): Whether to apply transformation to the target
  1572. image. Defaults to True.
  1573. """
  1574. def __init__(self, band_list=[1, 2, 3], apply_to_tar=True):
  1575. super(SelectBand, self).__init__()
  1576. self.band_list = band_list
  1577. self.apply_to_tar = apply_to_tar
  1578. def apply_im(self, image):
  1579. image = select_bands(image, self.band_list)
  1580. return image
  1581. def apply(self, sample):
  1582. sample['image'] = self.apply_im(sample['image'])
  1583. if 'image2' in sample:
  1584. sample['image2'] = self.apply_im(sample['image2'])
  1585. if 'target' in sample and self.apply_to_tar:
  1586. sample['target'] = self.apply_im(sample['target'])
  1587. return sample
  1588. class _PadBox(Transform):
  1589. def __init__(self, num_max_boxes=50):
  1590. """
  1591. Pad zeros to bboxes if number of bboxes is less than `num_max_boxes`.
  1592. Args:
  1593. num_max_boxes (int, optional): Max number of bboxes. Defaults to 50.
  1594. """
  1595. self.num_max_boxes = num_max_boxes
  1596. super(_PadBox, self).__init__()
  1597. def apply(self, sample):
  1598. gt_num = min(self.num_max_boxes, len(sample['gt_bbox']))
  1599. num_max = self.num_max_boxes
  1600. pad_bbox = np.zeros((num_max, 4), dtype=np.float32)
  1601. if gt_num > 0:
  1602. pad_bbox[:gt_num, :] = sample['gt_bbox'][:gt_num, :]
  1603. sample['gt_bbox'] = pad_bbox
  1604. if 'gt_class' in sample:
  1605. pad_class = np.zeros((num_max, ), dtype=np.int32)
  1606. if gt_num > 0:
  1607. pad_class[:gt_num] = sample['gt_class'][:gt_num, 0]
  1608. sample['gt_class'] = pad_class
  1609. if 'gt_score' in sample:
  1610. pad_score = np.zeros((num_max, ), dtype=np.float32)
  1611. if gt_num > 0:
  1612. pad_score[:gt_num] = sample['gt_score'][:gt_num, 0]
  1613. sample['gt_score'] = pad_score
  1614. # In training, for example in op ExpandImage,
  1615. # bbox and gt_class are expanded, but difficult is not,
  1616. # so judge by its length.
  1617. if 'difficult' in sample:
  1618. pad_diff = np.zeros((num_max, ), dtype=np.int32)
  1619. if gt_num > 0:
  1620. pad_diff[:gt_num] = sample['difficult'][:gt_num, 0]
  1621. sample['difficult'] = pad_diff
  1622. if 'is_crowd' in sample:
  1623. pad_crowd = np.zeros((num_max, ), dtype=np.int32)
  1624. if gt_num > 0:
  1625. pad_crowd[:gt_num] = sample['is_crowd'][:gt_num, 0]
  1626. sample['is_crowd'] = pad_crowd
  1627. return sample
  1628. class _NormalizeBox(Transform):
  1629. def __init__(self):
  1630. super(_NormalizeBox, self).__init__()
  1631. def apply(self, sample):
  1632. height, width = sample['image'].shape[:2]
  1633. for i in range(sample['gt_bbox'].shape[0]):
  1634. sample['gt_bbox'][i][0] = sample['gt_bbox'][i][0] / width
  1635. sample['gt_bbox'][i][1] = sample['gt_bbox'][i][1] / height
  1636. sample['gt_bbox'][i][2] = sample['gt_bbox'][i][2] / width
  1637. sample['gt_bbox'][i][3] = sample['gt_bbox'][i][3] / height
  1638. return sample
  1639. class _BboxXYXY2XYWH(Transform):
  1640. """
  1641. Convert bbox XYXY format to XYWH format.
  1642. """
  1643. def __init__(self):
  1644. super(_BboxXYXY2XYWH, self).__init__()
  1645. def apply(self, sample):
  1646. bbox = sample['gt_bbox']
  1647. bbox[:, 2:4] = bbox[:, 2:4] - bbox[:, :2]
  1648. bbox[:, :2] = bbox[:, :2] + bbox[:, 2:4] / 2.
  1649. sample['gt_bbox'] = bbox
  1650. return sample
  1651. class _Permute(Transform):
  1652. def __init__(self):
  1653. super(_Permute, self).__init__()
  1654. def apply(self, sample):
  1655. sample['image'] = permute(sample['image'], False)
  1656. if 'image2' in sample:
  1657. sample['image2'] = permute(sample['image2'], False)
  1658. if 'target' in sample:
  1659. sample['target'] = permute(sample['target'], False)
  1660. return sample
  1661. class RandomSwap(Transform):
  1662. """
  1663. Randomly swap multi-temporal images.
  1664. Args:
  1665. prob (float, optional): Probability of swapping the input images.
  1666. Default: 0.2.
  1667. """
  1668. def __init__(self, prob=0.2):
  1669. super(RandomSwap, self).__init__()
  1670. self.prob = prob
  1671. def apply(self, sample):
  1672. if 'image2' not in sample:
  1673. raise ValueError("'image2' is not found in the sample.")
  1674. if random.random() < self.prob:
  1675. sample['image'], sample['image2'] = sample['image2'], sample[
  1676. 'image']
  1677. return sample
  1678. class ReloadMask(Transform):
  1679. def apply(self, sample):
  1680. sample['mask'] = decode_seg_mask(sample['mask_ori'])
  1681. if 'aux_masks' in sample:
  1682. sample['aux_masks'] = list(
  1683. map(decode_seg_mask, sample['aux_masks_ori']))
  1684. return sample
  1685. class AppendIndex(Transform):
  1686. """
  1687. Append remote sensing index to input image(s).
  1688. Args:
  1689. index_type (str): Type of remote sensinng index. See supported
  1690. index types in
  1691. https://github.com/PaddlePaddle/PaddleRS/tree/develop/paddlers/transforms/indices.py .
  1692. band_indices (dict, optional): Mapping of band names to band indices
  1693. (starting from 1). See band names in
  1694. https://github.com/PaddlePaddle/PaddleRS/tree/develop/paddlers/transforms/indices.py .
  1695. Default: None.
  1696. satellite (str, optional): Type of satellite. If set,
  1697. band indices will be automatically determined accordingly. See supported satellites in
  1698. https://github.com/PaddlePaddle/PaddleRS/tree/develop/paddlers/transforms/satellites.py .
  1699. Default: None.
  1700. """
  1701. def __init__(self, index_type, band_indices=None, satellite=None, **kwargs):
  1702. super(AppendIndex, self).__init__()
  1703. cls = getattr(indices, index_type)
  1704. if satellite is not None:
  1705. satellite_bands = getattr(satellites, satellite)
  1706. self._compute_index = cls(satellite_bands, **kwargs)
  1707. else:
  1708. if band_indices is None:
  1709. raise ValueError(
  1710. "At least one of `band_indices` and `satellite` must not be None."
  1711. )
  1712. else:
  1713. self._compute_index = cls(band_indices, **kwargs)
  1714. def apply_im(self, image):
  1715. index = self._compute_index(image)
  1716. index = index[..., None].astype('float32')
  1717. return np.concatenate([image, index], axis=-1)
  1718. def apply(self, sample):
  1719. sample['image'] = self.apply_im(sample['image'])
  1720. if 'image2' in sample:
  1721. sample['image2'] = self.apply_im(sample['image2'])
  1722. return sample
  1723. class MatchRadiance(Transform):
  1724. """
  1725. Perform relative radiometric correction between bi-temporal images.
  1726. Args:
  1727. method (str, optional): Method used to match the radiance of the
  1728. bi-temporal images. Choices are {'hist', 'lsr'}. 'hist' stands
  1729. for histogram matching and 'lsr' stands for least-squares
  1730. regression. Default: 'hist'.
  1731. """
  1732. def __init__(self, method='hist'):
  1733. super(MatchRadiance, self).__init__()
  1734. if method == 'hist':
  1735. self._match_func = match_histograms
  1736. elif method == 'lsr':
  1737. self._match_func = match_by_regression
  1738. else:
  1739. raise ValueError(
  1740. "{} is not a supported radiometric correction method.".format(
  1741. method))
  1742. self.method = method
  1743. def apply(self, sample):
  1744. if 'image2' not in sample:
  1745. raise ValueError("'image2' is not found in the sample.")
  1746. sample['image2'] = self._match_func(sample['image2'], sample['image'])
  1747. return sample
  1748. class Arrange(Transform):
  1749. def __init__(self, mode):
  1750. super().__init__()
  1751. if mode not in ['train', 'eval', 'test', 'quant']:
  1752. raise ValueError(
  1753. "`mode` should be defined as one of ['train', 'eval', 'test', 'quant']!"
  1754. )
  1755. self.mode = mode
  1756. class ArrangeSegmenter(Arrange):
  1757. def apply(self, sample):
  1758. if 'mask' in sample:
  1759. mask = sample['mask']
  1760. mask = mask.astype('int64')
  1761. image = permute(sample['image'], False)
  1762. if self.mode == 'train':
  1763. return image, mask
  1764. if self.mode == 'eval':
  1765. return image, mask
  1766. if self.mode == 'test':
  1767. return image,
  1768. class ArrangeChangeDetector(Arrange):
  1769. def apply(self, sample):
  1770. if 'mask' in sample:
  1771. mask = sample['mask']
  1772. mask = mask.astype('int64')
  1773. image_t1 = permute(sample['image'], False)
  1774. image_t2 = permute(sample['image2'], False)
  1775. if self.mode == 'train':
  1776. masks = [mask]
  1777. if 'aux_masks' in sample:
  1778. masks.extend(
  1779. map(methodcaller('astype', 'int64'), sample['aux_masks']))
  1780. return (
  1781. image_t1,
  1782. image_t2, ) + tuple(masks)
  1783. if self.mode == 'eval':
  1784. return image_t1, image_t2, mask
  1785. if self.mode == 'test':
  1786. return image_t1, image_t2,
  1787. class ArrangeClassifier(Arrange):
  1788. def apply(self, sample):
  1789. image = permute(sample['image'], False)
  1790. if self.mode in ['train', 'eval']:
  1791. return image, sample['label']
  1792. else:
  1793. return image
  1794. class ArrangeDetector(Arrange):
  1795. def apply(self, sample):
  1796. if self.mode == 'eval' and 'gt_poly' in sample:
  1797. del sample['gt_poly']
  1798. return sample
  1799. class ArrangeRestorer(Arrange):
  1800. def apply(self, sample):
  1801. if 'target' in sample:
  1802. target = permute(sample['target'], False)
  1803. image = permute(sample['image'], False)
  1804. if self.mode == 'train':
  1805. return image, target
  1806. if self.mode == 'eval':
  1807. return image, target
  1808. if self.mode == 'test':
  1809. return image,