operators.py 69 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import copy
  16. import random
  17. from numbers import Number
  18. from functools import partial
  19. from operator import methodcaller
  20. try:
  21. from collections.abc import Sequence
  22. except Exception:
  23. from collections import Sequence
  24. import numpy as np
  25. import cv2
  26. import imghdr
  27. from PIL import Image
  28. from joblib import load
  29. import paddlers
  30. from .functions import (
  31. normalize, horizontal_flip, permute, vertical_flip, center_crop, is_poly,
  32. horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly,
  33. vertical_flip_rle, crop_poly, crop_rle, expand_poly, expand_rle,
  34. resize_poly, resize_rle, dehaze, select_bands, to_intensity, to_uint8,
  35. img_flip, img_simple_rotate, decode_seg_mask)
  36. __all__ = [
  37. "Compose", "DecodeImg", "Resize", "RandomResize", "ResizeByShort",
  38. "RandomResizeByShort", "ResizeByLong", "RandomHorizontalFlip",
  39. "RandomVerticalFlip", "Normalize", "CenterCrop", "RandomCrop",
  40. "RandomScaleAspect", "RandomExpand", "Pad", "MixupImage", "RandomDistort",
  41. "RandomBlur", "RandomSwap", "Dehaze", "ReduceDim", "SelectBand",
  42. "ArrangeSegmenter", "ArrangeChangeDetector", "ArrangeClassifier",
  43. "ArrangeDetector", "RandomFlipOrRotate", "ReloadMask"
  44. ]
  45. interp_dict = {
  46. 'NEAREST': cv2.INTER_NEAREST,
  47. 'LINEAR': cv2.INTER_LINEAR,
  48. 'CUBIC': cv2.INTER_CUBIC,
  49. 'AREA': cv2.INTER_AREA,
  50. 'LANCZOS4': cv2.INTER_LANCZOS4
  51. }
  52. class Compose(object):
  53. """
  54. Apply a series of data augmentation strategies to the input.
  55. All input images should be in Height-Width-Channel ([H, W, C]) format.
  56. Args:
  57. transforms (list[paddlers.transforms.Transform]): List of data preprocess or
  58. augmentation operators.
  59. Raises:
  60. TypeError: Invalid type of transforms.
  61. ValueError: Invalid length of transforms.
  62. """
  63. def __init__(self, transforms):
  64. super(Compose, self).__init__()
  65. if not isinstance(transforms, list):
  66. raise TypeError(
  67. "Type of transforms is invalid. Must be a list, but received is {}."
  68. .format(type(transforms)))
  69. if len(transforms) < 1:
  70. raise ValueError(
  71. "Length of transforms must not be less than 1, but received is {}."
  72. .format(len(transforms)))
  73. transforms = copy.deepcopy(transforms)
  74. self.arrange = self._pick_arrange(transforms)
  75. self.transforms = transforms
  76. def __call__(self, sample):
  77. """
  78. This is equivalent to sequentially calling compose_obj.apply_transforms()
  79. and compose_obj.arrange_outputs().
  80. """
  81. sample = self.apply_transforms(sample)
  82. sample = self.arrange_outputs(sample)
  83. return sample
  84. def apply_transforms(self, sample):
  85. for op in self.transforms:
  86. # Skip batch transforms amd mixup
  87. if isinstance(op, (paddlers.transforms.BatchRandomResize,
  88. paddlers.transforms.BatchRandomResizeByShort,
  89. MixupImage)):
  90. continue
  91. sample = op(sample)
  92. return sample
  93. def arrange_outputs(self, sample):
  94. if self.arrange is not None:
  95. sample = self.arrange(sample)
  96. return sample
  97. def _pick_arrange(self, transforms):
  98. arrange = None
  99. for idx, op in enumerate(transforms):
  100. if isinstance(op, Arrange):
  101. if idx != len(transforms) - 1:
  102. raise ValueError(
  103. "Arrange operator must be placed at the end of the list."
  104. )
  105. arrange = transforms.pop(idx)
  106. return arrange
  107. class Transform(object):
  108. """
  109. Parent class of all data augmentation operators.
  110. """
  111. def __init__(self):
  112. pass
  113. def apply_im(self, image):
  114. pass
  115. def apply_mask(self, mask):
  116. pass
  117. def apply_bbox(self, bbox):
  118. pass
  119. def apply_segm(self, segms):
  120. pass
  121. def apply(self, sample):
  122. if 'image' in sample:
  123. sample['image'] = self.apply_im(sample['image'])
  124. else: # image_tx
  125. sample['image'] = self.apply_im(sample['image_t1'])
  126. sample['image2'] = self.apply_im(sample['image_t2'])
  127. if 'mask' in sample:
  128. sample['mask'] = self.apply_mask(sample['mask'])
  129. if 'gt_bbox' in sample:
  130. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'])
  131. if 'aux_masks' in sample:
  132. sample['aux_masks'] = list(
  133. map(self.apply_mask, sample['aux_masks']))
  134. return sample
  135. def __call__(self, sample):
  136. if isinstance(sample, Sequence):
  137. sample = [self.apply(s) for s in sample]
  138. else:
  139. sample = self.apply(sample)
  140. return sample
  141. class DecodeImg(Transform):
  142. """
  143. Decode image(s) in input.
  144. Args:
  145. to_rgb (bool, optional): If True, convert input image(s) from BGR format to
  146. RGB format. Defaults to True.
  147. to_uint8 (bool, optional): If True, quantize and convert decoded image(s) to
  148. uint8 type. Defaults to True.
  149. decode_bgr (bool, optional): If True, automatically interpret a non-geo image
  150. (e.g., jpeg images) as a BGR image. Defaults to True.
  151. decode_sar (bool, optional): If True, automatically interpret a two-channel
  152. geo image (e.g. geotiff images) as a SAR image, set this argument to
  153. True. Defaults to True.
  154. read_geo_info (bool, optional): If True, read geographical information from
  155. the image. Deafults to False.
  156. """
  157. def __init__(self,
  158. to_rgb=True,
  159. to_uint8=True,
  160. decode_bgr=True,
  161. decode_sar=True,
  162. read_geo_info=False):
  163. super(DecodeImg, self).__init__()
  164. self.to_rgb = to_rgb
  165. self.to_uint8 = to_uint8
  166. self.decode_bgr = decode_bgr
  167. self.decode_sar = decode_sar
  168. self.read_geo_info = False
  169. def read_img(self, img_path):
  170. img_format = imghdr.what(img_path)
  171. name, ext = os.path.splitext(img_path)
  172. geo_trans, geo_proj = None, None
  173. if img_format == 'tiff' or ext == '.img':
  174. try:
  175. import gdal
  176. except:
  177. try:
  178. from osgeo import gdal
  179. except ImportError:
  180. raise ImportError(
  181. "Failed to import gdal! Please install GDAL library according to the document."
  182. )
  183. dataset = gdal.Open(img_path)
  184. if dataset == None:
  185. raise IOError('Cannot open', img_path)
  186. im_data = dataset.ReadAsArray()
  187. if im_data.ndim == 2 and self.decode_sar:
  188. im_data = to_intensity(im_data)
  189. im_data = im_data[:, :, np.newaxis]
  190. else:
  191. if im_data.ndim == 3:
  192. im_data = im_data.transpose((1, 2, 0))
  193. if self.read_geo_info:
  194. geo_trans = dataset.GetGeoTransform()
  195. geo_proj = dataset.GetGeoProjection()
  196. elif img_format in ['jpeg', 'bmp', 'png', 'jpg']:
  197. if self.decode_bgr:
  198. im_data = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
  199. cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR)
  200. else:
  201. im_data = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
  202. cv2.IMREAD_ANYCOLOR)
  203. elif ext == '.npy':
  204. im_data = np.load(img_path)
  205. else:
  206. raise TypeError("Image format {} is not supported!".format(ext))
  207. if self.read_geo_info:
  208. return im_data, geo_trans, geo_proj
  209. else:
  210. return im_data
  211. def apply_im(self, im_path):
  212. if isinstance(im_path, str):
  213. try:
  214. data = self.read_img(im_path)
  215. except:
  216. raise ValueError("Cannot read the image file {}!".format(
  217. im_path))
  218. if self.read_geo_info:
  219. image, geo_trans, geo_proj = data
  220. geo_info_dict = {'geo_trans': geo_trans, 'geo_proj': geo_proj}
  221. else:
  222. image = data
  223. else:
  224. image = im_path
  225. if self.to_rgb and image.shape[-1] == 3:
  226. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  227. if self.to_uint8:
  228. image = to_uint8(image)
  229. if self.read_geo_info:
  230. return image, geo_info_dict
  231. else:
  232. return image
  233. def apply_mask(self, mask):
  234. try:
  235. mask = np.asarray(Image.open(mask))
  236. except:
  237. raise ValueError("Cannot read the mask file {}!".format(mask))
  238. if len(mask.shape) != 2:
  239. raise ValueError(
  240. "Mask should be a 1-channel image, but recevied is a {}-channel image.".
  241. format(mask.shape[2]))
  242. return mask
  243. def apply(self, sample):
  244. """
  245. Args:
  246. sample (dict): Input sample.
  247. Returns:
  248. dict: Sample with decoded images.
  249. """
  250. if 'image' in sample:
  251. if self.read_geo_info:
  252. image, geo_info_dict = self.apply_im(sample['image'])
  253. sample['image'] = image
  254. sample['geo_info_dict'] = geo_info_dict
  255. else:
  256. sample['image'] = self.apply_im(sample['image'])
  257. if 'image2' in sample:
  258. if self.read_geo_info:
  259. image2, geo_info_dict2 = self.apply_im(sample['image2'])
  260. sample['image2'] = image2
  261. sample['geo_info_dict2'] = geo_info_dict2
  262. else:
  263. sample['image2'] = self.apply_im(sample['image2'])
  264. if 'image_t1' in sample and not 'image' in sample:
  265. if not ('image_t2' in sample and 'image2' not in sample):
  266. raise ValueError
  267. if self.read_geo_info:
  268. image, geo_info_dict = self.apply_im(sample['image_t1'])
  269. sample['image'] = image
  270. sample['geo_info_dict'] = geo_info_dict
  271. else:
  272. sample['image'] = self.apply_im(sample['image_t1'])
  273. if self.read_geo_info:
  274. image2, geo_info_dict2 = self.apply_im(sample['image_t2'])
  275. sample['image2'] = image2
  276. sample['geo_info_dict2'] = geo_info_dict2
  277. else:
  278. sample['image2'] = self.apply_im(sample['image_t2'])
  279. if 'mask' in sample:
  280. sample['mask_ori'] = copy.deepcopy(sample['mask'])
  281. sample['mask'] = self.apply_mask(sample['mask'])
  282. im_height, im_width, _ = sample['image'].shape
  283. se_height, se_width = sample['mask'].shape
  284. if im_height != se_height or im_width != se_width:
  285. raise ValueError(
  286. "The height or width of the image is not same as the mask.")
  287. if 'aux_masks' in sample:
  288. sample['aux_masks_ori'] = copy.deepcopy(sample['aux_masks'])
  289. sample['aux_masks'] = list(
  290. map(self.apply_mask, sample['aux_masks']))
  291. # TODO: check the shape of auxiliary masks
  292. sample['im_shape'] = np.array(
  293. sample['image'].shape[:2], dtype=np.float32)
  294. sample['scale_factor'] = np.array([1., 1.], dtype=np.float32)
  295. return sample
  296. class Resize(Transform):
  297. """
  298. Resize input.
  299. - If `target_size` is an int, resize the image(s) to (`target_size`, `target_size`).
  300. - If `target_size` is a list or tuple, resize the image(s) to `target_size`.
  301. Attention: If `interp` is 'RANDOM', the interpolation method will be chosen randomly.
  302. Args:
  303. target_size (int | list[int] | tuple[int]): Target size. If it is an integer, the
  304. target height and width will be both set to `target_size`. Otherwise,
  305. `target_size` represents [target height, target width].
  306. interp (str, optional): Interpolation method for resizing image(s). One of
  307. {'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}.
  308. Defaults to 'LINEAR'.
  309. keep_ratio (bool, optional): If True, the scaling factor of width and height will
  310. be set to same value, and height/width of the resized image will be not
  311. greater than the target width/height. Defaults to False.
  312. Raises:
  313. TypeError: Invalid type of target_size.
  314. ValueError: Invalid interpolation method.
  315. """
  316. def __init__(self, target_size, interp='LINEAR', keep_ratio=False):
  317. super(Resize, self).__init__()
  318. if not (interp == "RANDOM" or interp in interp_dict):
  319. raise ValueError("`interp` should be one of {}.".format(
  320. interp_dict.keys()))
  321. if isinstance(target_size, int):
  322. target_size = (target_size, target_size)
  323. else:
  324. if not (isinstance(target_size,
  325. (list, tuple)) and len(target_size) == 2):
  326. raise TypeError(
  327. "`target_size` should be an int or a list of length 2, but received {}.".
  328. format(target_size))
  329. # (height, width)
  330. self.target_size = target_size
  331. self.interp = interp
  332. self.keep_ratio = keep_ratio
  333. def apply_im(self, image, interp, target_size):
  334. flag = image.shape[2] == 1
  335. image = cv2.resize(image, target_size, interpolation=interp)
  336. if flag:
  337. image = image[:, :, np.newaxis]
  338. return image
  339. def apply_mask(self, mask, target_size):
  340. mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
  341. return mask
  342. def apply_bbox(self, bbox, scale, target_size):
  343. im_scale_x, im_scale_y = scale
  344. bbox[:, 0::2] *= im_scale_x
  345. bbox[:, 1::2] *= im_scale_y
  346. bbox[:, 0::2] = np.clip(bbox[:, 0::2], 0, target_size[0])
  347. bbox[:, 1::2] = np.clip(bbox[:, 1::2], 0, target_size[1])
  348. return bbox
  349. def apply_segm(self, segms, im_size, scale):
  350. im_h, im_w = im_size
  351. im_scale_x, im_scale_y = scale
  352. resized_segms = []
  353. for segm in segms:
  354. if is_poly(segm):
  355. # Polygon format
  356. resized_segms.append([
  357. resize_poly(poly, im_scale_x, im_scale_y) for poly in segm
  358. ])
  359. else:
  360. # RLE format
  361. resized_segms.append(
  362. resize_rle(segm, im_h, im_w, im_scale_x, im_scale_y))
  363. return resized_segms
  364. def apply(self, sample):
  365. if self.interp == "RANDOM":
  366. interp = random.choice(list(interp_dict.values()))
  367. else:
  368. interp = interp_dict[self.interp]
  369. im_h, im_w = sample['image'].shape[:2]
  370. im_scale_y = self.target_size[0] / im_h
  371. im_scale_x = self.target_size[1] / im_w
  372. target_size = (self.target_size[1], self.target_size[0])
  373. if self.keep_ratio:
  374. scale = min(im_scale_y, im_scale_x)
  375. target_w = int(round(im_w * scale))
  376. target_h = int(round(im_h * scale))
  377. target_size = (target_w, target_h)
  378. im_scale_y = target_h / im_h
  379. im_scale_x = target_w / im_w
  380. sample['image'] = self.apply_im(sample['image'], interp, target_size)
  381. if 'image2' in sample:
  382. sample['image2'] = self.apply_im(sample['image2'], interp,
  383. target_size)
  384. if 'mask' in sample:
  385. sample['mask'] = self.apply_mask(sample['mask'], target_size)
  386. if 'aux_masks' in sample:
  387. sample['aux_masks'] = list(
  388. map(partial(
  389. self.apply_mask, target_size=target_size),
  390. sample['aux_masks']))
  391. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  392. sample['gt_bbox'] = self.apply_bbox(
  393. sample['gt_bbox'], [im_scale_x, im_scale_y], target_size)
  394. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  395. sample['gt_poly'] = self.apply_segm(
  396. sample['gt_poly'], [im_h, im_w], [im_scale_x, im_scale_y])
  397. sample['im_shape'] = np.asarray(
  398. sample['image'].shape[:2], dtype=np.float32)
  399. if 'scale_factor' in sample:
  400. scale_factor = sample['scale_factor']
  401. sample['scale_factor'] = np.asarray(
  402. [scale_factor[0] * im_scale_y, scale_factor[1] * im_scale_x],
  403. dtype=np.float32)
  404. return sample
  405. class RandomResize(Transform):
  406. """
  407. Resize input to random sizes.
  408. Attention: If `interp` is 'RANDOM', the interpolation method will be chosen randomly.
  409. Args:
  410. target_sizes (list[int] | list[list|tuple] | tuple[list|tuple]):
  411. Multiple target sizes, each of which should be int, list, or tuple.
  412. interp (str, optional): Interpolation method for resizing image(s). One of
  413. {'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}.
  414. Defaults to 'LINEAR'.
  415. Raises:
  416. TypeError: Invalid type of `target_size`.
  417. ValueError: Invalid interpolation method.
  418. """
  419. def __init__(self, target_sizes, interp='LINEAR'):
  420. super(RandomResize, self).__init__()
  421. if not (interp == "RANDOM" or interp in interp_dict):
  422. raise ValueError("`interp` should be one of {}.".format(
  423. interp_dict.keys()))
  424. self.interp = interp
  425. assert isinstance(target_sizes, list), \
  426. "`target_size` must be a list."
  427. for i, item in enumerate(target_sizes):
  428. if isinstance(item, int):
  429. target_sizes[i] = (item, item)
  430. self.target_size = target_sizes
  431. def apply(self, sample):
  432. height, width = random.choice(self.target_size)
  433. resizer = Resize((height, width), interp=self.interp)
  434. sample = resizer(sample)
  435. return sample
  436. class ResizeByShort(Transform):
  437. """
  438. Resize input while keeping the aspect ratio.
  439. Attention: If `interp` is 'RANDOM', the interpolation method will be chosen randomly.
  440. Args:
  441. short_size (int): Target size of the shorter side of the image(s).
  442. max_size (int, optional): Upper bound of longer side of the image(s). If
  443. `max_size` is -1, no upper bound will be applied. Defaults to -1.
  444. interp (str, optional): Interpolation method for resizing image(s). One of
  445. {'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}.
  446. Defaults to 'LINEAR'.
  447. Raises:
  448. ValueError: Invalid interpolation method.
  449. """
  450. def __init__(self, short_size=256, max_size=-1, interp='LINEAR'):
  451. if not (interp == "RANDOM" or interp in interp_dict):
  452. raise ValueError("`interp` should be one of {}".format(
  453. interp_dict.keys()))
  454. super(ResizeByShort, self).__init__()
  455. self.short_size = short_size
  456. self.max_size = max_size
  457. self.interp = interp
  458. def apply(self, sample):
  459. im_h, im_w = sample['image'].shape[:2]
  460. im_short_size = min(im_h, im_w)
  461. im_long_size = max(im_h, im_w)
  462. scale = float(self.short_size) / float(im_short_size)
  463. if 0 < self.max_size < np.round(scale * im_long_size):
  464. scale = float(self.max_size) / float(im_long_size)
  465. target_w = int(round(im_w * scale))
  466. target_h = int(round(im_h * scale))
  467. sample = Resize(
  468. target_size=(target_h, target_w), interp=self.interp)(sample)
  469. return sample
  470. class RandomResizeByShort(Transform):
  471. """
  472. Resize input to random sizes while keeping the aspect ratio.
  473. Attention: If `interp` is 'RANDOM', the interpolation method will be chosen randomly.
  474. Args:
  475. short_sizes (list[int]): Target size of the shorter side of the image(s).
  476. max_size (int, optional): Upper bound of longer side of the image(s).
  477. If `max_size` is -1, no upper bound will be applied. Defaults to -1.
  478. interp (str, optional): Interpolation method for resizing image(s). One of
  479. {'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}.
  480. Defaults to 'LINEAR'.
  481. Raises:
  482. TypeError: Invalid type of `target_size`.
  483. ValueError: Invalid interpolation method.
  484. See Also:
  485. ResizeByShort: Resize image(s) in input while keeping the aspect ratio.
  486. """
  487. def __init__(self, short_sizes, max_size=-1, interp='LINEAR'):
  488. super(RandomResizeByShort, self).__init__()
  489. if not (interp == "RANDOM" or interp in interp_dict):
  490. raise ValueError("`interp` should be one of {}".format(
  491. interp_dict.keys()))
  492. self.interp = interp
  493. assert isinstance(short_sizes, list), \
  494. "`short_sizes` must be a list."
  495. self.short_sizes = short_sizes
  496. self.max_size = max_size
  497. def apply(self, sample):
  498. short_size = random.choice(self.short_sizes)
  499. resizer = ResizeByShort(
  500. short_size=short_size, max_size=self.max_size, interp=self.interp)
  501. sample = resizer(sample)
  502. return sample
  503. class ResizeByLong(Transform):
  504. def __init__(self, long_size=256, interp='LINEAR'):
  505. super(ResizeByLong, self).__init__()
  506. self.long_size = long_size
  507. self.interp = interp
  508. def apply(self, sample):
  509. im_h, im_w = sample['image'].shape[:2]
  510. im_long_size = max(im_h, im_w)
  511. scale = float(self.long_size) / float(im_long_size)
  512. target_h = int(round(im_h * scale))
  513. target_w = int(round(im_w * scale))
  514. sample = Resize(
  515. target_size=(target_h, target_w), interp=self.interp)(sample)
  516. return sample
  517. class RandomFlipOrRotate(Transform):
  518. """
  519. Flip or Rotate an image in different directions with a certain probability.
  520. Args:
  521. probs (list[float]): Probabilities of performing flipping and rotation.
  522. Default: [0.35,0.25].
  523. probsf (list[float]): Probabilities of 5 flipping modes (horizontal,
  524. vertical, both horizontal diction and vertical, diagonal,
  525. anti-diagonal). Default: [0.3, 0.3, 0.2, 0.1, 0.1].
  526. probsr (list[float]): Probabilities of 3 rotation modes (90°, 180°, 270°
  527. clockwise). Default: [0.25,0.5,0.25].
  528. Examples:
  529. from paddlers import transforms as T
  530. # Define operators for data augmentation
  531. train_transforms = T.Compose([
  532. T.DecodeImg(),
  533. T.RandomFlipOrRotate(
  534. probs = [0.3, 0.2] # p=0.3 to flip the image,p=0.2 to rotate the image,p=0.5 to keep the image unchanged.
  535. probsf = [0.3, 0.25, 0, 0, 0] # p=0.3 and p=0.25 to perform horizontal and vertical flipping; probility of no-flipping is 0.45.
  536. probsr = [0, 0.65, 0]), # p=0.65 to rotate the image by 180°; probility of no-rotation is 0.35.
  537. T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  538. ])
  539. """
  540. def __init__(self,
  541. probs=[0.35, 0.25],
  542. probsf=[0.3, 0.3, 0.2, 0.1, 0.1],
  543. probsr=[0.25, 0.5, 0.25]):
  544. super(RandomFlipOrRotate, self).__init__()
  545. # Change various probabilities into probability intervals, to judge in which mode to flip or rotate
  546. self.probs = [probs[0], probs[0] + probs[1]]
  547. self.probsf = self.get_probs_range(probsf)
  548. self.probsr = self.get_probs_range(probsr)
  549. def apply_im(self, image, mode_id, flip_mode=True):
  550. if flip_mode:
  551. image = img_flip(image, mode_id)
  552. else:
  553. image = img_simple_rotate(image, mode_id)
  554. return image
  555. def apply_mask(self, mask, mode_id, flip_mode=True):
  556. if flip_mode:
  557. mask = img_flip(mask, mode_id)
  558. else:
  559. mask = img_simple_rotate(mask, mode_id)
  560. return mask
  561. def apply_bbox(self, bbox, mode_id, flip_mode=True):
  562. raise TypeError(
  563. "Currently, RandomFlipOrRotate is not available for object detection tasks."
  564. )
  565. def apply_segm(self, bbox, mode_id, flip_mode=True):
  566. raise TypeError(
  567. "Currently, RandomFlipOrRotate is not available for object detection tasks."
  568. )
  569. def get_probs_range(self, probs):
  570. """
  571. Change list of probabilities into cumulative probability intervals.
  572. Args:
  573. probs (list[float]): Probabilities of different modes, shape: [n].
  574. Returns:
  575. list[list]: Probability intervals, shape: [n, 2].
  576. """
  577. ps = []
  578. last_prob = 0
  579. for prob in probs:
  580. p_s = last_prob
  581. cur_prob = prob / sum(probs)
  582. last_prob += cur_prob
  583. p_e = last_prob
  584. ps.append([p_s, p_e])
  585. return ps
  586. def judge_probs_range(self, p, probs):
  587. """
  588. Judge whether the value of `p` falls within the given probability interval.
  589. Args:
  590. p (float): Value between 0 and 1.
  591. probs (list[list]): Probability intervals, shape: [n, 2].
  592. Returns:
  593. int: Interval where the input probability falls into.
  594. """
  595. for id, id_range in enumerate(probs):
  596. if p > id_range[0] and p < id_range[1]:
  597. return id
  598. return -1
  599. def apply(self, sample):
  600. p_m = random.random()
  601. if p_m < self.probs[0]:
  602. mode_p = random.random()
  603. mode_id = self.judge_probs_range(mode_p, self.probsf)
  604. sample['image'] = self.apply_im(sample['image'], mode_id, True)
  605. if 'image2' in sample:
  606. sample['image2'] = self.apply_im(sample['image2'], mode_id,
  607. True)
  608. if 'mask' in sample:
  609. sample['mask'] = self.apply_mask(sample['mask'], mode_id, True)
  610. if 'aux_masks' in sample:
  611. sample['aux_masks'] = [
  612. self.apply_mask(aux_mask, mode_id, True)
  613. for aux_mask in sample['aux_masks']
  614. ]
  615. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  616. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], mode_id,
  617. True)
  618. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  619. sample['gt_poly'] = self.apply_segm(sample['gt_poly'], mode_id,
  620. True)
  621. elif p_m < self.probs[1]:
  622. mode_p = random.random()
  623. mode_id = self.judge_probs_range(mode_p, self.probsr)
  624. sample['image'] = self.apply_im(sample['image'], mode_id, False)
  625. if 'image2' in sample:
  626. sample['image2'] = self.apply_im(sample['image2'], mode_id,
  627. False)
  628. if 'mask' in sample:
  629. sample['mask'] = self.apply_mask(sample['mask'], mode_id, False)
  630. if 'aux_masks' in sample:
  631. sample['aux_masks'] = [
  632. self.apply_mask(aux_mask, mode_id, False)
  633. for aux_mask in sample['aux_masks']
  634. ]
  635. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  636. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], mode_id,
  637. False)
  638. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  639. sample['gt_poly'] = self.apply_segm(sample['gt_poly'], mode_id,
  640. False)
  641. return sample
  642. class RandomHorizontalFlip(Transform):
  643. """
  644. Randomly flip the input horizontally.
  645. Args:
  646. prob (float, optional): Probability of flipping the input. Defaults to .5.
  647. """
  648. def __init__(self, prob=0.5):
  649. super(RandomHorizontalFlip, self).__init__()
  650. self.prob = prob
  651. def apply_im(self, image):
  652. image = horizontal_flip(image)
  653. return image
  654. def apply_mask(self, mask):
  655. mask = horizontal_flip(mask)
  656. return mask
  657. def apply_bbox(self, bbox, width):
  658. oldx1 = bbox[:, 0].copy()
  659. oldx2 = bbox[:, 2].copy()
  660. bbox[:, 0] = width - oldx2
  661. bbox[:, 2] = width - oldx1
  662. return bbox
  663. def apply_segm(self, segms, height, width):
  664. flipped_segms = []
  665. for segm in segms:
  666. if is_poly(segm):
  667. # Polygon format
  668. flipped_segms.append(
  669. [horizontal_flip_poly(poly, width) for poly in segm])
  670. else:
  671. # RLE format
  672. flipped_segms.append(horizontal_flip_rle(segm, height, width))
  673. return flipped_segms
  674. def apply(self, sample):
  675. if random.random() < self.prob:
  676. im_h, im_w = sample['image'].shape[:2]
  677. sample['image'] = self.apply_im(sample['image'])
  678. if 'image2' in sample:
  679. sample['image2'] = self.apply_im(sample['image2'])
  680. if 'mask' in sample:
  681. sample['mask'] = self.apply_mask(sample['mask'])
  682. if 'aux_masks' in sample:
  683. sample['aux_masks'] = list(
  684. map(self.apply_mask, sample['aux_masks']))
  685. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  686. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], im_w)
  687. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  688. sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_h,
  689. im_w)
  690. return sample
  691. class RandomVerticalFlip(Transform):
  692. """
  693. Randomly flip the input vertically.
  694. Args:
  695. prob (float, optional): Probability of flipping the input. Defaults to .5.
  696. """
  697. def __init__(self, prob=0.5):
  698. super(RandomVerticalFlip, self).__init__()
  699. self.prob = prob
  700. def apply_im(self, image):
  701. image = vertical_flip(image)
  702. return image
  703. def apply_mask(self, mask):
  704. mask = vertical_flip(mask)
  705. return mask
  706. def apply_bbox(self, bbox, height):
  707. oldy1 = bbox[:, 1].copy()
  708. oldy2 = bbox[:, 3].copy()
  709. bbox[:, 0] = height - oldy2
  710. bbox[:, 2] = height - oldy1
  711. return bbox
  712. def apply_segm(self, segms, height, width):
  713. flipped_segms = []
  714. for segm in segms:
  715. if is_poly(segm):
  716. # Polygon format
  717. flipped_segms.append(
  718. [vertical_flip_poly(poly, height) for poly in segm])
  719. else:
  720. # RLE format
  721. flipped_segms.append(vertical_flip_rle(segm, height, width))
  722. return flipped_segms
  723. def apply(self, sample):
  724. if random.random() < self.prob:
  725. im_h, im_w = sample['image'].shape[:2]
  726. sample['image'] = self.apply_im(sample['image'])
  727. if 'image2' in sample:
  728. sample['image2'] = self.apply_im(sample['image2'])
  729. if 'mask' in sample:
  730. sample['mask'] = self.apply_mask(sample['mask'])
  731. if 'aux_masks' in sample:
  732. sample['aux_masks'] = list(
  733. map(self.apply_mask, sample['aux_masks']))
  734. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  735. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], im_h)
  736. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  737. sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_h,
  738. im_w)
  739. return sample
  740. class Normalize(Transform):
  741. """
  742. Apply normalization to the input image(s). The normalization steps are:
  743. 1. im = (im - min_value) * 1 / (max_value - min_value)
  744. 2. im = im - mean
  745. 3. im = im / std
  746. Args:
  747. mean (list[float] | tuple[float], optional): Mean of input image(s).
  748. Defaults to [0.485, 0.456, 0.406].
  749. std (list[float] | tuple[float], optional): Standard deviation of input
  750. image(s). Defaults to [0.229, 0.224, 0.225].
  751. min_val (list[float] | tuple[float], optional): Minimum value of input
  752. image(s). Defaults to [0, 0, 0, ].
  753. max_val (list[float] | tuple[float], optional): Max value of input image(s).
  754. Defaults to [255., 255., 255.].
  755. """
  756. def __init__(self,
  757. mean=[0.485, 0.456, 0.406],
  758. std=[0.229, 0.224, 0.225],
  759. min_val=None,
  760. max_val=None):
  761. super(Normalize, self).__init__()
  762. channel = len(mean)
  763. if min_val is None:
  764. min_val = [0] * channel
  765. if max_val is None:
  766. max_val = [255.] * channel
  767. from functools import reduce
  768. if reduce(lambda x, y: x * y, std) == 0:
  769. raise ValueError(
  770. "`std` should not contain 0, but received is {}.".format(std))
  771. if reduce(lambda x, y: x * y,
  772. [a - b for a, b in zip(max_val, min_val)]) == 0:
  773. raise ValueError(
  774. "(`max_val` - `min_val`) should not contain 0, but received is {}.".
  775. format((np.asarray(max_val) - np.asarray(min_val)).tolist()))
  776. self.mean = mean
  777. self.std = std
  778. self.min_val = min_val
  779. self.max_val = max_val
  780. def apply_im(self, image):
  781. image = image.astype(np.float32)
  782. mean = np.asarray(
  783. self.mean, dtype=np.float32)[np.newaxis, np.newaxis, :]
  784. std = np.asarray(self.std, dtype=np.float32)[np.newaxis, np.newaxis, :]
  785. image = normalize(image, mean, std, self.min_val, self.max_val)
  786. return image
  787. def apply(self, sample):
  788. sample['image'] = self.apply_im(sample['image'])
  789. if 'image2' in sample:
  790. sample['image2'] = self.apply_im(sample['image2'])
  791. return sample
  792. class CenterCrop(Transform):
  793. """
  794. Crop the input image(s) at the center.
  795. 1. Locate the center of the image.
  796. 2. Crop the image.
  797. Args:
  798. crop_size (int, optional): Target size of the cropped image(s).
  799. Defaults to 224.
  800. """
  801. def __init__(self, crop_size=224):
  802. super(CenterCrop, self).__init__()
  803. self.crop_size = crop_size
  804. def apply_im(self, image):
  805. image = center_crop(image, self.crop_size)
  806. return image
  807. def apply_mask(self, mask):
  808. mask = center_crop(mask, self.crop_size)
  809. return mask
  810. def apply(self, sample):
  811. sample['image'] = self.apply_im(sample['image'])
  812. if 'image2' in sample:
  813. sample['image2'] = self.apply_im(sample['image2'])
  814. if 'mask' in sample:
  815. sample['mask'] = self.apply_mask(sample['mask'])
  816. if 'aux_masks' in sample:
  817. sample['aux_masks'] = list(
  818. map(self.apply_mask, sample['aux_masks']))
  819. return sample
  820. class RandomCrop(Transform):
  821. """
  822. Randomly crop the input.
  823. 1. Compute the height and width of cropped area according to `aspect_ratio` and
  824. `scaling`.
  825. 2. Locate the upper left corner of cropped area randomly.
  826. 3. Crop the image(s).
  827. 4. Resize the cropped area to `crop_size` x `crop_size`.
  828. Args:
  829. crop_size (int | list[int] | tuple[int]): Target size of the cropped area. If
  830. None, the cropped area will not be resized. Defaults to None.
  831. aspect_ratio (list[float], optional): Aspect ratio of cropped region in
  832. [min, max] format. Defaults to [.5, 2.].
  833. thresholds (list[float], optional): Iou thresholds to decide a valid bbox
  834. crop. Defaults to [.0, .1, .3, .5, .7, .9].
  835. scaling (list[float], optional): Ratio between the cropped region and the
  836. original image in [min, max] format. Defaults to [.3, 1.].
  837. num_attempts (int, optional): Max number of tries before giving up.
  838. Defaults to 50.
  839. allow_no_crop (bool, optional): Whether returning without doing crop is
  840. allowed. Defaults to True.
  841. cover_all_box (bool, optional): Whether to ensure all bboxes be covered in
  842. the final crop. Defaults to False.
  843. """
  844. def __init__(self,
  845. crop_size=None,
  846. aspect_ratio=[.5, 2.],
  847. thresholds=[.0, .1, .3, .5, .7, .9],
  848. scaling=[.3, 1.],
  849. num_attempts=50,
  850. allow_no_crop=True,
  851. cover_all_box=False):
  852. super(RandomCrop, self).__init__()
  853. self.crop_size = crop_size
  854. self.aspect_ratio = aspect_ratio
  855. self.thresholds = thresholds
  856. self.scaling = scaling
  857. self.num_attempts = num_attempts
  858. self.allow_no_crop = allow_no_crop
  859. self.cover_all_box = cover_all_box
  860. def _generate_crop_info(self, sample):
  861. im_h, im_w = sample['image'].shape[:2]
  862. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  863. thresholds = self.thresholds
  864. if self.allow_no_crop:
  865. thresholds.append('no_crop')
  866. np.random.shuffle(thresholds)
  867. for thresh in thresholds:
  868. if thresh == 'no_crop':
  869. return None
  870. for i in range(self.num_attempts):
  871. crop_box = self._get_crop_box(im_h, im_w)
  872. if crop_box is None:
  873. continue
  874. iou = self._iou_matrix(
  875. sample['gt_bbox'],
  876. np.array(
  877. [crop_box], dtype=np.float32))
  878. if iou.max() < thresh:
  879. continue
  880. if self.cover_all_box and iou.min() < thresh:
  881. continue
  882. cropped_box, valid_ids = self._crop_box_with_center_constraint(
  883. sample['gt_bbox'], np.array(
  884. crop_box, dtype=np.float32))
  885. if valid_ids.size > 0:
  886. return crop_box, cropped_box, valid_ids
  887. else:
  888. for i in range(self.num_attempts):
  889. crop_box = self._get_crop_box(im_h, im_w)
  890. if crop_box is None:
  891. continue
  892. return crop_box, None, None
  893. return None
  894. def _get_crop_box(self, im_h, im_w):
  895. scale = np.random.uniform(*self.scaling)
  896. if self.aspect_ratio is not None:
  897. min_ar, max_ar = self.aspect_ratio
  898. aspect_ratio = np.random.uniform(
  899. max(min_ar, scale**2), min(max_ar, scale**-2))
  900. h_scale = scale / np.sqrt(aspect_ratio)
  901. w_scale = scale * np.sqrt(aspect_ratio)
  902. else:
  903. h_scale = np.random.uniform(*self.scaling)
  904. w_scale = np.random.uniform(*self.scaling)
  905. crop_h = im_h * h_scale
  906. crop_w = im_w * w_scale
  907. if self.aspect_ratio is None:
  908. if crop_h / crop_w < 0.5 or crop_h / crop_w > 2.0:
  909. return None
  910. crop_h = int(crop_h)
  911. crop_w = int(crop_w)
  912. crop_y = np.random.randint(0, im_h - crop_h)
  913. crop_x = np.random.randint(0, im_w - crop_w)
  914. return [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h]
  915. def _iou_matrix(self, a, b):
  916. tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2])
  917. br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
  918. area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2)
  919. area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
  920. area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
  921. area_o = (area_a[:, np.newaxis] + area_b - area_i)
  922. return area_i / (area_o + 1e-10)
  923. def _crop_box_with_center_constraint(self, box, crop):
  924. cropped_box = box.copy()
  925. cropped_box[:, :2] = np.maximum(box[:, :2], crop[:2])
  926. cropped_box[:, 2:] = np.minimum(box[:, 2:], crop[2:])
  927. cropped_box[:, :2] -= crop[:2]
  928. cropped_box[:, 2:] -= crop[:2]
  929. centers = (box[:, :2] + box[:, 2:]) / 2
  930. valid = np.logical_and(crop[:2] <= centers,
  931. centers < crop[2:]).all(axis=1)
  932. valid = np.logical_and(
  933. valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1))
  934. return cropped_box, np.where(valid)[0]
  935. def _crop_segm(self, segms, valid_ids, crop, height, width):
  936. crop_segms = []
  937. for id in valid_ids:
  938. segm = segms[id]
  939. if is_poly(segm):
  940. # Polygon format
  941. crop_segms.append(crop_poly(segm, crop))
  942. else:
  943. # RLE format
  944. crop_segms.append(crop_rle(segm, crop, height, width))
  945. return crop_segms
  946. def apply_im(self, image, crop):
  947. x1, y1, x2, y2 = crop
  948. return image[y1:y2, x1:x2, :]
  949. def apply_mask(self, mask, crop):
  950. x1, y1, x2, y2 = crop
  951. return mask[y1:y2, x1:x2, ...]
  952. def apply(self, sample):
  953. crop_info = self._generate_crop_info(sample)
  954. if crop_info is not None:
  955. crop_box, cropped_box, valid_ids = crop_info
  956. im_h, im_w = sample['image'].shape[:2]
  957. sample['image'] = self.apply_im(sample['image'], crop_box)
  958. if 'image2' in sample:
  959. sample['image2'] = self.apply_im(sample['image2'], crop_box)
  960. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  961. crop_polys = self._crop_segm(
  962. sample['gt_poly'],
  963. valid_ids,
  964. np.array(
  965. crop_box, dtype=np.int64),
  966. im_h,
  967. im_w)
  968. if [] in crop_polys:
  969. delete_id = list()
  970. valid_polys = list()
  971. for idx, poly in enumerate(crop_polys):
  972. if not crop_poly:
  973. delete_id.append(idx)
  974. else:
  975. valid_polys.append(poly)
  976. valid_ids = np.delete(valid_ids, delete_id)
  977. if not valid_polys:
  978. return sample
  979. sample['gt_poly'] = valid_polys
  980. else:
  981. sample['gt_poly'] = crop_polys
  982. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  983. sample['gt_bbox'] = np.take(cropped_box, valid_ids, axis=0)
  984. sample['gt_class'] = np.take(
  985. sample['gt_class'], valid_ids, axis=0)
  986. if 'gt_score' in sample:
  987. sample['gt_score'] = np.take(
  988. sample['gt_score'], valid_ids, axis=0)
  989. if 'is_crowd' in sample:
  990. sample['is_crowd'] = np.take(
  991. sample['is_crowd'], valid_ids, axis=0)
  992. if 'mask' in sample:
  993. sample['mask'] = self.apply_mask(sample['mask'], crop_box)
  994. if 'aux_masks' in sample:
  995. sample['aux_masks'] = list(
  996. map(partial(
  997. self.apply_mask, crop=crop_box),
  998. sample['aux_masks']))
  999. if self.crop_size is not None:
  1000. sample = Resize(self.crop_size)(sample)
  1001. return sample
  1002. class RandomScaleAspect(Transform):
  1003. """
  1004. Crop input image(s) and resize back to original sizes.
  1005. Args:
  1006. min_scale (float): Minimum ratio between the cropped region and the original
  1007. image. If 0, image(s) will not be cropped. Defaults to .5.
  1008. aspect_ratio (float): Aspect ratio of cropped region. Defaults to .33.
  1009. """
  1010. def __init__(self, min_scale=0.5, aspect_ratio=0.33):
  1011. super(RandomScaleAspect, self).__init__()
  1012. self.min_scale = min_scale
  1013. self.aspect_ratio = aspect_ratio
  1014. def apply(self, sample):
  1015. if self.min_scale != 0 and self.aspect_ratio != 0:
  1016. img_height, img_width = sample['image'].shape[:2]
  1017. sample = RandomCrop(
  1018. crop_size=(img_height, img_width),
  1019. aspect_ratio=[self.aspect_ratio, 1. / self.aspect_ratio],
  1020. scaling=[self.min_scale, 1.],
  1021. num_attempts=10,
  1022. allow_no_crop=False)(sample)
  1023. return sample
  1024. class RandomExpand(Transform):
  1025. """
  1026. Randomly expand the input by padding according to random offsets.
  1027. Args:
  1028. upper_ratio (float, optional): Maximum ratio to which the original image
  1029. is expanded. Defaults to 4..
  1030. prob (float, optional): Probability of apply expanding. Defaults to .5.
  1031. im_padding_value (list[float] | tuple[float], optional): RGB filling value
  1032. for the image. Defaults to (127.5, 127.5, 127.5).
  1033. label_padding_value (int, optional): Filling value for the mask.
  1034. Defaults to 255.
  1035. See Also:
  1036. paddlers.transforms.Pad
  1037. """
  1038. def __init__(self,
  1039. upper_ratio=4.,
  1040. prob=.5,
  1041. im_padding_value=127.5,
  1042. label_padding_value=255):
  1043. super(RandomExpand, self).__init__()
  1044. assert upper_ratio > 1.01, "`upper_ratio` must be larger than 1.01."
  1045. self.upper_ratio = upper_ratio
  1046. self.prob = prob
  1047. assert isinstance(im_padding_value, (Number, Sequence)), \
  1048. "Value to fill must be either float or sequence."
  1049. self.im_padding_value = im_padding_value
  1050. self.label_padding_value = label_padding_value
  1051. def apply(self, sample):
  1052. if random.random() < self.prob:
  1053. im_h, im_w = sample['image'].shape[:2]
  1054. ratio = np.random.uniform(1., self.upper_ratio)
  1055. h = int(im_h * ratio)
  1056. w = int(im_w * ratio)
  1057. if h > im_h and w > im_w:
  1058. y = np.random.randint(0, h - im_h)
  1059. x = np.random.randint(0, w - im_w)
  1060. target_size = (h, w)
  1061. offsets = (x, y)
  1062. sample = Pad(
  1063. target_size=target_size,
  1064. pad_mode=-1,
  1065. offsets=offsets,
  1066. im_padding_value=self.im_padding_value,
  1067. label_padding_value=self.label_padding_value)(sample)
  1068. return sample
  1069. class Pad(Transform):
  1070. def __init__(self,
  1071. target_size=None,
  1072. pad_mode=0,
  1073. offsets=None,
  1074. im_padding_value=127.5,
  1075. label_padding_value=255,
  1076. size_divisor=32):
  1077. """
  1078. Pad image to a specified size or multiple of `size_divisor`.
  1079. Args:
  1080. target_size (list[int] | tuple[int], optional): Image target size, if None, pad to
  1081. multiple of size_divisor. Defaults to None.
  1082. pad_mode (int, optional): Pad mode. Currently only four modes are supported:
  1083. [-1, 0, 1, 2]. if -1, use specified offsets. If 0, only pad to right and bottom
  1084. If 1, pad according to center. If 2, only pad left and top. Defaults to 0.
  1085. im_padding_value (list[float] | tuple[float]): RGB value of padded area.
  1086. Defaults to (127.5, 127.5, 127.5).
  1087. label_padding_value (int, optional): Filling value for the mask.
  1088. Defaults to 255.
  1089. size_divisor (int): Image width and height after padding will be a multiple of
  1090. `size_divisor`.
  1091. """
  1092. super(Pad, self).__init__()
  1093. if isinstance(target_size, (list, tuple)):
  1094. if len(target_size) != 2:
  1095. raise ValueError(
  1096. "`target_size` should contain 2 elements, but it is {}.".
  1097. format(target_size))
  1098. if isinstance(target_size, int):
  1099. target_size = [target_size] * 2
  1100. assert pad_mode in [
  1101. -1, 0, 1, 2
  1102. ], "Currently only four modes are supported: [-1, 0, 1, 2]."
  1103. if pad_mode == -1:
  1104. assert offsets, "if `pad_mode` is -1, `offsets` should not be None."
  1105. self.target_size = target_size
  1106. self.size_divisor = size_divisor
  1107. self.pad_mode = pad_mode
  1108. self.offsets = offsets
  1109. self.im_padding_value = im_padding_value
  1110. self.label_padding_value = label_padding_value
  1111. def apply_im(self, image, offsets, target_size):
  1112. x, y = offsets
  1113. h, w = target_size
  1114. im_h, im_w, channel = image.shape[:3]
  1115. canvas = np.ones((h, w, channel), dtype=np.float32)
  1116. canvas *= np.array(self.im_padding_value, dtype=np.float32)
  1117. canvas[y:y + im_h, x:x + im_w, :] = image.astype(np.float32)
  1118. return canvas
  1119. def apply_mask(self, mask, offsets, target_size):
  1120. x, y = offsets
  1121. im_h, im_w = mask.shape[:2]
  1122. h, w = target_size
  1123. canvas = np.ones((h, w), dtype=np.float32)
  1124. canvas *= np.array(self.label_padding_value, dtype=np.float32)
  1125. canvas[y:y + im_h, x:x + im_w] = mask.astype(np.float32)
  1126. return canvas
  1127. def apply_bbox(self, bbox, offsets):
  1128. return bbox + np.array(offsets * 2, dtype=np.float32)
  1129. def apply_segm(self, segms, offsets, im_size, size):
  1130. x, y = offsets
  1131. height, width = im_size
  1132. h, w = size
  1133. expanded_segms = []
  1134. for segm in segms:
  1135. if is_poly(segm):
  1136. # Polygon format
  1137. expanded_segms.append(
  1138. [expand_poly(poly, x, y) for poly in segm])
  1139. else:
  1140. # RLE format
  1141. expanded_segms.append(
  1142. expand_rle(segm, x, y, height, width, h, w))
  1143. return expanded_segms
  1144. def apply(self, sample):
  1145. im_h, im_w = sample['image'].shape[:2]
  1146. if self.target_size:
  1147. h, w = self.target_size
  1148. assert (
  1149. im_h <= h and im_w <= w
  1150. ), 'target size ({}, {}) cannot be less than image size ({}, {})'\
  1151. .format(h, w, im_h, im_w)
  1152. else:
  1153. h = (np.ceil(im_h / self.size_divisor) *
  1154. self.size_divisor).astype(int)
  1155. w = (np.ceil(im_w / self.size_divisor) *
  1156. self.size_divisor).astype(int)
  1157. if h == im_h and w == im_w:
  1158. return sample
  1159. if self.pad_mode == -1:
  1160. offsets = self.offsets
  1161. elif self.pad_mode == 0:
  1162. offsets = [0, 0]
  1163. elif self.pad_mode == 1:
  1164. offsets = [(w - im_w) // 2, (h - im_h) // 2]
  1165. else:
  1166. offsets = [w - im_w, h - im_h]
  1167. sample['image'] = self.apply_im(sample['image'], offsets, (h, w))
  1168. if 'image2' in sample:
  1169. sample['image2'] = self.apply_im(sample['image2'], offsets, (h, w))
  1170. if 'mask' in sample:
  1171. sample['mask'] = self.apply_mask(sample['mask'], offsets, (h, w))
  1172. if 'aux_masks' in sample:
  1173. sample['aux_masks'] = list(
  1174. map(partial(
  1175. self.apply_mask, offsets=offsets, target_size=(h, w)),
  1176. sample['aux_masks']))
  1177. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  1178. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], offsets)
  1179. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  1180. sample['gt_poly'] = self.apply_segm(
  1181. sample['gt_poly'], offsets, im_size=[im_h, im_w], size=[h, w])
  1182. return sample
  1183. class MixupImage(Transform):
  1184. def __init__(self, alpha=1.5, beta=1.5, mixup_epoch=-1):
  1185. """
  1186. Mixup two images and their gt_bbbox/gt_score.
  1187. Args:
  1188. alpha (float, optional): Alpha parameter of beta distribution.
  1189. Defaults to 1.5.
  1190. beta (float, optional): Beta parameter of beta distribution.
  1191. Defaults to 1.5.
  1192. """
  1193. super(MixupImage, self).__init__()
  1194. if alpha <= 0.0:
  1195. raise ValueError("`alpha` should be positive in MixupImage.")
  1196. if beta <= 0.0:
  1197. raise ValueError("`beta` should be positive in MixupImage.")
  1198. self.alpha = alpha
  1199. self.beta = beta
  1200. self.mixup_epoch = mixup_epoch
  1201. def apply_im(self, image1, image2, factor):
  1202. h = max(image1.shape[0], image2.shape[0])
  1203. w = max(image1.shape[1], image2.shape[1])
  1204. img = np.zeros((h, w, image1.shape[2]), 'float32')
  1205. img[:image1.shape[0], :image1.shape[1], :] = \
  1206. image1.astype('float32') * factor
  1207. img[:image2.shape[0], :image2.shape[1], :] += \
  1208. image2.astype('float32') * (1.0 - factor)
  1209. return img.astype('uint8')
  1210. def __call__(self, sample):
  1211. if not isinstance(sample, Sequence):
  1212. return sample
  1213. assert len(sample) == 2, 'mixup need two samples'
  1214. factor = np.random.beta(self.alpha, self.beta)
  1215. factor = max(0.0, min(1.0, factor))
  1216. if factor >= 1.0:
  1217. return sample[0]
  1218. if factor <= 0.0:
  1219. return sample[1]
  1220. image = self.apply_im(sample[0]['image'], sample[1]['image'], factor)
  1221. result = copy.deepcopy(sample[0])
  1222. result['image'] = image
  1223. # Apply bbox and score
  1224. if 'gt_bbox' in sample[0]:
  1225. gt_bbox1 = sample[0]['gt_bbox']
  1226. gt_bbox2 = sample[1]['gt_bbox']
  1227. gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0)
  1228. result['gt_bbox'] = gt_bbox
  1229. if 'gt_poly' in sample[0]:
  1230. gt_poly1 = sample[0]['gt_poly']
  1231. gt_poly2 = sample[1]['gt_poly']
  1232. gt_poly = gt_poly1 + gt_poly2
  1233. result['gt_poly'] = gt_poly
  1234. if 'gt_class' in sample[0]:
  1235. gt_class1 = sample[0]['gt_class']
  1236. gt_class2 = sample[1]['gt_class']
  1237. gt_class = np.concatenate((gt_class1, gt_class2), axis=0)
  1238. result['gt_class'] = gt_class
  1239. gt_score1 = np.ones_like(sample[0]['gt_class'])
  1240. gt_score2 = np.ones_like(sample[1]['gt_class'])
  1241. gt_score = np.concatenate(
  1242. (gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
  1243. result['gt_score'] = gt_score
  1244. if 'is_crowd' in sample[0]:
  1245. is_crowd1 = sample[0]['is_crowd']
  1246. is_crowd2 = sample[1]['is_crowd']
  1247. is_crowd = np.concatenate((is_crowd1, is_crowd2), axis=0)
  1248. result['is_crowd'] = is_crowd
  1249. if 'difficult' in sample[0]:
  1250. is_difficult1 = sample[0]['difficult']
  1251. is_difficult2 = sample[1]['difficult']
  1252. is_difficult = np.concatenate(
  1253. (is_difficult1, is_difficult2), axis=0)
  1254. result['difficult'] = is_difficult
  1255. return result
  1256. class RandomDistort(Transform):
  1257. """
  1258. Random color distortion.
  1259. Args:
  1260. brightness_range (float, optional): Range of brightness distortion.
  1261. Defaults to .5.
  1262. brightness_prob (float, optional): Probability of brightness distortion.
  1263. Defaults to .5.
  1264. contrast_range (float, optional): Range of contrast distortion.
  1265. Defaults to .5.
  1266. contrast_prob (float, optional): Probability of contrast distortion.
  1267. Defaults to .5.
  1268. saturation_range (float, optional): Range of saturation distortion.
  1269. Defaults to .5.
  1270. saturation_prob (float, optional): Probability of saturation distortion.
  1271. Defaults to .5.
  1272. hue_range (float, optional): Range of hue distortion. Defaults to .5.
  1273. hue_prob (float, optional): Probability of hue distortion. Defaults to .5.
  1274. random_apply (bool, optional): Apply the transformation in random (yolo) or
  1275. fixed (SSD) order. Defaults to True.
  1276. count (int, optional): Number of distortions to apply. Defaults to 4.
  1277. shuffle_channel (bool, optional): Whether to swap channels randomly.
  1278. Defaults to False.
  1279. """
  1280. def __init__(self,
  1281. brightness_range=0.5,
  1282. brightness_prob=0.5,
  1283. contrast_range=0.5,
  1284. contrast_prob=0.5,
  1285. saturation_range=0.5,
  1286. saturation_prob=0.5,
  1287. hue_range=18,
  1288. hue_prob=0.5,
  1289. random_apply=True,
  1290. count=4,
  1291. shuffle_channel=False):
  1292. super(RandomDistort, self).__init__()
  1293. self.brightness_range = [1 - brightness_range, 1 + brightness_range]
  1294. self.brightness_prob = brightness_prob
  1295. self.contrast_range = [1 - contrast_range, 1 + contrast_range]
  1296. self.contrast_prob = contrast_prob
  1297. self.saturation_range = [1 - saturation_range, 1 + saturation_range]
  1298. self.saturation_prob = saturation_prob
  1299. self.hue_range = [1 - hue_range, 1 + hue_range]
  1300. self.hue_prob = hue_prob
  1301. self.random_apply = random_apply
  1302. self.count = count
  1303. self.shuffle_channel = shuffle_channel
  1304. def apply_hue(self, image):
  1305. low, high = self.hue_range
  1306. if np.random.uniform(0., 1.) < self.hue_prob:
  1307. return image
  1308. # It works, but the result differs from HSV version.
  1309. delta = np.random.uniform(low, high)
  1310. u = np.cos(delta * np.pi)
  1311. w = np.sin(delta * np.pi)
  1312. bt = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
  1313. tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.321],
  1314. [0.211, -0.523, 0.311]])
  1315. ityiq = np.array([[1.0, 0.956, 0.621], [1.0, -0.272, -0.647],
  1316. [1.0, -1.107, 1.705]])
  1317. t = np.dot(np.dot(ityiq, bt), tyiq).T
  1318. res_list = []
  1319. channel = image.shape[2]
  1320. for i in range(channel // 3):
  1321. sub_img = image[:, :, 3 * i:3 * (i + 1)]
  1322. sub_img = sub_img.astype(np.float32)
  1323. sub_img = np.dot(image, t)
  1324. res_list.append(sub_img)
  1325. if channel % 3 != 0:
  1326. i = channel % 3
  1327. res_list.append(image[:, :, -i:])
  1328. return np.concatenate(res_list, axis=2)
  1329. def apply_saturation(self, image):
  1330. low, high = self.saturation_range
  1331. delta = np.random.uniform(low, high)
  1332. if np.random.uniform(0., 1.) < self.saturation_prob:
  1333. return image
  1334. res_list = []
  1335. channel = image.shape[2]
  1336. for i in range(channel // 3):
  1337. sub_img = image[:, :, 3 * i:3 * (i + 1)]
  1338. sub_img = sub_img.astype(np.float32)
  1339. # It works, but the result differs from HSV version.
  1340. gray = sub_img * np.array(
  1341. [[[0.299, 0.587, 0.114]]], dtype=np.float32)
  1342. gray = gray.sum(axis=2, keepdims=True)
  1343. gray *= (1.0 - delta)
  1344. sub_img *= delta
  1345. sub_img += gray
  1346. res_list.append(sub_img)
  1347. if channel % 3 != 0:
  1348. i = channel % 3
  1349. res_list.append(image[:, :, -i:])
  1350. return np.concatenate(res_list, axis=2)
  1351. def apply_contrast(self, image):
  1352. low, high = self.contrast_range
  1353. if np.random.uniform(0., 1.) < self.contrast_prob:
  1354. return image
  1355. delta = np.random.uniform(low, high)
  1356. image = image.astype(np.float32)
  1357. image *= delta
  1358. return image
  1359. def apply_brightness(self, image):
  1360. low, high = self.brightness_range
  1361. if np.random.uniform(0., 1.) < self.brightness_prob:
  1362. return image
  1363. delta = np.random.uniform(low, high)
  1364. image = image.astype(np.float32)
  1365. image += delta
  1366. return image
  1367. def apply(self, sample):
  1368. if self.random_apply:
  1369. functions = [
  1370. self.apply_brightness, self.apply_contrast,
  1371. self.apply_saturation, self.apply_hue
  1372. ]
  1373. distortions = np.random.permutation(functions)[:self.count]
  1374. for func in distortions:
  1375. sample['image'] = func(sample['image'])
  1376. if 'image2' in sample:
  1377. sample['image2'] = func(sample['image2'])
  1378. return sample
  1379. sample['image'] = self.apply_brightness(sample['image'])
  1380. if 'image2' in sample:
  1381. sample['image2'] = self.apply_brightness(sample['image2'])
  1382. mode = np.random.randint(0, 2)
  1383. if mode:
  1384. sample['image'] = self.apply_contrast(sample['image'])
  1385. if 'image2' in sample:
  1386. sample['image2'] = self.apply_contrast(sample['image2'])
  1387. sample['image'] = self.apply_saturation(sample['image'])
  1388. sample['image'] = self.apply_hue(sample['image'])
  1389. if 'image2' in sample:
  1390. sample['image2'] = self.apply_saturation(sample['image2'])
  1391. sample['image2'] = self.apply_hue(sample['image2'])
  1392. if not mode:
  1393. sample['image'] = self.apply_contrast(sample['image'])
  1394. if 'image2' in sample:
  1395. sample['image2'] = self.apply_contrast(sample['image2'])
  1396. if self.shuffle_channel:
  1397. if np.random.randint(0, 2):
  1398. sample['image'] = sample['image'][..., np.random.permutation(3)]
  1399. if 'image2' in sample:
  1400. sample['image2'] = sample['image2'][
  1401. ..., np.random.permutation(3)]
  1402. return sample
  1403. class RandomBlur(Transform):
  1404. """
  1405. Randomly blur input image(s).
  1406. Args:
  1407. prob (float): Probability of blurring.
  1408. """
  1409. def __init__(self, prob=0.1):
  1410. super(RandomBlur, self).__init__()
  1411. self.prob = prob
  1412. def apply_im(self, image, radius):
  1413. image = cv2.GaussianBlur(image, (radius, radius), 0, 0)
  1414. return image
  1415. def apply(self, sample):
  1416. if self.prob <= 0:
  1417. n = 0
  1418. elif self.prob >= 1:
  1419. n = 1
  1420. else:
  1421. n = int(1.0 / self.prob)
  1422. if n > 0:
  1423. if np.random.randint(0, n) == 0:
  1424. radius = np.random.randint(3, 10)
  1425. if radius % 2 != 1:
  1426. radius = radius + 1
  1427. if radius > 9:
  1428. radius = 9
  1429. sample['image'] = self.apply_im(sample['image'], radius)
  1430. if 'image2' in sample:
  1431. sample['image2'] = self.apply_im(sample['image2'], radius)
  1432. return sample
  1433. class Dehaze(Transform):
  1434. """
  1435. Dehaze input image(s).
  1436. Args:
  1437. gamma (bool, optional): Use gamma correction or not. Defaults to False.
  1438. """
  1439. def __init__(self, gamma=False):
  1440. super(Dehaze, self).__init__()
  1441. self.gamma = gamma
  1442. def apply_im(self, image):
  1443. image = dehaze(image, self.gamma)
  1444. return image
  1445. def apply(self, sample):
  1446. sample['image'] = self.apply_im(sample['image'])
  1447. if 'image2' in sample:
  1448. sample['image2'] = self.apply_im(sample['image2'])
  1449. return sample
  1450. class ReduceDim(Transform):
  1451. """
  1452. Use PCA to reduce the dimension of input image(s).
  1453. Args:
  1454. joblib_path (str): Path of *.joblib file of PCA.
  1455. """
  1456. def __init__(self, joblib_path):
  1457. super(ReduceDim, self).__init__()
  1458. ext = joblib_path.split(".")[-1]
  1459. if ext != "joblib":
  1460. raise ValueError("`joblib_path` must be *.joblib, not *.{}.".format(
  1461. ext))
  1462. self.pca = load(joblib_path)
  1463. def apply_im(self, image):
  1464. H, W, C = image.shape
  1465. n_im = np.reshape(image, (-1, C))
  1466. im_pca = self.pca.transform(n_im)
  1467. result = np.reshape(im_pca, (H, W, -1))
  1468. return result
  1469. def apply(self, sample):
  1470. sample['image'] = self.apply_im(sample['image'])
  1471. if 'image2' in sample:
  1472. sample['image2'] = self.apply_im(sample['image2'])
  1473. return sample
  1474. class SelectBand(Transform):
  1475. """
  1476. Select a set of bands of input image(s).
  1477. Args:
  1478. band_list (list, optional): Bands to select (band index starts from 1).
  1479. Defaults to [1, 2, 3].
  1480. """
  1481. def __init__(self, band_list=[1, 2, 3]):
  1482. super(SelectBand, self).__init__()
  1483. self.band_list = band_list
  1484. def apply_im(self, image):
  1485. image = select_bands(image, self.band_list)
  1486. return image
  1487. def apply(self, sample):
  1488. sample['image'] = self.apply_im(sample['image'])
  1489. if 'image2' in sample:
  1490. sample['image2'] = self.apply_im(sample['image2'])
  1491. return sample
  1492. class _PadBox(Transform):
  1493. def __init__(self, num_max_boxes=50):
  1494. """
  1495. Pad zeros to bboxes if number of bboxes is less than `num_max_boxes`.
  1496. Args:
  1497. num_max_boxes (int, optional): Max number of bboxes. Defaults to 50.
  1498. """
  1499. self.num_max_boxes = num_max_boxes
  1500. super(_PadBox, self).__init__()
  1501. def apply(self, sample):
  1502. gt_num = min(self.num_max_boxes, len(sample['gt_bbox']))
  1503. num_max = self.num_max_boxes
  1504. pad_bbox = np.zeros((num_max, 4), dtype=np.float32)
  1505. if gt_num > 0:
  1506. pad_bbox[:gt_num, :] = sample['gt_bbox'][:gt_num, :]
  1507. sample['gt_bbox'] = pad_bbox
  1508. if 'gt_class' in sample:
  1509. pad_class = np.zeros((num_max, ), dtype=np.int32)
  1510. if gt_num > 0:
  1511. pad_class[:gt_num] = sample['gt_class'][:gt_num, 0]
  1512. sample['gt_class'] = pad_class
  1513. if 'gt_score' in sample:
  1514. pad_score = np.zeros((num_max, ), dtype=np.float32)
  1515. if gt_num > 0:
  1516. pad_score[:gt_num] = sample['gt_score'][:gt_num, 0]
  1517. sample['gt_score'] = pad_score
  1518. # In training, for example in op ExpandImage,
  1519. # bbox and gt_class are expanded, but difficult is not,
  1520. # so judge by its length.
  1521. if 'difficult' in sample:
  1522. pad_diff = np.zeros((num_max, ), dtype=np.int32)
  1523. if gt_num > 0:
  1524. pad_diff[:gt_num] = sample['difficult'][:gt_num, 0]
  1525. sample['difficult'] = pad_diff
  1526. if 'is_crowd' in sample:
  1527. pad_crowd = np.zeros((num_max, ), dtype=np.int32)
  1528. if gt_num > 0:
  1529. pad_crowd[:gt_num] = sample['is_crowd'][:gt_num, 0]
  1530. sample['is_crowd'] = pad_crowd
  1531. return sample
  1532. class _NormalizeBox(Transform):
  1533. def __init__(self):
  1534. super(_NormalizeBox, self).__init__()
  1535. def apply(self, sample):
  1536. height, width = sample['image'].shape[:2]
  1537. for i in range(sample['gt_bbox'].shape[0]):
  1538. sample['gt_bbox'][i][0] = sample['gt_bbox'][i][0] / width
  1539. sample['gt_bbox'][i][1] = sample['gt_bbox'][i][1] / height
  1540. sample['gt_bbox'][i][2] = sample['gt_bbox'][i][2] / width
  1541. sample['gt_bbox'][i][3] = sample['gt_bbox'][i][3] / height
  1542. return sample
  1543. class _BboxXYXY2XYWH(Transform):
  1544. """
  1545. Convert bbox XYXY format to XYWH format.
  1546. """
  1547. def __init__(self):
  1548. super(_BboxXYXY2XYWH, self).__init__()
  1549. def apply(self, sample):
  1550. bbox = sample['gt_bbox']
  1551. bbox[:, 2:4] = bbox[:, 2:4] - bbox[:, :2]
  1552. bbox[:, :2] = bbox[:, :2] + bbox[:, 2:4] / 2.
  1553. sample['gt_bbox'] = bbox
  1554. return sample
  1555. class _Permute(Transform):
  1556. def __init__(self):
  1557. super(_Permute, self).__init__()
  1558. def apply(self, sample):
  1559. sample['image'] = permute(sample['image'], False)
  1560. if 'image2' in sample:
  1561. sample['image2'] = permute(sample['image2'], False)
  1562. return sample
  1563. class RandomSwap(Transform):
  1564. """
  1565. Randomly swap multi-temporal images.
  1566. Args:
  1567. prob (float, optional): Probability of swapping the input images.
  1568. Default: 0.2.
  1569. """
  1570. def __init__(self, prob=0.2):
  1571. super(RandomSwap, self).__init__()
  1572. self.prob = prob
  1573. def apply(self, sample):
  1574. if 'image2' not in sample:
  1575. raise ValueError("'image2' is not found in the sample.")
  1576. if random.random() < self.prob:
  1577. sample['image'], sample['image2'] = sample['image2'], sample[
  1578. 'image']
  1579. return sample
  1580. class ReloadMask(Transform):
  1581. def apply(self, sample):
  1582. sample['mask'] = decode_seg_mask(sample['mask_ori'])
  1583. if 'aux_masks' in sample:
  1584. sample['aux_masks'] = list(
  1585. map(decode_seg_mask, sample['aux_masks_ori']))
  1586. return sample
  1587. class Arrange(Transform):
  1588. def __init__(self, mode):
  1589. super().__init__()
  1590. if mode not in ['train', 'eval', 'test', 'quant']:
  1591. raise ValueError(
  1592. "`mode` should be defined as one of ['train', 'eval', 'test', 'quant']!"
  1593. )
  1594. self.mode = mode
  1595. class ArrangeSegmenter(Arrange):
  1596. def apply(self, sample):
  1597. if 'mask' in sample:
  1598. mask = sample['mask']
  1599. mask = mask.astype('int64')
  1600. image = permute(sample['image'], False)
  1601. if self.mode == 'train':
  1602. return image, mask
  1603. if self.mode == 'eval':
  1604. return image, mask
  1605. if self.mode == 'test':
  1606. return image,
  1607. class ArrangeChangeDetector(Arrange):
  1608. def apply(self, sample):
  1609. if 'mask' in sample:
  1610. mask = sample['mask']
  1611. mask = mask.astype('int64')
  1612. image_t1 = permute(sample['image'], False)
  1613. image_t2 = permute(sample['image2'], False)
  1614. if self.mode == 'train':
  1615. masks = [mask]
  1616. if 'aux_masks' in sample:
  1617. masks.extend(
  1618. map(methodcaller('astype', 'int64'), sample['aux_masks']))
  1619. return (
  1620. image_t1,
  1621. image_t2, ) + tuple(masks)
  1622. if self.mode == 'eval':
  1623. return image_t1, image_t2, mask
  1624. if self.mode == 'test':
  1625. return image_t1, image_t2,
  1626. class ArrangeClassifier(Arrange):
  1627. def apply(self, sample):
  1628. image = permute(sample['image'], False)
  1629. if self.mode in ['train', 'eval']:
  1630. return image, sample['label']
  1631. else:
  1632. return image
  1633. class ArrangeDetector(Arrange):
  1634. def apply(self, sample):
  1635. if self.mode == 'eval' and 'gt_poly' in sample:
  1636. del sample['gt_poly']
  1637. return sample