operators.py 66 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import copy
  16. import random
  17. from numbers import Number
  18. from functools import partial
  19. from operator import methodcaller
  20. try:
  21. from collections.abc import Sequence
  22. except Exception:
  23. from collections import Sequence
  24. import numpy as np
  25. import cv2
  26. import imghdr
  27. from PIL import Image
  28. from joblib import load
  29. import paddlers
  30. from .functions import normalize, horizontal_flip, permute, vertical_flip, center_crop, is_poly, \
  31. horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly, vertical_flip_rle, crop_poly, \
  32. crop_rle, expand_poly, expand_rle, resize_poly, resize_rle, dehaze, select_bands, \
  33. to_intensity, to_uint8, img_flip, img_simple_rotate
  34. __all__ = [
  35. "Compose",
  36. "DecodeImg",
  37. "Resize",
  38. "RandomResize",
  39. "ResizeByShort",
  40. "RandomResizeByShort",
  41. "ResizeByLong",
  42. "RandomHorizontalFlip",
  43. "RandomVerticalFlip",
  44. "Normalize",
  45. "CenterCrop",
  46. "RandomCrop",
  47. "RandomScaleAspect",
  48. "RandomExpand",
  49. "Pad",
  50. "MixupImage",
  51. "RandomDistort",
  52. "RandomBlur",
  53. "RandomSwap",
  54. "Dehaze",
  55. "ReduceDim",
  56. "SelectBand",
  57. "ArrangeSegmenter",
  58. "ArrangeChangeDetector",
  59. "ArrangeClassifier",
  60. "ArrangeDetector",
  61. "RandomFlipOrRotate",
  62. ]
  63. interp_dict = {
  64. 'NEAREST': cv2.INTER_NEAREST,
  65. 'LINEAR': cv2.INTER_LINEAR,
  66. 'CUBIC': cv2.INTER_CUBIC,
  67. 'AREA': cv2.INTER_AREA,
  68. 'LANCZOS4': cv2.INTER_LANCZOS4
  69. }
  70. class Compose(object):
  71. """
  72. Apply a series of data augmentation strategies to the input.
  73. All input images should be in Height-Width-Channel ([H, W, C]) format.
  74. Args:
  75. transforms (list[paddlers.transforms.Transform]): List of data preprocess or augmentation operators.
  76. arrange (list[paddlers.transforms.Arrange]|None, optional): If not None, the Arrange operator will be used to
  77. arrange the outputs of `transforms`. Defaults to None.
  78. Raises:
  79. TypeError: Invalid type of transforms.
  80. ValueError: Invalid length of transforms.
  81. """
  82. def __init__(self, transforms, arrange=None):
  83. super(Compose, self).__init__()
  84. if not isinstance(transforms, list):
  85. raise TypeError(
  86. "Type of transforms is invalid. Must be a list, but received is {}."
  87. .format(type(transforms)))
  88. if len(transforms) < 1:
  89. raise ValueError(
  90. "Length of transforms must not be less than 1, but received is {}."
  91. .format(len(transforms)))
  92. self.transforms = transforms
  93. self.arrange = arrange
  94. def __call__(self, sample):
  95. """
  96. This is equivalent to sequentially calling compose_obj.apply_transforms() and compose_obj.arrange_outputs().
  97. """
  98. sample = self.apply_transforms(sample)
  99. sample = self.arrange_outputs(sample)
  100. return sample
  101. def apply_transforms(self, sample):
  102. for op in self.transforms:
  103. # Skip batch transforms amd mixup
  104. if isinstance(op, (paddlers.transforms.BatchRandomResize,
  105. paddlers.transforms.BatchRandomResizeByShort,
  106. MixupImage)):
  107. continue
  108. sample = op(sample)
  109. return sample
  110. def arrange_outputs(self, sample):
  111. if self.arrange is not None:
  112. sample = self.arrange(sample)
  113. return sample
  114. class Transform(object):
  115. """
  116. Parent class of all data augmentation operations
  117. """
  118. def __init__(self):
  119. pass
  120. def apply_im(self, image):
  121. pass
  122. def apply_mask(self, mask):
  123. pass
  124. def apply_bbox(self, bbox):
  125. pass
  126. def apply_segm(self, segms):
  127. pass
  128. def apply(self, sample):
  129. if 'image' in sample:
  130. sample['image'] = self.apply_im(sample['image'])
  131. else: # image_tx
  132. sample['image'] = self.apply_im(sample['image_t1'])
  133. sample['image2'] = self.apply_im(sample['image_t2'])
  134. if 'mask' in sample:
  135. sample['mask'] = self.apply_mask(sample['mask'])
  136. if 'gt_bbox' in sample:
  137. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'])
  138. if 'aux_masks' in sample:
  139. sample['aux_masks'] = list(
  140. map(self.apply_mask, sample['aux_masks']))
  141. return sample
  142. def __call__(self, sample):
  143. if isinstance(sample, Sequence):
  144. sample = [self.apply(s) for s in sample]
  145. else:
  146. sample = self.apply(sample)
  147. return sample
  148. class DecodeImg(Transform):
  149. """
  150. Decode image(s) in input.
  151. Args:
  152. to_rgb (bool, optional): If True, convert input image(s) from BGR format to RGB format. Defaults to True.
  153. to_uint8 (bool, optional): If True, quantize and convert decoded image(s) to uint8 type. Defaults to True.
  154. decode_bgr (bool, optional): If True, automatically interpret a non-geo image (e.g., jpeg images) as a BGR image.
  155. Defaults to True.
  156. decode_sar (bool, optional): If True, automatically interpret a two-channel geo image (e.g. geotiff images) as a
  157. SAR image, set this argument to True. Defaults to True.
  158. """
  159. def __init__(self,
  160. to_rgb=True,
  161. to_uint8=True,
  162. decode_bgr=True,
  163. decode_sar=True):
  164. super(DecodeImg, self).__init__()
  165. self.to_rgb = to_rgb
  166. self.to_uint8 = to_uint8
  167. self.decode_bgr = decode_bgr
  168. self.decode_sar = decode_sar
  169. def read_img(self, img_path):
  170. img_format = imghdr.what(img_path)
  171. name, ext = os.path.splitext(img_path)
  172. if img_format == 'tiff' or ext == '.img':
  173. try:
  174. import gdal
  175. except:
  176. try:
  177. from osgeo import gdal
  178. except ImportError:
  179. raise ImportError(
  180. "Failed to import gdal! Please install GDAL library according to the document."
  181. )
  182. dataset = gdal.Open(img_path)
  183. if dataset == None:
  184. raise IOError('Can not open', img_path)
  185. im_data = dataset.ReadAsArray()
  186. if im_data.ndim == 2 and self.decode_sar:
  187. im_data = to_intensity(im_data) # is read SAR
  188. im_data = im_data[:, :, np.newaxis]
  189. else:
  190. if im_data.ndim == 3:
  191. im_data = im_data.transpose((1, 2, 0))
  192. return im_data
  193. elif img_format in ['jpeg', 'bmp', 'png', 'jpg']:
  194. if self.decode_bgr:
  195. return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
  196. cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR)
  197. else:
  198. return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
  199. cv2.IMREAD_ANYCOLOR)
  200. elif ext == '.npy':
  201. return np.load(img_path)
  202. else:
  203. raise TypeError("Image format {} is not supported!".format(ext))
  204. def apply_im(self, im_path):
  205. if isinstance(im_path, str):
  206. try:
  207. image = self.read_img(im_path)
  208. except:
  209. raise ValueError("Cannot read the image file {}!".format(
  210. im_path))
  211. else:
  212. image = im_path
  213. if self.to_rgb and image.shape[-1] == 3:
  214. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  215. if self.to_uint8:
  216. image = to_uint8(image)
  217. return image
  218. def apply_mask(self, mask):
  219. try:
  220. mask = np.asarray(Image.open(mask))
  221. except:
  222. raise ValueError("Cannot read the mask file {}!".format(mask))
  223. if len(mask.shape) != 2:
  224. raise ValueError(
  225. "Mask should be a 1-channel image, but recevied is a {}-channel image.".
  226. format(mask.shape[2]))
  227. return mask
  228. def apply(self, sample):
  229. """
  230. Args:
  231. sample (dict): Input sample.
  232. Returns:
  233. dict: Decoded sample.
  234. """
  235. if 'image' in sample:
  236. sample['image'] = self.apply_im(sample['image'])
  237. if 'image2' in sample:
  238. sample['image2'] = self.apply_im(sample['image2'])
  239. if 'image_t1' in sample and not 'image' in sample:
  240. if not ('image_t2' in sample and 'image2' not in sample):
  241. raise ValueError
  242. sample['image'] = self.apply_im(sample['image_t1'])
  243. sample['image2'] = self.apply_im(sample['image_t2'])
  244. if 'mask' in sample:
  245. sample['mask'] = self.apply_mask(sample['mask'])
  246. im_height, im_width, _ = sample['image'].shape
  247. se_height, se_width = sample['mask'].shape
  248. if im_height != se_height or im_width != se_width:
  249. raise ValueError(
  250. "The height or width of the image is not same as the mask.")
  251. if 'aux_masks' in sample:
  252. sample['aux_masks'] = list(
  253. map(self.apply_mask, sample['aux_masks']))
  254. # TODO: check the shape of auxiliary masks
  255. sample['im_shape'] = np.array(
  256. sample['image'].shape[:2], dtype=np.float32)
  257. sample['scale_factor'] = np.array([1., 1.], dtype=np.float32)
  258. return sample
  259. class Resize(Transform):
  260. """
  261. Resize input.
  262. - If target_size is an int, resize the image(s) to (target_size, target_size).
  263. - If target_size is a list or tuple, resize the image(s) to target_size.
  264. Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
  265. Args:
  266. target_size (int, list[int] | tuple[int]): Target size. If int, the height and width share the same target_size.
  267. Otherwise, target_size represents [target height, target width].
  268. interp ({'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}, optional):
  269. Interpolation method of resize. Defaults to 'LINEAR'.
  270. keep_ratio (bool): the resize scale of width/height is same and width/height after resized is not greater
  271. than target width/height. Defaults to False.
  272. Raises:
  273. TypeError: Invalid type of target_size.
  274. ValueError: Invalid interpolation method.
  275. """
  276. def __init__(self, target_size, interp='LINEAR', keep_ratio=False):
  277. super(Resize, self).__init__()
  278. if not (interp == "RANDOM" or interp in interp_dict):
  279. raise ValueError("`interp` should be one of {}.".format(
  280. interp_dict.keys()))
  281. if isinstance(target_size, int):
  282. target_size = (target_size, target_size)
  283. else:
  284. if not (isinstance(target_size,
  285. (list, tuple)) and len(target_size) == 2):
  286. raise TypeError(
  287. "`target_size` should be an int or a list of length 2, but received {}.".
  288. format(target_size))
  289. # (height, width)
  290. self.target_size = target_size
  291. self.interp = interp
  292. self.keep_ratio = keep_ratio
  293. def apply_im(self, image, interp, target_size):
  294. flag = image.shape[2] == 1
  295. image = cv2.resize(image, target_size, interpolation=interp)
  296. if flag:
  297. image = image[:, :, np.newaxis]
  298. return image
  299. def apply_mask(self, mask, target_size):
  300. mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
  301. return mask
  302. def apply_bbox(self, bbox, scale, target_size):
  303. im_scale_x, im_scale_y = scale
  304. bbox[:, 0::2] *= im_scale_x
  305. bbox[:, 1::2] *= im_scale_y
  306. bbox[:, 0::2] = np.clip(bbox[:, 0::2], 0, target_size[0])
  307. bbox[:, 1::2] = np.clip(bbox[:, 1::2], 0, target_size[1])
  308. return bbox
  309. def apply_segm(self, segms, im_size, scale):
  310. im_h, im_w = im_size
  311. im_scale_x, im_scale_y = scale
  312. resized_segms = []
  313. for segm in segms:
  314. if is_poly(segm):
  315. # Polygon format
  316. resized_segms.append([
  317. resize_poly(poly, im_scale_x, im_scale_y) for poly in segm
  318. ])
  319. else:
  320. # RLE format
  321. resized_segms.append(
  322. resize_rle(segm, im_h, im_w, im_scale_x, im_scale_y))
  323. return resized_segms
  324. def apply(self, sample):
  325. if self.interp == "RANDOM":
  326. interp = random.choice(list(interp_dict.values()))
  327. else:
  328. interp = interp_dict[self.interp]
  329. im_h, im_w = sample['image'].shape[:2]
  330. im_scale_y = self.target_size[0] / im_h
  331. im_scale_x = self.target_size[1] / im_w
  332. target_size = (self.target_size[1], self.target_size[0])
  333. if self.keep_ratio:
  334. scale = min(im_scale_y, im_scale_x)
  335. target_w = int(round(im_w * scale))
  336. target_h = int(round(im_h * scale))
  337. target_size = (target_w, target_h)
  338. im_scale_y = target_h / im_h
  339. im_scale_x = target_w / im_w
  340. sample['image'] = self.apply_im(sample['image'], interp, target_size)
  341. if 'image2' in sample:
  342. sample['image2'] = self.apply_im(sample['image2'], interp,
  343. target_size)
  344. if 'mask' in sample:
  345. sample['mask'] = self.apply_mask(sample['mask'], target_size)
  346. if 'aux_masks' in sample:
  347. sample['aux_masks'] = list(
  348. map(partial(
  349. self.apply_mask, target_size=target_size),
  350. sample['aux_masks']))
  351. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  352. sample['gt_bbox'] = self.apply_bbox(
  353. sample['gt_bbox'], [im_scale_x, im_scale_y], target_size)
  354. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  355. sample['gt_poly'] = self.apply_segm(
  356. sample['gt_poly'], [im_h, im_w], [im_scale_x, im_scale_y])
  357. sample['im_shape'] = np.asarray(
  358. sample['image'].shape[:2], dtype=np.float32)
  359. if 'scale_factor' in sample:
  360. scale_factor = sample['scale_factor']
  361. sample['scale_factor'] = np.asarray(
  362. [scale_factor[0] * im_scale_y, scale_factor[1] * im_scale_x],
  363. dtype=np.float32)
  364. return sample
  365. class RandomResize(Transform):
  366. """
  367. Resize input to random sizes.
  368. Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
  369. Args:
  370. target_sizes (list[int] | list[list | tuple] | tuple[list | tuple]):
  371. Multiple target sizes, each target size is an int or list/tuple.
  372. interp ({'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}, optional):
  373. Interpolation method of resize. Defaults to 'LINEAR'.
  374. Raises:
  375. TypeError: Invalid type of target_size.
  376. ValueError: Invalid interpolation method.
  377. See Also:
  378. Resize input to a specific size.
  379. """
  380. def __init__(self, target_sizes, interp='LINEAR'):
  381. super(RandomResize, self).__init__()
  382. if not (interp == "RANDOM" or interp in interp_dict):
  383. raise ValueError("`interp` should be one of {}.".format(
  384. interp_dict.keys()))
  385. self.interp = interp
  386. assert isinstance(target_sizes, list), \
  387. "`target_size` must be a list."
  388. for i, item in enumerate(target_sizes):
  389. if isinstance(item, int):
  390. target_sizes[i] = (item, item)
  391. self.target_size = target_sizes
  392. def apply(self, sample):
  393. height, width = random.choice(self.target_size)
  394. resizer = Resize((height, width), interp=self.interp)
  395. sample = resizer(sample)
  396. return sample
  397. class ResizeByShort(Transform):
  398. """
  399. Resize input with keeping the aspect ratio.
  400. Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
  401. Args:
  402. short_size (int): Target size of the shorter side of the image(s).
  403. max_size (int, optional): The upper bound of longer side of the image(s). If max_size is -1, no upper bound is applied. Defaults to -1.
  404. interp ({'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}, optional): Interpolation method of resize. Defaults to 'LINEAR'.
  405. Raises:
  406. ValueError: Invalid interpolation method.
  407. """
  408. def __init__(self, short_size=256, max_size=-1, interp='LINEAR'):
  409. if not (interp == "RANDOM" or interp in interp_dict):
  410. raise ValueError("`interp` should be one of {}".format(
  411. interp_dict.keys()))
  412. super(ResizeByShort, self).__init__()
  413. self.short_size = short_size
  414. self.max_size = max_size
  415. self.interp = interp
  416. def apply(self, sample):
  417. im_h, im_w = sample['image'].shape[:2]
  418. im_short_size = min(im_h, im_w)
  419. im_long_size = max(im_h, im_w)
  420. scale = float(self.short_size) / float(im_short_size)
  421. if 0 < self.max_size < np.round(scale * im_long_size):
  422. scale = float(self.max_size) / float(im_long_size)
  423. target_w = int(round(im_w * scale))
  424. target_h = int(round(im_h * scale))
  425. sample = Resize(
  426. target_size=(target_h, target_w), interp=self.interp)(sample)
  427. return sample
  428. class RandomResizeByShort(Transform):
  429. """
  430. Resize input to random sizes with keeping the aspect ratio.
  431. Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
  432. Args:
  433. short_sizes (list[int]): Target size of the shorter side of the image(s).
  434. max_size (int, optional): The upper bound of longer side of the image(s). If max_size is -1, no upper bound is applied. Defaults to -1.
  435. interp ({'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}, optional): Interpolation method of resize. Defaults to 'LINEAR'.
  436. Raises:
  437. TypeError: Invalid type of target_size.
  438. ValueError: Invalid interpolation method.
  439. See Also:
  440. ResizeByShort: Resize image(s) in input with keeping the aspect ratio.
  441. """
  442. def __init__(self, short_sizes, max_size=-1, interp='LINEAR'):
  443. super(RandomResizeByShort, self).__init__()
  444. if not (interp == "RANDOM" or interp in interp_dict):
  445. raise ValueError("`interp` should be one of {}".format(
  446. interp_dict.keys()))
  447. self.interp = interp
  448. assert isinstance(short_sizes, list), \
  449. "`short_sizes` must be a list."
  450. self.short_sizes = short_sizes
  451. self.max_size = max_size
  452. def apply(self, sample):
  453. short_size = random.choice(self.short_sizes)
  454. resizer = ResizeByShort(
  455. short_size=short_size, max_size=self.max_size, interp=self.interp)
  456. sample = resizer(sample)
  457. return sample
  458. class ResizeByLong(Transform):
  459. def __init__(self, long_size=256, interp='LINEAR'):
  460. super(ResizeByLong, self).__init__()
  461. self.long_size = long_size
  462. self.interp = interp
  463. def apply(self, sample):
  464. im_h, im_w = sample['image'].shape[:2]
  465. im_long_size = max(im_h, im_w)
  466. scale = float(self.long_size) / float(im_long_size)
  467. target_h = int(round(im_h * scale))
  468. target_w = int(round(im_w * scale))
  469. sample = Resize(
  470. target_size=(target_h, target_w), interp=self.interp)(sample)
  471. return sample
  472. class RandomFlipOrRotate(Transform):
  473. """
  474. Flip or Rotate an image in different ways with a certain probability.
  475. Args:
  476. probs (list of float): Probabilities of flipping and rotation. Default: [0.35,0.25].
  477. probsf (list of float): Probabilities of 5 flipping mode
  478. (horizontal, vertical, both horizontal diction and vertical, diagonal, anti-diagonal).
  479. Default: [0.3, 0.3, 0.2, 0.1, 0.1].
  480. probsr (list of float): Probabilities of 3 rotation mode(90°, 180°, 270° clockwise). Default: [0.25,0.5,0.25].
  481. Examples:
  482. from paddlers import transforms as T
  483. # 定义数据增强
  484. train_transforms = T.Compose([
  485. T.DecodeImg(),
  486. T.RandomFlipOrRotate(
  487. probs = [0.3, 0.2] # 进行flip增强的概率是0.3,进行rotate增强的概率是0.2,不变的概率是0.5
  488. probsf = [0.3, 0.25, 0, 0, 0] # flip增强时,使用水平flip、垂直flip的概率分别是0.3、0.25,水平且垂直flip、对角线flip、反对角线flip概率均为0,不变的概率是0.45
  489. probsr = [0, 0.65, 0]), # rotate增强时,顺时针旋转90度的概率是0,顺时针旋转180度的概率是0.65,顺时针旋转90度的概率是0,不变的概率是0.35
  490. T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  491. ])
  492. """
  493. def __init__(self,
  494. probs=[0.35, 0.25],
  495. probsf=[0.3, 0.3, 0.2, 0.1, 0.1],
  496. probsr=[0.25, 0.5, 0.25]):
  497. super(RandomFlipOrRotate, self).__init__()
  498. # Change various probabilities into probability intervals, to judge in which mode to flip or rotate
  499. self.probs = [probs[0], probs[0] + probs[1]]
  500. self.probsf = self.get_probs_range(probsf)
  501. self.probsr = self.get_probs_range(probsr)
  502. def apply_im(self, image, mode_id, flip_mode=True):
  503. if flip_mode:
  504. image = img_flip(image, mode_id)
  505. else:
  506. image = img_simple_rotate(image, mode_id)
  507. return image
  508. def apply_mask(self, mask, mode_id, flip_mode=True):
  509. if flip_mode:
  510. mask = img_flip(mask, mode_id)
  511. else:
  512. mask = img_simple_rotate(mask, mode_id)
  513. return mask
  514. def apply_bbox(self, bbox, mode_id, flip_mode=True):
  515. raise TypeError(
  516. "Currently, RandomFlipOrRotate is not available for object detection tasks."
  517. )
  518. def apply_segm(self, bbox, mode_id, flip_mode=True):
  519. raise TypeError(
  520. "Currently, RandomFlipOrRotate is not available for object detection tasks."
  521. )
  522. def get_probs_range(self, probs):
  523. '''
  524. Change various probabilities into cumulative probabilities
  525. Args:
  526. probs(list of float): probabilities of different mode, shape:[n]
  527. Returns:
  528. probability intervals(list of binary list): shape:[n, 2]
  529. '''
  530. ps = []
  531. last_prob = 0
  532. for prob in probs:
  533. p_s = last_prob
  534. cur_prob = prob / sum(probs)
  535. last_prob += cur_prob
  536. p_e = last_prob
  537. ps.append([p_s, p_e])
  538. return ps
  539. def judge_probs_range(self, p, probs):
  540. '''
  541. Judge whether a probability value falls within the given probability interval
  542. Args:
  543. p(float): probability
  544. probs(list of binary list): probability intervals, shape:[n, 2]
  545. Returns:
  546. mode id(int):the probability interval number where the input probability falls,
  547. if return -1, the image will remain as it is and will not be processed
  548. '''
  549. for id, id_range in enumerate(probs):
  550. if p > id_range[0] and p < id_range[1]:
  551. return id
  552. return -1
  553. def apply(self, sample):
  554. p_m = random.random()
  555. if p_m < self.probs[0]:
  556. mode_p = random.random()
  557. mode_id = self.judge_probs_range(mode_p, self.probsf)
  558. sample['image'] = self.apply_im(sample['image'], mode_id, True)
  559. if 'image2' in sample:
  560. sample['image2'] = self.apply_im(sample['image2'], mode_id,
  561. True)
  562. if 'mask' in sample:
  563. sample['mask'] = self.apply_mask(sample['mask'], mode_id, True)
  564. if 'aux_masks' in sample:
  565. sample['aux_masks'] = [
  566. self.apply_mask(aux_mask, mode_id, True)
  567. for aux_mask in sample['aux_masks']
  568. ]
  569. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  570. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], mode_id,
  571. True)
  572. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  573. sample['gt_poly'] = self.apply_segm(sample['gt_poly'], mode_id,
  574. True)
  575. elif p_m < self.probs[1]:
  576. mode_p = random.random()
  577. mode_id = self.judge_probs_range(mode_p, self.probsr)
  578. sample['image'] = self.apply_im(sample['image'], mode_id, False)
  579. if 'image2' in sample:
  580. sample['image2'] = self.apply_im(sample['image2'], mode_id,
  581. False)
  582. if 'mask' in sample:
  583. sample['mask'] = self.apply_mask(sample['mask'], mode_id, False)
  584. if 'aux_masks' in sample:
  585. sample['aux_masks'] = [
  586. self.apply_mask(aux_mask, mode_id, False)
  587. for aux_mask in sample['aux_masks']
  588. ]
  589. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  590. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], mode_id,
  591. False)
  592. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  593. sample['gt_poly'] = self.apply_segm(sample['gt_poly'], mode_id,
  594. False)
  595. return sample
  596. class RandomHorizontalFlip(Transform):
  597. """
  598. Randomly flip the input horizontally.
  599. Args:
  600. prob(float, optional): Probability of flipping the input. Defaults to .5.
  601. """
  602. def __init__(self, prob=0.5):
  603. super(RandomHorizontalFlip, self).__init__()
  604. self.prob = prob
  605. def apply_im(self, image):
  606. image = horizontal_flip(image)
  607. return image
  608. def apply_mask(self, mask):
  609. mask = horizontal_flip(mask)
  610. return mask
  611. def apply_bbox(self, bbox, width):
  612. oldx1 = bbox[:, 0].copy()
  613. oldx2 = bbox[:, 2].copy()
  614. bbox[:, 0] = width - oldx2
  615. bbox[:, 2] = width - oldx1
  616. return bbox
  617. def apply_segm(self, segms, height, width):
  618. flipped_segms = []
  619. for segm in segms:
  620. if is_poly(segm):
  621. # Polygon format
  622. flipped_segms.append(
  623. [horizontal_flip_poly(poly, width) for poly in segm])
  624. else:
  625. # RLE format
  626. flipped_segms.append(horizontal_flip_rle(segm, height, width))
  627. return flipped_segms
  628. def apply(self, sample):
  629. if random.random() < self.prob:
  630. im_h, im_w = sample['image'].shape[:2]
  631. sample['image'] = self.apply_im(sample['image'])
  632. if 'image2' in sample:
  633. sample['image2'] = self.apply_im(sample['image2'])
  634. if 'mask' in sample:
  635. sample['mask'] = self.apply_mask(sample['mask'])
  636. if 'aux_masks' in sample:
  637. sample['aux_masks'] = list(
  638. map(self.apply_mask, sample['aux_masks']))
  639. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  640. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], im_w)
  641. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  642. sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_h,
  643. im_w)
  644. return sample
  645. class RandomVerticalFlip(Transform):
  646. """
  647. Randomly flip the input vertically.
  648. Args:
  649. prob(float, optional): Probability of flipping the input. Defaults to .5.
  650. """
  651. def __init__(self, prob=0.5):
  652. super(RandomVerticalFlip, self).__init__()
  653. self.prob = prob
  654. def apply_im(self, image):
  655. image = vertical_flip(image)
  656. return image
  657. def apply_mask(self, mask):
  658. mask = vertical_flip(mask)
  659. return mask
  660. def apply_bbox(self, bbox, height):
  661. oldy1 = bbox[:, 1].copy()
  662. oldy2 = bbox[:, 3].copy()
  663. bbox[:, 0] = height - oldy2
  664. bbox[:, 2] = height - oldy1
  665. return bbox
  666. def apply_segm(self, segms, height, width):
  667. flipped_segms = []
  668. for segm in segms:
  669. if is_poly(segm):
  670. # Polygon format
  671. flipped_segms.append(
  672. [vertical_flip_poly(poly, height) for poly in segm])
  673. else:
  674. # RLE format
  675. flipped_segms.append(vertical_flip_rle(segm, height, width))
  676. return flipped_segms
  677. def apply(self, sample):
  678. if random.random() < self.prob:
  679. im_h, im_w = sample['image'].shape[:2]
  680. sample['image'] = self.apply_im(sample['image'])
  681. if 'image2' in sample:
  682. sample['image2'] = self.apply_im(sample['image2'])
  683. if 'mask' in sample:
  684. sample['mask'] = self.apply_mask(sample['mask'])
  685. if 'aux_masks' in sample:
  686. sample['aux_masks'] = list(
  687. map(self.apply_mask, sample['aux_masks']))
  688. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  689. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], im_h)
  690. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  691. sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_h,
  692. im_w)
  693. return sample
  694. class Normalize(Transform):
  695. """
  696. Apply normalization to the input image(s). The normalization steps are:
  697. 1. im = (im - min_value) * 1 / (max_value - min_value)
  698. 2. im = im - mean
  699. 3. im = im / std
  700. Args:
  701. mean(list[float] | tuple[float], optional): Mean of input image(s). Defaults to [0.485, 0.456, 0.406].
  702. std(list[float] | tuple[float], optional): Standard deviation of input image(s). Defaults to [0.229, 0.224, 0.225].
  703. min_val(list[float] | tuple[float], optional): Minimum value of input image(s). Defaults to [0, 0, 0, ].
  704. max_val(list[float] | tuple[float], optional): Max value of input image(s). Defaults to [255., 255., 255.].
  705. """
  706. def __init__(self,
  707. mean=[0.485, 0.456, 0.406],
  708. std=[0.229, 0.224, 0.225],
  709. min_val=None,
  710. max_val=None):
  711. super(Normalize, self).__init__()
  712. channel = len(mean)
  713. if min_val is None:
  714. min_val = [0] * channel
  715. if max_val is None:
  716. max_val = [255.] * channel
  717. from functools import reduce
  718. if reduce(lambda x, y: x * y, std) == 0:
  719. raise ValueError(
  720. "`std` should not contain 0, but received is {}.".format(std))
  721. if reduce(lambda x, y: x * y,
  722. [a - b for a, b in zip(max_val, min_val)]) == 0:
  723. raise ValueError(
  724. "(`max_val` - `min_val`) should not contain 0, but received is {}.".
  725. format((np.asarray(max_val) - np.asarray(min_val)).tolist()))
  726. self.mean = mean
  727. self.std = std
  728. self.min_val = min_val
  729. self.max_val = max_val
  730. def apply_im(self, image):
  731. image = image.astype(np.float32)
  732. mean = np.asarray(
  733. self.mean, dtype=np.float32)[np.newaxis, np.newaxis, :]
  734. std = np.asarray(self.std, dtype=np.float32)[np.newaxis, np.newaxis, :]
  735. image = normalize(image, mean, std, self.min_val, self.max_val)
  736. return image
  737. def apply(self, sample):
  738. sample['image'] = self.apply_im(sample['image'])
  739. if 'image2' in sample:
  740. sample['image2'] = self.apply_im(sample['image2'])
  741. return sample
  742. class CenterCrop(Transform):
  743. """
  744. Crop the input at the center.
  745. 1. Locate the center of the image.
  746. 2. Crop the sample.
  747. Args:
  748. crop_size(int, optional): target size of the cropped image(s). Defaults to 224.
  749. """
  750. def __init__(self, crop_size=224):
  751. super(CenterCrop, self).__init__()
  752. self.crop_size = crop_size
  753. def apply_im(self, image):
  754. image = center_crop(image, self.crop_size)
  755. return image
  756. def apply_mask(self, mask):
  757. mask = center_crop(mask, self.crop_size)
  758. return mask
  759. def apply(self, sample):
  760. sample['image'] = self.apply_im(sample['image'])
  761. if 'image2' in sample:
  762. sample['image2'] = self.apply_im(sample['image2'])
  763. if 'mask' in sample:
  764. sample['mask'] = self.apply_mask(sample['mask'])
  765. if 'aux_masks' in sample:
  766. sample['aux_masks'] = list(
  767. map(self.apply_mask, sample['aux_masks']))
  768. return sample
  769. class RandomCrop(Transform):
  770. """
  771. Randomly crop the input.
  772. 1. Compute the height and width of cropped area according to aspect_ratio and scaling.
  773. 2. Locate the upper left corner of cropped area randomly.
  774. 3. Crop the image(s).
  775. 4. Resize the cropped area to crop_size by crop_size.
  776. Args:
  777. crop_size(int, list[int] | tuple[int]): Target size of the cropped area. If None, the cropped area will not be
  778. resized. Defaults to None.
  779. aspect_ratio (list[float], optional): Aspect ratio of cropped region in [min, max] format. Defaults to [.5, 2.].
  780. thresholds (list[float], optional): Iou thresholds to decide a valid bbox crop.
  781. Defaults to [.0, .1, .3, .5, .7, .9].
  782. scaling (list[float], optional): Ratio between the cropped region and the original image in [min, max] format.
  783. Defaults to [.3, 1.].
  784. num_attempts (int, optional): The number of tries before giving up. Defaults to 50.
  785. allow_no_crop (bool, optional): Whether returning without doing crop is allowed. Defaults to True.
  786. cover_all_box (bool, optional): Whether to ensure all bboxes are covered in the final crop. Defaults to False.
  787. """
  788. def __init__(self,
  789. crop_size=None,
  790. aspect_ratio=[.5, 2.],
  791. thresholds=[.0, .1, .3, .5, .7, .9],
  792. scaling=[.3, 1.],
  793. num_attempts=50,
  794. allow_no_crop=True,
  795. cover_all_box=False):
  796. super(RandomCrop, self).__init__()
  797. self.crop_size = crop_size
  798. self.aspect_ratio = aspect_ratio
  799. self.thresholds = thresholds
  800. self.scaling = scaling
  801. self.num_attempts = num_attempts
  802. self.allow_no_crop = allow_no_crop
  803. self.cover_all_box = cover_all_box
  804. def _generate_crop_info(self, sample):
  805. im_h, im_w = sample['image'].shape[:2]
  806. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  807. thresholds = self.thresholds
  808. if self.allow_no_crop:
  809. thresholds.append('no_crop')
  810. np.random.shuffle(thresholds)
  811. for thresh in thresholds:
  812. if thresh == 'no_crop':
  813. return None
  814. for i in range(self.num_attempts):
  815. crop_box = self._get_crop_box(im_h, im_w)
  816. if crop_box is None:
  817. continue
  818. iou = self._iou_matrix(
  819. sample['gt_bbox'],
  820. np.array(
  821. [crop_box], dtype=np.float32))
  822. if iou.max() < thresh:
  823. continue
  824. if self.cover_all_box and iou.min() < thresh:
  825. continue
  826. cropped_box, valid_ids = self._crop_box_with_center_constraint(
  827. sample['gt_bbox'], np.array(
  828. crop_box, dtype=np.float32))
  829. if valid_ids.size > 0:
  830. return crop_box, cropped_box, valid_ids
  831. else:
  832. for i in range(self.num_attempts):
  833. crop_box = self._get_crop_box(im_h, im_w)
  834. if crop_box is None:
  835. continue
  836. return crop_box, None, None
  837. return None
  838. def _get_crop_box(self, im_h, im_w):
  839. scale = np.random.uniform(*self.scaling)
  840. if self.aspect_ratio is not None:
  841. min_ar, max_ar = self.aspect_ratio
  842. aspect_ratio = np.random.uniform(
  843. max(min_ar, scale**2), min(max_ar, scale**-2))
  844. h_scale = scale / np.sqrt(aspect_ratio)
  845. w_scale = scale * np.sqrt(aspect_ratio)
  846. else:
  847. h_scale = np.random.uniform(*self.scaling)
  848. w_scale = np.random.uniform(*self.scaling)
  849. crop_h = im_h * h_scale
  850. crop_w = im_w * w_scale
  851. if self.aspect_ratio is None:
  852. if crop_h / crop_w < 0.5 or crop_h / crop_w > 2.0:
  853. return None
  854. crop_h = int(crop_h)
  855. crop_w = int(crop_w)
  856. crop_y = np.random.randint(0, im_h - crop_h)
  857. crop_x = np.random.randint(0, im_w - crop_w)
  858. return [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h]
  859. def _iou_matrix(self, a, b):
  860. tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2])
  861. br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
  862. area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2)
  863. area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
  864. area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
  865. area_o = (area_a[:, np.newaxis] + area_b - area_i)
  866. return area_i / (area_o + 1e-10)
  867. def _crop_box_with_center_constraint(self, box, crop):
  868. cropped_box = box.copy()
  869. cropped_box[:, :2] = np.maximum(box[:, :2], crop[:2])
  870. cropped_box[:, 2:] = np.minimum(box[:, 2:], crop[2:])
  871. cropped_box[:, :2] -= crop[:2]
  872. cropped_box[:, 2:] -= crop[:2]
  873. centers = (box[:, :2] + box[:, 2:]) / 2
  874. valid = np.logical_and(crop[:2] <= centers,
  875. centers < crop[2:]).all(axis=1)
  876. valid = np.logical_and(
  877. valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1))
  878. return cropped_box, np.where(valid)[0]
  879. def _crop_segm(self, segms, valid_ids, crop, height, width):
  880. crop_segms = []
  881. for id in valid_ids:
  882. segm = segms[id]
  883. if is_poly(segm):
  884. # Polygon format
  885. crop_segms.append(crop_poly(segm, crop))
  886. else:
  887. # RLE format
  888. crop_segms.append(crop_rle(segm, crop, height, width))
  889. return crop_segms
  890. def apply_im(self, image, crop):
  891. x1, y1, x2, y2 = crop
  892. return image[y1:y2, x1:x2, :]
  893. def apply_mask(self, mask, crop):
  894. x1, y1, x2, y2 = crop
  895. return mask[y1:y2, x1:x2, ...]
  896. def apply(self, sample):
  897. crop_info = self._generate_crop_info(sample)
  898. if crop_info is not None:
  899. crop_box, cropped_box, valid_ids = crop_info
  900. im_h, im_w = sample['image'].shape[:2]
  901. sample['image'] = self.apply_im(sample['image'], crop_box)
  902. if 'image2' in sample:
  903. sample['image2'] = self.apply_im(sample['image2'], crop_box)
  904. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  905. crop_polys = self._crop_segm(
  906. sample['gt_poly'],
  907. valid_ids,
  908. np.array(
  909. crop_box, dtype=np.int64),
  910. im_h,
  911. im_w)
  912. if [] in crop_polys:
  913. delete_id = list()
  914. valid_polys = list()
  915. for idx, poly in enumerate(crop_polys):
  916. if not crop_poly:
  917. delete_id.append(idx)
  918. else:
  919. valid_polys.append(poly)
  920. valid_ids = np.delete(valid_ids, delete_id)
  921. if not valid_polys:
  922. return sample
  923. sample['gt_poly'] = valid_polys
  924. else:
  925. sample['gt_poly'] = crop_polys
  926. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  927. sample['gt_bbox'] = np.take(cropped_box, valid_ids, axis=0)
  928. sample['gt_class'] = np.take(
  929. sample['gt_class'], valid_ids, axis=0)
  930. if 'gt_score' in sample:
  931. sample['gt_score'] = np.take(
  932. sample['gt_score'], valid_ids, axis=0)
  933. if 'is_crowd' in sample:
  934. sample['is_crowd'] = np.take(
  935. sample['is_crowd'], valid_ids, axis=0)
  936. if 'mask' in sample:
  937. sample['mask'] = self.apply_mask(sample['mask'], crop_box)
  938. if 'aux_masks' in sample:
  939. sample['aux_masks'] = list(
  940. map(partial(
  941. self.apply_mask, crop=crop_box),
  942. sample['aux_masks']))
  943. if self.crop_size is not None:
  944. sample = Resize(self.crop_size)(sample)
  945. return sample
  946. class RandomScaleAspect(Transform):
  947. """
  948. Crop input image(s) and resize back to original sizes.
  949. Args:
  950. min_scale (float): Minimum ratio between the cropped region and the original image.
  951. If 0, image(s) will not be cropped. Defaults to .5.
  952. aspect_ratio (float): Aspect ratio of cropped region. Defaults to .33.
  953. """
  954. def __init__(self, min_scale=0.5, aspect_ratio=0.33):
  955. super(RandomScaleAspect, self).__init__()
  956. self.min_scale = min_scale
  957. self.aspect_ratio = aspect_ratio
  958. def apply(self, sample):
  959. if self.min_scale != 0 and self.aspect_ratio != 0:
  960. img_height, img_width = sample['image'].shape[:2]
  961. sample = RandomCrop(
  962. crop_size=(img_height, img_width),
  963. aspect_ratio=[self.aspect_ratio, 1. / self.aspect_ratio],
  964. scaling=[self.min_scale, 1.],
  965. num_attempts=10,
  966. allow_no_crop=False)(sample)
  967. return sample
  968. class RandomExpand(Transform):
  969. """
  970. Randomly expand the input by padding according to random offsets.
  971. Args:
  972. upper_ratio(float, optional): The maximum ratio to which the original image is expanded. Defaults to 4..
  973. prob(float, optional): The probability of apply expanding. Defaults to .5.
  974. im_padding_value(list[float] | tuple[float], optional): RGB filling value for the image. Defaults to (127.5, 127.5, 127.5).
  975. label_padding_value(int, optional): Filling value for the mask. Defaults to 255.
  976. See Also:
  977. paddlers.transforms.Pad
  978. """
  979. def __init__(self,
  980. upper_ratio=4.,
  981. prob=.5,
  982. im_padding_value=127.5,
  983. label_padding_value=255):
  984. super(RandomExpand, self).__init__()
  985. assert upper_ratio > 1.01, "`upper_ratio` must be larger than 1.01."
  986. self.upper_ratio = upper_ratio
  987. self.prob = prob
  988. assert isinstance(im_padding_value, (Number, Sequence)), \
  989. "Value to fill must be either float or sequence."
  990. self.im_padding_value = im_padding_value
  991. self.label_padding_value = label_padding_value
  992. def apply(self, sample):
  993. if random.random() < self.prob:
  994. im_h, im_w = sample['image'].shape[:2]
  995. ratio = np.random.uniform(1., self.upper_ratio)
  996. h = int(im_h * ratio)
  997. w = int(im_w * ratio)
  998. if h > im_h and w > im_w:
  999. y = np.random.randint(0, h - im_h)
  1000. x = np.random.randint(0, w - im_w)
  1001. target_size = (h, w)
  1002. offsets = (x, y)
  1003. sample = Pad(
  1004. target_size=target_size,
  1005. pad_mode=-1,
  1006. offsets=offsets,
  1007. im_padding_value=self.im_padding_value,
  1008. label_padding_value=self.label_padding_value)(sample)
  1009. return sample
  1010. class Pad(Transform):
  1011. def __init__(self,
  1012. target_size=None,
  1013. pad_mode=0,
  1014. offsets=None,
  1015. im_padding_value=127.5,
  1016. label_padding_value=255,
  1017. size_divisor=32):
  1018. """
  1019. Pad image to a specified size or multiple of size_divisor.
  1020. Args:
  1021. target_size(int, Sequence, optional): Image target size, if None, pad to multiple of size_divisor. Defaults to None.
  1022. pad_mode({-1, 0, 1, 2}, optional): Pad mode, currently only supports four modes [-1, 0, 1, 2]. if -1, use specified offsets
  1023. if 0, only pad to right and bottom. If 1, pad according to center. If 2, only pad left and top. Defaults to 0.
  1024. im_padding_value(Sequence[float]): RGB value of pad area. Defaults to (127.5, 127.5, 127.5).
  1025. label_padding_value(int, optional): Filling value for the mask. Defaults to 255.
  1026. size_divisor(int): Image width and height after padding is a multiple of coarsest_stride.
  1027. """
  1028. super(Pad, self).__init__()
  1029. if isinstance(target_size, (list, tuple)):
  1030. if len(target_size) != 2:
  1031. raise ValueError(
  1032. "`target_size` should contain 2 elements, but it is {}.".
  1033. format(target_size))
  1034. if isinstance(target_size, int):
  1035. target_size = [target_size] * 2
  1036. assert pad_mode in [
  1037. -1, 0, 1, 2
  1038. ], "Currently only four modes are supported: [-1, 0, 1, 2]."
  1039. if pad_mode == -1:
  1040. assert offsets, "if `pad_mode` is -1, `offsets` should not be None."
  1041. self.target_size = target_size
  1042. self.size_divisor = size_divisor
  1043. self.pad_mode = pad_mode
  1044. self.offsets = offsets
  1045. self.im_padding_value = im_padding_value
  1046. self.label_padding_value = label_padding_value
  1047. def apply_im(self, image, offsets, target_size):
  1048. x, y = offsets
  1049. h, w = target_size
  1050. im_h, im_w, channel = image.shape[:3]
  1051. canvas = np.ones((h, w, channel), dtype=np.float32)
  1052. canvas *= np.array(self.im_padding_value, dtype=np.float32)
  1053. canvas[y:y + im_h, x:x + im_w, :] = image.astype(np.float32)
  1054. return canvas
  1055. def apply_mask(self, mask, offsets, target_size):
  1056. x, y = offsets
  1057. im_h, im_w = mask.shape[:2]
  1058. h, w = target_size
  1059. canvas = np.ones((h, w), dtype=np.float32)
  1060. canvas *= np.array(self.label_padding_value, dtype=np.float32)
  1061. canvas[y:y + im_h, x:x + im_w] = mask.astype(np.float32)
  1062. return canvas
  1063. def apply_bbox(self, bbox, offsets):
  1064. return bbox + np.array(offsets * 2, dtype=np.float32)
  1065. def apply_segm(self, segms, offsets, im_size, size):
  1066. x, y = offsets
  1067. height, width = im_size
  1068. h, w = size
  1069. expanded_segms = []
  1070. for segm in segms:
  1071. if is_poly(segm):
  1072. # Polygon format
  1073. expanded_segms.append(
  1074. [expand_poly(poly, x, y) for poly in segm])
  1075. else:
  1076. # RLE format
  1077. expanded_segms.append(
  1078. expand_rle(segm, x, y, height, width, h, w))
  1079. return expanded_segms
  1080. def apply(self, sample):
  1081. im_h, im_w = sample['image'].shape[:2]
  1082. if self.target_size:
  1083. h, w = self.target_size
  1084. assert (
  1085. im_h <= h and im_w <= w
  1086. ), 'target size ({}, {}) cannot be less than image size ({}, {})'\
  1087. .format(h, w, im_h, im_w)
  1088. else:
  1089. h = (np.ceil(im_h / self.size_divisor) *
  1090. self.size_divisor).astype(int)
  1091. w = (np.ceil(im_w / self.size_divisor) *
  1092. self.size_divisor).astype(int)
  1093. if h == im_h and w == im_w:
  1094. return sample
  1095. if self.pad_mode == -1:
  1096. offsets = self.offsets
  1097. elif self.pad_mode == 0:
  1098. offsets = [0, 0]
  1099. elif self.pad_mode == 1:
  1100. offsets = [(w - im_w) // 2, (h - im_h) // 2]
  1101. else:
  1102. offsets = [w - im_w, h - im_h]
  1103. sample['image'] = self.apply_im(sample['image'], offsets, (h, w))
  1104. if 'image2' in sample:
  1105. sample['image2'] = self.apply_im(sample['image2'], offsets, (h, w))
  1106. if 'mask' in sample:
  1107. sample['mask'] = self.apply_mask(sample['mask'], offsets, (h, w))
  1108. if 'aux_masks' in sample:
  1109. sample['aux_masks'] = list(
  1110. map(partial(
  1111. self.apply_mask, offsets=offsets, target_size=(h, w)),
  1112. sample['aux_masks']))
  1113. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  1114. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], offsets)
  1115. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  1116. sample['gt_poly'] = self.apply_segm(
  1117. sample['gt_poly'], offsets, im_size=[im_h, im_w], size=[h, w])
  1118. return sample
  1119. class MixupImage(Transform):
  1120. def __init__(self, alpha=1.5, beta=1.5, mixup_epoch=-1):
  1121. """
  1122. Mixup two images and their gt_bbbox/gt_score.
  1123. Args:
  1124. alpha (float, optional): Alpha parameter of beta distribution. Defaults to 1.5.
  1125. beta (float, optional): Beta parameter of beta distribution. Defaults to 1.5.
  1126. """
  1127. super(MixupImage, self).__init__()
  1128. if alpha <= 0.0:
  1129. raise ValueError("`alpha` should be positive in MixupImage.")
  1130. if beta <= 0.0:
  1131. raise ValueError("`beta` should be positive in MixupImage.")
  1132. self.alpha = alpha
  1133. self.beta = beta
  1134. self.mixup_epoch = mixup_epoch
  1135. def apply_im(self, image1, image2, factor):
  1136. h = max(image1.shape[0], image2.shape[0])
  1137. w = max(image1.shape[1], image2.shape[1])
  1138. img = np.zeros((h, w, image1.shape[2]), 'float32')
  1139. img[:image1.shape[0], :image1.shape[1], :] = \
  1140. image1.astype('float32') * factor
  1141. img[:image2.shape[0], :image2.shape[1], :] += \
  1142. image2.astype('float32') * (1.0 - factor)
  1143. return img.astype('uint8')
  1144. def __call__(self, sample):
  1145. if not isinstance(sample, Sequence):
  1146. return sample
  1147. assert len(sample) == 2, 'mixup need two samples'
  1148. factor = np.random.beta(self.alpha, self.beta)
  1149. factor = max(0.0, min(1.0, factor))
  1150. if factor >= 1.0:
  1151. return sample[0]
  1152. if factor <= 0.0:
  1153. return sample[1]
  1154. image = self.apply_im(sample[0]['image'], sample[1]['image'], factor)
  1155. result = copy.deepcopy(sample[0])
  1156. result['image'] = image
  1157. # apply bbox and score
  1158. if 'gt_bbox' in sample[0]:
  1159. gt_bbox1 = sample[0]['gt_bbox']
  1160. gt_bbox2 = sample[1]['gt_bbox']
  1161. gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0)
  1162. result['gt_bbox'] = gt_bbox
  1163. if 'gt_poly' in sample[0]:
  1164. gt_poly1 = sample[0]['gt_poly']
  1165. gt_poly2 = sample[1]['gt_poly']
  1166. gt_poly = gt_poly1 + gt_poly2
  1167. result['gt_poly'] = gt_poly
  1168. if 'gt_class' in sample[0]:
  1169. gt_class1 = sample[0]['gt_class']
  1170. gt_class2 = sample[1]['gt_class']
  1171. gt_class = np.concatenate((gt_class1, gt_class2), axis=0)
  1172. result['gt_class'] = gt_class
  1173. gt_score1 = np.ones_like(sample[0]['gt_class'])
  1174. gt_score2 = np.ones_like(sample[1]['gt_class'])
  1175. gt_score = np.concatenate(
  1176. (gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
  1177. result['gt_score'] = gt_score
  1178. if 'is_crowd' in sample[0]:
  1179. is_crowd1 = sample[0]['is_crowd']
  1180. is_crowd2 = sample[1]['is_crowd']
  1181. is_crowd = np.concatenate((is_crowd1, is_crowd2), axis=0)
  1182. result['is_crowd'] = is_crowd
  1183. if 'difficult' in sample[0]:
  1184. is_difficult1 = sample[0]['difficult']
  1185. is_difficult2 = sample[1]['difficult']
  1186. is_difficult = np.concatenate(
  1187. (is_difficult1, is_difficult2), axis=0)
  1188. result['difficult'] = is_difficult
  1189. return result
  1190. class RandomDistort(Transform):
  1191. """
  1192. Random color distortion.
  1193. Args:
  1194. brightness_range(float, optional): Range of brightness distortion. Defaults to .5.
  1195. brightness_prob(float, optional): Probability of brightness distortion. Defaults to .5.
  1196. contrast_range(float, optional): Range of contrast distortion. Defaults to .5.
  1197. contrast_prob(float, optional): Probability of contrast distortion. Defaults to .5.
  1198. saturation_range(float, optional): Range of saturation distortion. Defaults to .5.
  1199. saturation_prob(float, optional): Probability of saturation distortion. Defaults to .5.
  1200. hue_range(float, optional): Range of hue distortion. Defaults to .5.
  1201. hue_prob(float, optional): Probability of hue distortion. Defaults to .5.
  1202. random_apply (bool, optional): whether to apply in random (yolo) or fixed (SSD)
  1203. order. Defaults to True.
  1204. count (int, optional): the number of doing distortion. Defaults to 4.
  1205. shuffle_channel (bool, optional): whether to swap channels randomly. Defaults to False.
  1206. """
  1207. def __init__(self,
  1208. brightness_range=0.5,
  1209. brightness_prob=0.5,
  1210. contrast_range=0.5,
  1211. contrast_prob=0.5,
  1212. saturation_range=0.5,
  1213. saturation_prob=0.5,
  1214. hue_range=18,
  1215. hue_prob=0.5,
  1216. random_apply=True,
  1217. count=4,
  1218. shuffle_channel=False):
  1219. super(RandomDistort, self).__init__()
  1220. self.brightness_range = [1 - brightness_range, 1 + brightness_range]
  1221. self.brightness_prob = brightness_prob
  1222. self.contrast_range = [1 - contrast_range, 1 + contrast_range]
  1223. self.contrast_prob = contrast_prob
  1224. self.saturation_range = [1 - saturation_range, 1 + saturation_range]
  1225. self.saturation_prob = saturation_prob
  1226. self.hue_range = [1 - hue_range, 1 + hue_range]
  1227. self.hue_prob = hue_prob
  1228. self.random_apply = random_apply
  1229. self.count = count
  1230. self.shuffle_channel = shuffle_channel
  1231. def apply_hue(self, image):
  1232. low, high = self.hue_range
  1233. if np.random.uniform(0., 1.) < self.hue_prob:
  1234. return image
  1235. # it works, but result differ from HSV version
  1236. delta = np.random.uniform(low, high)
  1237. u = np.cos(delta * np.pi)
  1238. w = np.sin(delta * np.pi)
  1239. bt = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
  1240. tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.321],
  1241. [0.211, -0.523, 0.311]])
  1242. ityiq = np.array([[1.0, 0.956, 0.621], [1.0, -0.272, -0.647],
  1243. [1.0, -1.107, 1.705]])
  1244. t = np.dot(np.dot(ityiq, bt), tyiq).T
  1245. res_list = []
  1246. channel = image.shape[2]
  1247. for i in range(channel // 3):
  1248. sub_img = image[:, :, 3 * i:3 * (i + 1)]
  1249. sub_img = sub_img.astype(np.float32)
  1250. sub_img = np.dot(image, t)
  1251. res_list.append(sub_img)
  1252. if channel % 3 != 0:
  1253. i = channel % 3
  1254. res_list.append(image[:, :, -i:])
  1255. return np.concatenate(res_list, axis=2)
  1256. def apply_saturation(self, image):
  1257. low, high = self.saturation_range
  1258. delta = np.random.uniform(low, high)
  1259. if np.random.uniform(0., 1.) < self.saturation_prob:
  1260. return image
  1261. res_list = []
  1262. channel = image.shape[2]
  1263. for i in range(channel // 3):
  1264. sub_img = image[:, :, 3 * i:3 * (i + 1)]
  1265. sub_img = sub_img.astype(np.float32)
  1266. # it works, but result differ from HSV version
  1267. gray = sub_img * np.array(
  1268. [[[0.299, 0.587, 0.114]]], dtype=np.float32)
  1269. gray = gray.sum(axis=2, keepdims=True)
  1270. gray *= (1.0 - delta)
  1271. sub_img *= delta
  1272. sub_img += gray
  1273. res_list.append(sub_img)
  1274. if channel % 3 != 0:
  1275. i = channel % 3
  1276. res_list.append(image[:, :, -i:])
  1277. return np.concatenate(res_list, axis=2)
  1278. def apply_contrast(self, image):
  1279. low, high = self.contrast_range
  1280. if np.random.uniform(0., 1.) < self.contrast_prob:
  1281. return image
  1282. delta = np.random.uniform(low, high)
  1283. image = image.astype(np.float32)
  1284. image *= delta
  1285. return image
  1286. def apply_brightness(self, image):
  1287. low, high = self.brightness_range
  1288. if np.random.uniform(0., 1.) < self.brightness_prob:
  1289. return image
  1290. delta = np.random.uniform(low, high)
  1291. image = image.astype(np.float32)
  1292. image += delta
  1293. return image
  1294. def apply(self, sample):
  1295. if self.random_apply:
  1296. functions = [
  1297. self.apply_brightness, self.apply_contrast,
  1298. self.apply_saturation, self.apply_hue
  1299. ]
  1300. distortions = np.random.permutation(functions)[:self.count]
  1301. for func in distortions:
  1302. sample['image'] = func(sample['image'])
  1303. if 'image2' in sample:
  1304. sample['image2'] = func(sample['image2'])
  1305. return sample
  1306. sample['image'] = self.apply_brightness(sample['image'])
  1307. if 'image2' in sample:
  1308. sample['image2'] = self.apply_brightness(sample['image2'])
  1309. mode = np.random.randint(0, 2)
  1310. if mode:
  1311. sample['image'] = self.apply_contrast(sample['image'])
  1312. if 'image2' in sample:
  1313. sample['image2'] = self.apply_contrast(sample['image2'])
  1314. sample['image'] = self.apply_saturation(sample['image'])
  1315. sample['image'] = self.apply_hue(sample['image'])
  1316. if 'image2' in sample:
  1317. sample['image2'] = self.apply_saturation(sample['image2'])
  1318. sample['image2'] = self.apply_hue(sample['image2'])
  1319. if not mode:
  1320. sample['image'] = self.apply_contrast(sample['image'])
  1321. if 'image2' in sample:
  1322. sample['image2'] = self.apply_contrast(sample['image2'])
  1323. if self.shuffle_channel:
  1324. if np.random.randint(0, 2):
  1325. sample['image'] = sample['image'][..., np.random.permutation(3)]
  1326. if 'image2' in sample:
  1327. sample['image2'] = sample['image2'][
  1328. ..., np.random.permutation(3)]
  1329. return sample
  1330. class RandomBlur(Transform):
  1331. """
  1332. Randomly blur input image(s).
  1333. Args:
  1334. prob (float): Probability of blurring.
  1335. """
  1336. def __init__(self, prob=0.1):
  1337. super(RandomBlur, self).__init__()
  1338. self.prob = prob
  1339. def apply_im(self, image, radius):
  1340. image = cv2.GaussianBlur(image, (radius, radius), 0, 0)
  1341. return image
  1342. def apply(self, sample):
  1343. if self.prob <= 0:
  1344. n = 0
  1345. elif self.prob >= 1:
  1346. n = 1
  1347. else:
  1348. n = int(1.0 / self.prob)
  1349. if n > 0:
  1350. if np.random.randint(0, n) == 0:
  1351. radius = np.random.randint(3, 10)
  1352. if radius % 2 != 1:
  1353. radius = radius + 1
  1354. if radius > 9:
  1355. radius = 9
  1356. sample['image'] = self.apply_im(sample['image'], radius)
  1357. if 'image2' in sample:
  1358. sample['image2'] = self.apply_im(sample['image2'], radius)
  1359. return sample
  1360. class Dehaze(Transform):
  1361. """
  1362. Dehaze input image(s).
  1363. Args:
  1364. gamma (bool, optional): Use gamma correction or not. Defaults to False.
  1365. """
  1366. def __init__(self, gamma=False):
  1367. super(Dehaze, self).__init__()
  1368. self.gamma = gamma
  1369. def apply_im(self, image):
  1370. image = dehaze(image, self.gamma)
  1371. return image
  1372. def apply(self, sample):
  1373. sample['image'] = self.apply_im(sample['image'])
  1374. if 'image2' in sample:
  1375. sample['image2'] = self.apply_im(sample['image2'])
  1376. return sample
  1377. class ReduceDim(Transform):
  1378. """
  1379. Use PCA to reduce the dimension of input image(s).
  1380. Args:
  1381. joblib_path (str): Path of *.joblib file of PCA.
  1382. """
  1383. def __init__(self, joblib_path):
  1384. super(ReduceDim, self).__init__()
  1385. ext = joblib_path.split(".")[-1]
  1386. if ext != "joblib":
  1387. raise ValueError("`joblib_path` must be *.joblib, not *.{}.".format(
  1388. ext))
  1389. self.pca = load(joblib_path)
  1390. def apply_im(self, image):
  1391. H, W, C = image.shape
  1392. n_im = np.reshape(image, (-1, C))
  1393. im_pca = self.pca.transform(n_im)
  1394. result = np.reshape(im_pca, (H, W, -1))
  1395. return result
  1396. def apply(self, sample):
  1397. sample['image'] = self.apply_im(sample['image'])
  1398. if 'image2' in sample:
  1399. sample['image2'] = self.apply_im(sample['image2'])
  1400. return sample
  1401. class SelectBand(Transform):
  1402. """
  1403. Select a set of bands of input image(s).
  1404. Args:
  1405. band_list (list, optional): Bands to select (the band index starts with 1). Defaults to [1, 2, 3].
  1406. """
  1407. def __init__(self, band_list=[1, 2, 3]):
  1408. super(SelectBand, self).__init__()
  1409. self.band_list = band_list
  1410. def apply_im(self, image):
  1411. image = select_bands(image, self.band_list)
  1412. return image
  1413. def apply(self, sample):
  1414. sample['image'] = self.apply_im(sample['image'])
  1415. if 'image2' in sample:
  1416. sample['image2'] = self.apply_im(sample['image2'])
  1417. return sample
  1418. class _PadBox(Transform):
  1419. def __init__(self, num_max_boxes=50):
  1420. """
  1421. Pad zeros to bboxes if number of bboxes is less than num_max_boxes.
  1422. Args:
  1423. num_max_boxes (int, optional): the max number of bboxes. Defaults to 50.
  1424. """
  1425. self.num_max_boxes = num_max_boxes
  1426. super(_PadBox, self).__init__()
  1427. def apply(self, sample):
  1428. gt_num = min(self.num_max_boxes, len(sample['gt_bbox']))
  1429. num_max = self.num_max_boxes
  1430. pad_bbox = np.zeros((num_max, 4), dtype=np.float32)
  1431. if gt_num > 0:
  1432. pad_bbox[:gt_num, :] = sample['gt_bbox'][:gt_num, :]
  1433. sample['gt_bbox'] = pad_bbox
  1434. if 'gt_class' in sample:
  1435. pad_class = np.zeros((num_max, ), dtype=np.int32)
  1436. if gt_num > 0:
  1437. pad_class[:gt_num] = sample['gt_class'][:gt_num, 0]
  1438. sample['gt_class'] = pad_class
  1439. if 'gt_score' in sample:
  1440. pad_score = np.zeros((num_max, ), dtype=np.float32)
  1441. if gt_num > 0:
  1442. pad_score[:gt_num] = sample['gt_score'][:gt_num, 0]
  1443. sample['gt_score'] = pad_score
  1444. # in training, for example in op ExpandImage,
  1445. # the bbox and gt_class is expanded, but the difficult is not,
  1446. # so, judging by it's length
  1447. if 'difficult' in sample:
  1448. pad_diff = np.zeros((num_max, ), dtype=np.int32)
  1449. if gt_num > 0:
  1450. pad_diff[:gt_num] = sample['difficult'][:gt_num, 0]
  1451. sample['difficult'] = pad_diff
  1452. if 'is_crowd' in sample:
  1453. pad_crowd = np.zeros((num_max, ), dtype=np.int32)
  1454. if gt_num > 0:
  1455. pad_crowd[:gt_num] = sample['is_crowd'][:gt_num, 0]
  1456. sample['is_crowd'] = pad_crowd
  1457. return sample
  1458. class _NormalizeBox(Transform):
  1459. def __init__(self):
  1460. super(_NormalizeBox, self).__init__()
  1461. def apply(self, sample):
  1462. height, width = sample['image'].shape[:2]
  1463. for i in range(sample['gt_bbox'].shape[0]):
  1464. sample['gt_bbox'][i][0] = sample['gt_bbox'][i][0] / width
  1465. sample['gt_bbox'][i][1] = sample['gt_bbox'][i][1] / height
  1466. sample['gt_bbox'][i][2] = sample['gt_bbox'][i][2] / width
  1467. sample['gt_bbox'][i][3] = sample['gt_bbox'][i][3] / height
  1468. return sample
  1469. class _BboxXYXY2XYWH(Transform):
  1470. """
  1471. Convert bbox XYXY format to XYWH format.
  1472. """
  1473. def __init__(self):
  1474. super(_BboxXYXY2XYWH, self).__init__()
  1475. def apply(self, sample):
  1476. bbox = sample['gt_bbox']
  1477. bbox[:, 2:4] = bbox[:, 2:4] - bbox[:, :2]
  1478. bbox[:, :2] = bbox[:, :2] + bbox[:, 2:4] / 2.
  1479. sample['gt_bbox'] = bbox
  1480. return sample
  1481. class _Permute(Transform):
  1482. def __init__(self):
  1483. super(_Permute, self).__init__()
  1484. def apply(self, sample):
  1485. sample['image'] = permute(sample['image'], False)
  1486. if 'image2' in sample:
  1487. sample['image2'] = permute(sample['image2'], False)
  1488. return sample
  1489. class RandomSwap(Transform):
  1490. """
  1491. Randomly swap multi-temporal images.
  1492. Args:
  1493. prob (float, optional): Probability of swapping the input images. Default: 0.2.
  1494. """
  1495. def __init__(self, prob=0.2):
  1496. super(RandomSwap, self).__init__()
  1497. self.prob = prob
  1498. def apply(self, sample):
  1499. if 'image2' not in sample:
  1500. raise ValueError("'image2' is not found in the sample.")
  1501. if random.random() < self.prob:
  1502. sample['image'], sample['image2'] = sample['image2'], sample[
  1503. 'image']
  1504. return sample
  1505. class Arrange(Transform):
  1506. def __init__(self, mode):
  1507. super().__init__()
  1508. if mode not in ['train', 'eval', 'test', 'quant']:
  1509. raise ValueError(
  1510. "`mode` should be defined as one of ['train', 'eval', 'test', 'quant']!"
  1511. )
  1512. self.mode = mode
  1513. class ArrangeSegmenter(Arrange):
  1514. def apply(self, sample):
  1515. if 'mask' in sample:
  1516. mask = sample['mask']
  1517. mask = mask.astype('int64')
  1518. image = permute(sample['image'], False)
  1519. if self.mode == 'train':
  1520. return image, mask
  1521. if self.mode == 'eval':
  1522. return image, mask
  1523. if self.mode == 'test':
  1524. return image,
  1525. class ArrangeChangeDetector(Arrange):
  1526. def apply(self, sample):
  1527. if 'mask' in sample:
  1528. mask = sample['mask']
  1529. mask = mask.astype('int64')
  1530. image_t1 = permute(sample['image'], False)
  1531. image_t2 = permute(sample['image2'], False)
  1532. if self.mode == 'train':
  1533. masks = [mask]
  1534. if 'aux_masks' in sample:
  1535. masks.extend(
  1536. map(methodcaller('astype', 'int64'), sample['aux_masks']))
  1537. return (
  1538. image_t1,
  1539. image_t2, ) + tuple(masks)
  1540. if self.mode == 'eval':
  1541. return image_t1, image_t2, mask
  1542. if self.mode == 'test':
  1543. return image_t1, image_t2,
  1544. class ArrangeClassifier(Arrange):
  1545. def apply(self, sample):
  1546. image = permute(sample['image'], False)
  1547. if self.mode in ['train', 'eval']:
  1548. return image, sample['label']
  1549. else:
  1550. return image
  1551. class ArrangeDetector(Arrange):
  1552. def apply(self, sample):
  1553. if self.mode == 'eval' and 'gt_poly' in sample:
  1554. del sample['gt_poly']
  1555. return sample