operators.py 65 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import copy
  16. import random
  17. from numbers import Number
  18. from functools import partial
  19. from operator import methodcaller
  20. try:
  21. from collections.abc import Sequence
  22. except Exception:
  23. from collections import Sequence
  24. import numpy as np
  25. import cv2
  26. import imghdr
  27. from PIL import Image
  28. from joblib import load
  29. import paddlers
  30. from .functions import normalize, horizontal_flip, permute, vertical_flip, center_crop, is_poly, \
  31. horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly, vertical_flip_rle, crop_poly, \
  32. crop_rle, expand_poly, expand_rle, resize_poly, resize_rle, de_haze, select_bands, \
  33. to_intensity, to_uint8, img_flip, img_simple_rotate
  34. __all__ = [
  35. "Compose",
  36. "ImgDecoder",
  37. "Resize",
  38. "RandomResize",
  39. "ResizeByShort",
  40. "RandomResizeByShort",
  41. "ResizeByLong",
  42. "RandomHorizontalFlip",
  43. "RandomVerticalFlip",
  44. "Normalize",
  45. "CenterCrop",
  46. "RandomCrop",
  47. "RandomScaleAspect",
  48. "RandomExpand",
  49. "Padding",
  50. "MixupImage",
  51. "RandomDistort",
  52. "RandomBlur",
  53. "RandomSwap",
  54. "Defogging",
  55. "DimReducing",
  56. "BandSelecting",
  57. "ArrangeSegmenter",
  58. "ArrangeChangeDetector",
  59. "ArrangeClassifier",
  60. "ArrangeDetector",
  61. "RandomFlipOrRotation",
  62. ]
  63. interp_dict = {
  64. 'NEAREST': cv2.INTER_NEAREST,
  65. 'LINEAR': cv2.INTER_LINEAR,
  66. 'CUBIC': cv2.INTER_CUBIC,
  67. 'AREA': cv2.INTER_AREA,
  68. 'LANCZOS4': cv2.INTER_LANCZOS4
  69. }
  70. class Transform(object):
  71. """
  72. Parent class of all data augmentation operations
  73. """
  74. def __init__(self):
  75. pass
  76. def apply_im(self, image):
  77. pass
  78. def apply_mask(self, mask):
  79. pass
  80. def apply_bbox(self, bbox):
  81. pass
  82. def apply_segm(self, segms):
  83. pass
  84. def apply(self, sample):
  85. if 'image' in sample:
  86. sample['image'] = self.apply_im(sample['image'])
  87. else: # image_tx
  88. sample['image'] = self.apply_im(sample['image_t1'])
  89. sample['image2'] = self.apply_im(sample['image_t2'])
  90. if 'mask' in sample:
  91. sample['mask'] = self.apply_mask(sample['mask'])
  92. if 'gt_bbox' in sample:
  93. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'])
  94. if 'aux_masks' in sample:
  95. sample['aux_masks'] = list(
  96. map(self.apply_mask, sample['aux_masks']))
  97. return sample
  98. def __call__(self, sample):
  99. if isinstance(sample, Sequence):
  100. sample = [self.apply(s) for s in sample]
  101. else:
  102. sample = self.apply(sample)
  103. return sample
  104. class ImgDecoder(Transform):
  105. """
  106. Decode image(s) in input.
  107. Args:
  108. to_rgb (bool, optional): If True, convert input images from BGR format to RGB format. Defaults to True.
  109. """
  110. def __init__(self, to_rgb=True, to_uint8=True):
  111. super(ImgDecoder, self).__init__()
  112. self.to_rgb = to_rgb
  113. self.to_uint8 = to_uint8
  114. def read_img(self, img_path, input_channel=3):
  115. img_format = imghdr.what(img_path)
  116. name, ext = os.path.splitext(img_path)
  117. if img_format == 'tiff' or ext == '.img':
  118. try:
  119. import gdal
  120. except:
  121. try:
  122. from osgeo import gdal
  123. except:
  124. raise Exception(
  125. "Failed to import gdal! You can try use conda to install gdal"
  126. )
  127. six.reraise(*sys.exc_info())
  128. dataset = gdal.Open(img_path)
  129. if dataset == None:
  130. raise Exception('Can not open', img_path)
  131. im_data = dataset.ReadAsArray()
  132. if im_data.ndim == 2:
  133. im_data = to_intensity(im_data) # is read SAR
  134. im_data = im_data[:, :, np.newaxis]
  135. elif im_data.ndim == 3:
  136. im_data = im_data.transpose((1, 2, 0))
  137. return im_data
  138. elif img_format in ['jpeg', 'bmp', 'png', 'jpg']:
  139. if input_channel == 3:
  140. return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
  141. cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR)
  142. else:
  143. return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
  144. cv2.IMREAD_ANYCOLOR)
  145. elif ext == '.npy':
  146. return np.load(img_path)
  147. else:
  148. raise Exception('Image format {} is not supported!'.format(ext))
  149. def apply_im(self, im_path):
  150. if isinstance(im_path, str):
  151. try:
  152. image = self.read_img(im_path)
  153. except:
  154. raise ValueError('Cannot read the image file {}!'.format(
  155. im_path))
  156. else:
  157. image = im_path
  158. if self.to_rgb and image.shape[-1] == 3:
  159. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  160. if self.to_uint8:
  161. image = to_uint8(image)
  162. return image
  163. def apply_mask(self, mask):
  164. try:
  165. mask = np.asarray(Image.open(mask))
  166. except:
  167. raise ValueError("Cannot read the mask file {}!".format(mask))
  168. if len(mask.shape) != 2:
  169. raise Exception(
  170. "Mask should be a 1-channel image, but recevied is a {}-channel image.".
  171. format(mask.shape[2]))
  172. return mask
  173. def apply(self, sample):
  174. """
  175. Args:
  176. sample (dict): Input sample.
  177. Returns:
  178. dict: Decoded sample.
  179. """
  180. if 'image' in sample:
  181. sample['image'] = self.apply_im(sample['image'])
  182. if 'image2' in sample:
  183. sample['image2'] = self.apply_im(sample['image2'])
  184. if 'image_t1' in sample and not 'image' in sample:
  185. if not ('image_t2' in sample and 'image2' not in sample):
  186. raise ValueError
  187. sample['image'] = self.apply_im(sample['image_t1'])
  188. sample['image2'] = self.apply_im(sample['image_t2'])
  189. if 'mask' in sample:
  190. sample['mask'] = self.apply_mask(sample['mask'])
  191. im_height, im_width, _ = sample['image'].shape
  192. se_height, se_width = sample['mask'].shape
  193. if im_height != se_height or im_width != se_width:
  194. raise Exception(
  195. "The height or width of the im is not same as the mask")
  196. if 'aux_masks' in sample:
  197. sample['aux_masks'] = list(
  198. map(self.apply_mask, sample['aux_masks']))
  199. # TODO: check the shape of auxiliary masks
  200. sample['im_shape'] = np.array(
  201. sample['image'].shape[:2], dtype=np.float32)
  202. sample['scale_factor'] = np.array([1., 1.], dtype=np.float32)
  203. return sample
  204. class Compose(Transform):
  205. """
  206. Apply a series of data augmentation to the input.
  207. All input images are in Height-Width-Channel ([H, W, C]) format.
  208. Args:
  209. transforms (List[paddlers.transforms.Transform]): List of data preprocess or augmentations.
  210. Raises:
  211. TypeError: Invalid type of transforms.
  212. ValueError: Invalid length of transforms.
  213. """
  214. def __init__(self, transforms, to_uint8=True):
  215. super(Compose, self).__init__()
  216. if not isinstance(transforms, list):
  217. raise TypeError(
  218. 'Type of transforms is invalid. Must be List, but received is {}'
  219. .format(type(transforms)))
  220. if len(transforms) < 1:
  221. raise ValueError(
  222. 'Length of transforms must not be less than 1, but received is {}'
  223. .format(len(transforms)))
  224. self.transforms = transforms
  225. self.decode_image = ImgDecoder(to_uint8=to_uint8)
  226. self.arrange_outputs = None
  227. self.apply_im_only = False
  228. def __call__(self, sample):
  229. if self.apply_im_only:
  230. if 'mask' in sample:
  231. mask_backup = copy.deepcopy(sample['mask'])
  232. del sample['mask']
  233. if 'aux_masks' in sample:
  234. aux_masks = copy.deepcopy(sample['aux_masks'])
  235. sample = self.decode_image(sample)
  236. for op in self.transforms:
  237. # skip batch transforms amd mixup
  238. if isinstance(op, (paddlers.transforms.BatchRandomResize,
  239. paddlers.transforms.BatchRandomResizeByShort,
  240. MixupImage)):
  241. continue
  242. sample = op(sample)
  243. if self.arrange_outputs is not None:
  244. if self.apply_im_only:
  245. sample['mask'] = mask_backup
  246. if 'aux_masks' in locals():
  247. sample['aux_masks'] = aux_masks
  248. sample = self.arrange_outputs(sample)
  249. return sample
  250. class Resize(Transform):
  251. """
  252. Resize input.
  253. - If target_size is an int, resize the image(s) to (target_size, target_size).
  254. - If target_size is a list or tuple, resize the image(s) to target_size.
  255. Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
  256. Args:
  257. target_size (int, List[int] or Tuple[int]): Target size. If int, the height and width share the same target_size.
  258. Otherwise, target_size represents [target height, target width].
  259. interp ({'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}, optional):
  260. Interpolation method of resize. Defaults to 'LINEAR'.
  261. keep_ratio (bool): the resize scale of width/height is same and width/height after resized is not greater
  262. than target width/height. Defaults to False.
  263. Raises:
  264. TypeError: Invalid type of target_size.
  265. ValueError: Invalid interpolation method.
  266. """
  267. def __init__(self, target_size, interp='LINEAR', keep_ratio=False):
  268. super(Resize, self).__init__()
  269. if not (interp == "RANDOM" or interp in interp_dict):
  270. raise ValueError("interp should be one of {}".format(
  271. interp_dict.keys()))
  272. if isinstance(target_size, int):
  273. target_size = (target_size, target_size)
  274. else:
  275. if not (isinstance(target_size,
  276. (list, tuple)) and len(target_size) == 2):
  277. raise TypeError(
  278. "target_size should be an int or a list of length 2, but received {}".
  279. format(target_size))
  280. # (height, width)
  281. self.target_size = target_size
  282. self.interp = interp
  283. self.keep_ratio = keep_ratio
  284. def apply_im(self, image, interp, target_size):
  285. flag = image.shape[2] == 1
  286. image = cv2.resize(image, target_size, interpolation=interp)
  287. if flag:
  288. image = image[:, :, np.newaxis]
  289. return image
  290. def apply_mask(self, mask, target_size):
  291. mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
  292. return mask
  293. def apply_bbox(self, bbox, scale, target_size):
  294. im_scale_x, im_scale_y = scale
  295. bbox[:, 0::2] *= im_scale_x
  296. bbox[:, 1::2] *= im_scale_y
  297. bbox[:, 0::2] = np.clip(bbox[:, 0::2], 0, target_size[0])
  298. bbox[:, 1::2] = np.clip(bbox[:, 1::2], 0, target_size[1])
  299. return bbox
  300. def apply_segm(self, segms, im_size, scale):
  301. im_h, im_w = im_size
  302. im_scale_x, im_scale_y = scale
  303. resized_segms = []
  304. for segm in segms:
  305. if is_poly(segm):
  306. # Polygon format
  307. resized_segms.append([
  308. resize_poly(poly, im_scale_x, im_scale_y) for poly in segm
  309. ])
  310. else:
  311. # RLE format
  312. resized_segms.append(
  313. resize_rle(segm, im_h, im_w, im_scale_x, im_scale_y))
  314. return resized_segms
  315. def apply(self, sample):
  316. if self.interp == "RANDOM":
  317. interp = random.choice(list(interp_dict.values()))
  318. else:
  319. interp = interp_dict[self.interp]
  320. im_h, im_w = sample['image'].shape[:2]
  321. im_scale_y = self.target_size[0] / im_h
  322. im_scale_x = self.target_size[1] / im_w
  323. target_size = (self.target_size[1], self.target_size[0])
  324. if self.keep_ratio:
  325. scale = min(im_scale_y, im_scale_x)
  326. target_w = int(round(im_w * scale))
  327. target_h = int(round(im_h * scale))
  328. target_size = (target_w, target_h)
  329. im_scale_y = target_h / im_h
  330. im_scale_x = target_w / im_w
  331. sample['image'] = self.apply_im(sample['image'], interp, target_size)
  332. if 'image2' in sample:
  333. sample['image2'] = self.apply_im(sample['image2'], interp,
  334. target_size)
  335. if 'mask' in sample:
  336. sample['mask'] = self.apply_mask(sample['mask'], target_size)
  337. if 'aux_masks' in sample:
  338. sample['aux_masks'] = list(
  339. map(partial(
  340. self.apply_mask, target_size=target_size),
  341. sample['aux_masks']))
  342. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  343. sample['gt_bbox'] = self.apply_bbox(
  344. sample['gt_bbox'], [im_scale_x, im_scale_y], target_size)
  345. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  346. sample['gt_poly'] = self.apply_segm(
  347. sample['gt_poly'], [im_h, im_w], [im_scale_x, im_scale_y])
  348. sample['im_shape'] = np.asarray(
  349. sample['image'].shape[:2], dtype=np.float32)
  350. if 'scale_factor' in sample:
  351. scale_factor = sample['scale_factor']
  352. sample['scale_factor'] = np.asarray(
  353. [scale_factor[0] * im_scale_y, scale_factor[1] * im_scale_x],
  354. dtype=np.float32)
  355. return sample
  356. class RandomResize(Transform):
  357. """
  358. Resize input to random sizes.
  359. Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
  360. Args:
  361. target_sizes (List[int], List[list or tuple] or Tuple[list or tuple]):
  362. Multiple target sizes, each target size is an int or list/tuple.
  363. interp ({'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}, optional):
  364. Interpolation method of resize. Defaults to 'LINEAR'.
  365. Raises:
  366. TypeError: Invalid type of target_size.
  367. ValueError: Invalid interpolation method.
  368. See Also:
  369. Resize input to a specific size.
  370. """
  371. def __init__(self, target_sizes, interp='LINEAR'):
  372. super(RandomResize, self).__init__()
  373. if not (interp == "RANDOM" or interp in interp_dict):
  374. raise ValueError("interp should be one of {}".format(
  375. interp_dict.keys()))
  376. self.interp = interp
  377. assert isinstance(target_sizes, list), \
  378. "target_size must be List"
  379. for i, item in enumerate(target_sizes):
  380. if isinstance(item, int):
  381. target_sizes[i] = (item, item)
  382. self.target_size = target_sizes
  383. def apply(self, sample):
  384. height, width = random.choice(self.target_size)
  385. resizer = Resize((height, width), interp=self.interp)
  386. sample = resizer(sample)
  387. return sample
  388. class ResizeByShort(Transform):
  389. """
  390. Resize input with keeping the aspect ratio.
  391. Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
  392. Args:
  393. short_size (int): Target size of the shorter side of the image(s).
  394. max_size (int, optional): The upper bound of longer side of the image(s). If max_size is -1, no upper bound is applied. Defaults to -1.
  395. interp ({'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}, optional): Interpolation method of resize. Defaults to 'LINEAR'.
  396. Raises:
  397. ValueError: Invalid interpolation method.
  398. """
  399. def __init__(self, short_size=256, max_size=-1, interp='LINEAR'):
  400. if not (interp == "RANDOM" or interp in interp_dict):
  401. raise ValueError("interp should be one of {}".format(
  402. interp_dict.keys()))
  403. super(ResizeByShort, self).__init__()
  404. self.short_size = short_size
  405. self.max_size = max_size
  406. self.interp = interp
  407. def apply(self, sample):
  408. im_h, im_w = sample['image'].shape[:2]
  409. im_short_size = min(im_h, im_w)
  410. im_long_size = max(im_h, im_w)
  411. scale = float(self.short_size) / float(im_short_size)
  412. if 0 < self.max_size < np.round(scale * im_long_size):
  413. scale = float(self.max_size) / float(im_long_size)
  414. target_w = int(round(im_w * scale))
  415. target_h = int(round(im_h * scale))
  416. sample = Resize(
  417. target_size=(target_h, target_w), interp=self.interp)(sample)
  418. return sample
  419. class RandomResizeByShort(Transform):
  420. """
  421. Resize input to random sizes with keeping the aspect ratio.
  422. Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
  423. Args:
  424. short_sizes (List[int]): Target size of the shorter side of the image(s).
  425. max_size (int, optional): The upper bound of longer side of the image(s). If max_size is -1, no upper bound is applied. Defaults to -1.
  426. interp ({'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}, optional): Interpolation method of resize. Defaults to 'LINEAR'.
  427. Raises:
  428. TypeError: Invalid type of target_size.
  429. ValueError: Invalid interpolation method.
  430. See Also:
  431. ResizeByShort: Resize image(s) in input with keeping the aspect ratio.
  432. """
  433. def __init__(self, short_sizes, max_size=-1, interp='LINEAR'):
  434. super(RandomResizeByShort, self).__init__()
  435. if not (interp == "RANDOM" or interp in interp_dict):
  436. raise ValueError("interp should be one of {}".format(
  437. interp_dict.keys()))
  438. self.interp = interp
  439. assert isinstance(short_sizes, list), \
  440. "short_sizes must be List"
  441. self.short_sizes = short_sizes
  442. self.max_size = max_size
  443. def apply(self, sample):
  444. short_size = random.choice(self.short_sizes)
  445. resizer = ResizeByShort(
  446. short_size=short_size, max_size=self.max_size, interp=self.interp)
  447. sample = resizer(sample)
  448. return sample
  449. class ResizeByLong(Transform):
  450. def __init__(self, long_size=256, interp='LINEAR'):
  451. super(ResizeByLong, self).__init__()
  452. self.long_size = long_size
  453. self.interp = interp
  454. def apply(self, sample):
  455. im_h, im_w = sample['image'].shape[:2]
  456. im_long_size = max(im_h, im_w)
  457. scale = float(self.long_size) / float(im_long_size)
  458. target_h = int(round(im_h * scale))
  459. target_w = int(round(im_w * scale))
  460. sample = Resize(
  461. target_size=(target_h, target_w), interp=self.interp)(sample)
  462. return sample
  463. class RandomFlipOrRotation(Transform):
  464. """
  465. Flip or Rotate an image in different ways with a certain probability.
  466. Args:
  467. probs (list of float): Probabilities of flipping and rotation. Default: [0.35,0.25].
  468. probsf (list of float): Probabilities of 5 flipping mode
  469. (horizontal, vertical, both horizontal diction and vertical, diagonal, anti-diagonal).
  470. Default: [0.3, 0.3, 0.2, 0.1, 0.1].
  471. probsr (list of float): Probabilities of 3 rotation mode(90°, 180°, 270° clockwise). Default: [0.25,0.5,0.25].
  472. Examples:
  473. from paddlers import transforms as T
  474. # 定义数据增强
  475. train_transforms = T.Compose([
  476. T.RandomFlipOrRotation(
  477. probs = [0.3, 0.2] # 进行flip增强的概率是0.3,进行rotate增强的概率是0.2,不变的概率是0.5
  478. probsf = [0.3, 0.25, 0, 0, 0] # flip增强时,使用水平flip、垂直flip的概率分别是0.3、0.25,水平且垂直flip、对角线flip、反对角线flip概率均为0,不变的概率是0.45
  479. probsr = [0, 0.65, 0]), # rotate增强时,顺时针旋转90度的概率是0,顺时针旋转180度的概率是0.65,顺时针旋转90度的概率是0,不变的概率是0.35
  480. T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  481. ])
  482. """
  483. def __init__(self,
  484. probs=[0.35, 0.25],
  485. probsf=[0.3, 0.3, 0.2, 0.1, 0.1],
  486. probsr=[0.25, 0.5, 0.25]):
  487. super(RandomFlipOrRotation, self).__init__()
  488. # Change various probabilities into probability intervals, to judge in which mode to flip or rotate
  489. self.probs = [probs[0], probs[0] + probs[1]]
  490. self.probsf = self.get_probs_range(probsf)
  491. self.probsr = self.get_probs_range(probsr)
  492. def apply_im(self, image, mode_id, flip_mode=True):
  493. if flip_mode:
  494. image = img_flip(image, mode_id)
  495. else:
  496. image = img_simple_rotate(image, mode_id)
  497. return image
  498. def apply_mask(self, mask, mode_id, flip_mode=True):
  499. if flip_mode:
  500. mask = img_flip(mask, mode_id)
  501. else:
  502. mask = img_simple_rotate(mask, mode_id)
  503. return mask
  504. def get_probs_range(self, probs):
  505. '''
  506. Change various probabilities into cumulative probabilities
  507. Args:
  508. probs(list of float): probabilities of different mode, shape:[n]
  509. Returns:
  510. probability intervals(list of binary list): shape:[n, 2]
  511. '''
  512. ps = []
  513. last_prob = 0
  514. for prob in probs:
  515. p_s = last_prob
  516. cur_prob = prob / sum(probs)
  517. last_prob += cur_prob
  518. p_e = last_prob
  519. ps.append([p_s, p_e])
  520. return ps
  521. def judge_probs_range(self, p, probs):
  522. '''
  523. Judge whether a probability value falls within the given probability interval
  524. Args:
  525. p(float): probability
  526. probs(list of binary list): probability intervals, shape:[n, 2]
  527. Returns:
  528. mode id(int):the probability interval number where the input probability falls,
  529. if return -1, the image will remain as it is and will not be processed
  530. '''
  531. for id, id_range in enumerate(probs):
  532. if p > id_range[0] and p < id_range[1]:
  533. return id
  534. return -1
  535. def apply(self, sample):
  536. p_m = random.random()
  537. if p_m < self.probs[0]:
  538. mode_p = random.random()
  539. mode_id = self.judge_probs_range(mode_p, self.probsf)
  540. sample['image'] = self.apply_im(sample['image'], mode_id, True)
  541. if 'mask' in sample:
  542. sample['mask'] = self.apply_mask(sample['mask'], mode_id, True)
  543. elif p_m < self.probs[1]:
  544. mode_p = random.random()
  545. mode_id = self.judge_probs_range(mode_p, self.probsr)
  546. sample['image'] = self.apply_im(sample['image'], mode_id, False)
  547. if 'mask' in sample:
  548. sample['mask'] = self.apply_mask(sample['mask'], mode_id, False)
  549. return sample
  550. class RandomHorizontalFlip(Transform):
  551. """
  552. Randomly flip the input horizontally.
  553. Args:
  554. prob(float, optional): Probability of flipping the input. Defaults to .5.
  555. """
  556. def __init__(self, prob=0.5):
  557. super(RandomHorizontalFlip, self).__init__()
  558. self.prob = prob
  559. def apply_im(self, image):
  560. image = horizontal_flip(image)
  561. return image
  562. def apply_mask(self, mask):
  563. mask = horizontal_flip(mask)
  564. return mask
  565. def apply_bbox(self, bbox, width):
  566. oldx1 = bbox[:, 0].copy()
  567. oldx2 = bbox[:, 2].copy()
  568. bbox[:, 0] = width - oldx2
  569. bbox[:, 2] = width - oldx1
  570. return bbox
  571. def apply_segm(self, segms, height, width):
  572. flipped_segms = []
  573. for segm in segms:
  574. if is_poly(segm):
  575. # Polygon format
  576. flipped_segms.append(
  577. [horizontal_flip_poly(poly, width) for poly in segm])
  578. else:
  579. # RLE format
  580. flipped_segms.append(horizontal_flip_rle(segm, height, width))
  581. return flipped_segms
  582. def apply(self, sample):
  583. if random.random() < self.prob:
  584. im_h, im_w = sample['image'].shape[:2]
  585. sample['image'] = self.apply_im(sample['image'])
  586. if 'image2' in sample:
  587. sample['image2'] = self.apply_im(sample['image2'])
  588. if 'mask' in sample:
  589. sample['mask'] = self.apply_mask(sample['mask'])
  590. if 'aux_masks' in sample:
  591. sample['aux_masks'] = list(
  592. map(self.apply_mask, sample['aux_masks']))
  593. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  594. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], im_w)
  595. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  596. sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_h,
  597. im_w)
  598. return sample
  599. class RandomVerticalFlip(Transform):
  600. """
  601. Randomly flip the input vertically.
  602. Args:
  603. prob(float, optional): Probability of flipping the input. Defaults to .5.
  604. """
  605. def __init__(self, prob=0.5):
  606. super(RandomVerticalFlip, self).__init__()
  607. self.prob = prob
  608. def apply_im(self, image):
  609. image = vertical_flip(image)
  610. return image
  611. def apply_mask(self, mask):
  612. mask = vertical_flip(mask)
  613. return mask
  614. def apply_bbox(self, bbox, height):
  615. oldy1 = bbox[:, 1].copy()
  616. oldy2 = bbox[:, 3].copy()
  617. bbox[:, 0] = height - oldy2
  618. bbox[:, 2] = height - oldy1
  619. return bbox
  620. def apply_segm(self, segms, height, width):
  621. flipped_segms = []
  622. for segm in segms:
  623. if is_poly(segm):
  624. # Polygon format
  625. flipped_segms.append(
  626. [vertical_flip_poly(poly, height) for poly in segm])
  627. else:
  628. # RLE format
  629. flipped_segms.append(vertical_flip_rle(segm, height, width))
  630. return flipped_segms
  631. def apply(self, sample):
  632. if random.random() < self.prob:
  633. im_h, im_w = sample['image'].shape[:2]
  634. sample['image'] = self.apply_im(sample['image'])
  635. if 'image2' in sample:
  636. sample['image2'] = self.apply_im(sample['image2'])
  637. if 'mask' in sample:
  638. sample['mask'] = self.apply_mask(sample['mask'])
  639. if 'aux_masks' in sample:
  640. sample['aux_masks'] = list(
  641. map(self.apply_mask, sample['aux_masks']))
  642. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  643. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], im_h)
  644. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  645. sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_h,
  646. im_w)
  647. return sample
  648. class Normalize(Transform):
  649. """
  650. Apply min-max normalization to the image(s) in input.
  651. 1. im = (im - min_value) * 1 / (max_value - min_value)
  652. 2. im = im - mean
  653. 3. im = im / std
  654. Args:
  655. mean(List[float] or Tuple[float], optional): Mean of input image(s). Defaults to [0.485, 0.456, 0.406].
  656. std(List[float] or Tuple[float], optional): Standard deviation of input image(s). Defaults to [0.229, 0.224, 0.225].
  657. min_val(List[float] or Tuple[float], optional): Minimum value of input image(s). Defaults to [0, 0, 0, ].
  658. max_val(List[float] or Tuple[float], optional): Max value of input image(s). Defaults to [255., 255., 255.].
  659. """
  660. def __init__(self,
  661. mean=[0.485, 0.456, 0.406],
  662. std=[0.229, 0.224, 0.225],
  663. min_val=None,
  664. max_val=None):
  665. super(Normalize, self).__init__()
  666. channel = len(mean)
  667. if min_val is None:
  668. min_val = [0] * channel
  669. if max_val is None:
  670. max_val = [255.] * channel
  671. from functools import reduce
  672. if reduce(lambda x, y: x * y, std) == 0:
  673. raise ValueError(
  674. 'Std should not contain 0, but received is {}.'.format(std))
  675. if reduce(lambda x, y: x * y,
  676. [a - b for a, b in zip(max_val, min_val)]) == 0:
  677. raise ValueError(
  678. '(max_val - min_val) should not contain 0, but received is {}.'.
  679. format((np.asarray(max_val) - np.asarray(min_val)).tolist()))
  680. self.mean = mean
  681. self.std = std
  682. self.min_val = min_val
  683. self.max_val = max_val
  684. def apply_im(self, image):
  685. image = image.astype(np.float32)
  686. mean = np.asarray(
  687. self.mean, dtype=np.float32)[np.newaxis, np.newaxis, :]
  688. std = np.asarray(self.std, dtype=np.float32)[np.newaxis, np.newaxis, :]
  689. image = normalize(image, mean, std, self.min_val, self.max_val)
  690. return image
  691. def apply(self, sample):
  692. sample['image'] = self.apply_im(sample['image'])
  693. if 'image2' in sample:
  694. sample['image2'] = self.apply_im(sample['image2'])
  695. return sample
  696. class CenterCrop(Transform):
  697. """
  698. Crop the input at the center.
  699. 1. Locate the center of the image.
  700. 2. Crop the sample.
  701. Args:
  702. crop_size(int, optional): target size of the cropped image(s). Defaults to 224.
  703. """
  704. def __init__(self, crop_size=224):
  705. super(CenterCrop, self).__init__()
  706. self.crop_size = crop_size
  707. def apply_im(self, image):
  708. image = center_crop(image, self.crop_size)
  709. return image
  710. def apply_mask(self, mask):
  711. mask = center_crop(mask, self.crop_size)
  712. return mask
  713. def apply(self, sample):
  714. sample['image'] = self.apply_im(sample['image'])
  715. if 'image2' in sample:
  716. sample['image2'] = self.apply_im(sample['image2'])
  717. if 'mask' in sample:
  718. sample['mask'] = self.apply_mask(sample['mask'])
  719. if 'aux_masks' in sample:
  720. sample['aux_masks'] = list(
  721. map(self.apply_mask, sample['aux_masks']))
  722. return sample
  723. class RandomCrop(Transform):
  724. """
  725. Randomly crop the input.
  726. 1. Compute the height and width of cropped area according to aspect_ratio and scaling.
  727. 2. Locate the upper left corner of cropped area randomly.
  728. 3. Crop the image(s).
  729. 4. Resize the cropped area to crop_size by crop_size.
  730. Args:
  731. crop_size(int, List[int] or Tuple[int]): Target size of the cropped area. If None, the cropped area will not be
  732. resized. Defaults to None.
  733. aspect_ratio (List[float], optional): Aspect ratio of cropped region in [min, max] format. Defaults to [.5, 2.].
  734. thresholds (List[float], optional): Iou thresholds to decide a valid bbox crop.
  735. Defaults to [.0, .1, .3, .5, .7, .9].
  736. scaling (List[float], optional): Ratio between the cropped region and the original image in [min, max] format.
  737. Defaults to [.3, 1.].
  738. num_attempts (int, optional): The number of tries before giving up. Defaults to 50.
  739. allow_no_crop (bool, optional): Whether returning without doing crop is allowed. Defaults to True.
  740. cover_all_box (bool, optional): Whether to ensure all bboxes are covered in the final crop. Defaults to False.
  741. """
  742. def __init__(self,
  743. crop_size=None,
  744. aspect_ratio=[.5, 2.],
  745. thresholds=[.0, .1, .3, .5, .7, .9],
  746. scaling=[.3, 1.],
  747. num_attempts=50,
  748. allow_no_crop=True,
  749. cover_all_box=False):
  750. super(RandomCrop, self).__init__()
  751. self.crop_size = crop_size
  752. self.aspect_ratio = aspect_ratio
  753. self.thresholds = thresholds
  754. self.scaling = scaling
  755. self.num_attempts = num_attempts
  756. self.allow_no_crop = allow_no_crop
  757. self.cover_all_box = cover_all_box
  758. def _generate_crop_info(self, sample):
  759. im_h, im_w = sample['image'].shape[:2]
  760. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  761. thresholds = self.thresholds
  762. if self.allow_no_crop:
  763. thresholds.append('no_crop')
  764. np.random.shuffle(thresholds)
  765. for thresh in thresholds:
  766. if thresh == 'no_crop':
  767. return None
  768. for i in range(self.num_attempts):
  769. crop_box = self._get_crop_box(im_h, im_w)
  770. if crop_box is None:
  771. continue
  772. iou = self._iou_matrix(
  773. sample['gt_bbox'],
  774. np.array(
  775. [crop_box], dtype=np.float32))
  776. if iou.max() < thresh:
  777. continue
  778. if self.cover_all_box and iou.min() < thresh:
  779. continue
  780. cropped_box, valid_ids = self._crop_box_with_center_constraint(
  781. sample['gt_bbox'], np.array(
  782. crop_box, dtype=np.float32))
  783. if valid_ids.size > 0:
  784. return crop_box, cropped_box, valid_ids
  785. else:
  786. for i in range(self.num_attempts):
  787. crop_box = self._get_crop_box(im_h, im_w)
  788. if crop_box is None:
  789. continue
  790. return crop_box, None, None
  791. return None
  792. def _get_crop_box(self, im_h, im_w):
  793. scale = np.random.uniform(*self.scaling)
  794. if self.aspect_ratio is not None:
  795. min_ar, max_ar = self.aspect_ratio
  796. aspect_ratio = np.random.uniform(
  797. max(min_ar, scale**2), min(max_ar, scale**-2))
  798. h_scale = scale / np.sqrt(aspect_ratio)
  799. w_scale = scale * np.sqrt(aspect_ratio)
  800. else:
  801. h_scale = np.random.uniform(*self.scaling)
  802. w_scale = np.random.uniform(*self.scaling)
  803. crop_h = im_h * h_scale
  804. crop_w = im_w * w_scale
  805. if self.aspect_ratio is None:
  806. if crop_h / crop_w < 0.5 or crop_h / crop_w > 2.0:
  807. return None
  808. crop_h = int(crop_h)
  809. crop_w = int(crop_w)
  810. crop_y = np.random.randint(0, im_h - crop_h)
  811. crop_x = np.random.randint(0, im_w - crop_w)
  812. return [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h]
  813. def _iou_matrix(self, a, b):
  814. tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2])
  815. br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
  816. area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2)
  817. area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
  818. area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
  819. area_o = (area_a[:, np.newaxis] + area_b - area_i)
  820. return area_i / (area_o + 1e-10)
  821. def _crop_box_with_center_constraint(self, box, crop):
  822. cropped_box = box.copy()
  823. cropped_box[:, :2] = np.maximum(box[:, :2], crop[:2])
  824. cropped_box[:, 2:] = np.minimum(box[:, 2:], crop[2:])
  825. cropped_box[:, :2] -= crop[:2]
  826. cropped_box[:, 2:] -= crop[:2]
  827. centers = (box[:, :2] + box[:, 2:]) / 2
  828. valid = np.logical_and(crop[:2] <= centers,
  829. centers < crop[2:]).all(axis=1)
  830. valid = np.logical_and(
  831. valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1))
  832. return cropped_box, np.where(valid)[0]
  833. def _crop_segm(self, segms, valid_ids, crop, height, width):
  834. crop_segms = []
  835. for id in valid_ids:
  836. segm = segms[id]
  837. if is_poly(segm):
  838. # Polygon format
  839. crop_segms.append(crop_poly(segm, crop))
  840. else:
  841. # RLE format
  842. crop_segms.append(crop_rle(segm, crop, height, width))
  843. return crop_segms
  844. def apply_im(self, image, crop):
  845. x1, y1, x2, y2 = crop
  846. return image[y1:y2, x1:x2, :]
  847. def apply_mask(self, mask, crop):
  848. x1, y1, x2, y2 = crop
  849. return mask[y1:y2, x1:x2, ...]
  850. def apply(self, sample):
  851. crop_info = self._generate_crop_info(sample)
  852. if crop_info is not None:
  853. crop_box, cropped_box, valid_ids = crop_info
  854. im_h, im_w = sample['image'].shape[:2]
  855. sample['image'] = self.apply_im(sample['image'], crop_box)
  856. if 'image2' in sample:
  857. sample['image2'] = self.apply_im(sample['image2'], crop_box)
  858. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  859. crop_polys = self._crop_segm(
  860. sample['gt_poly'],
  861. valid_ids,
  862. np.array(
  863. crop_box, dtype=np.int64),
  864. im_h,
  865. im_w)
  866. if [] in crop_polys:
  867. delete_id = list()
  868. valid_polys = list()
  869. for idx, poly in enumerate(crop_polys):
  870. if not crop_poly:
  871. delete_id.append(idx)
  872. else:
  873. valid_polys.append(poly)
  874. valid_ids = np.delete(valid_ids, delete_id)
  875. if not valid_polys:
  876. return sample
  877. sample['gt_poly'] = valid_polys
  878. else:
  879. sample['gt_poly'] = crop_polys
  880. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  881. sample['gt_bbox'] = np.take(cropped_box, valid_ids, axis=0)
  882. sample['gt_class'] = np.take(
  883. sample['gt_class'], valid_ids, axis=0)
  884. if 'gt_score' in sample:
  885. sample['gt_score'] = np.take(
  886. sample['gt_score'], valid_ids, axis=0)
  887. if 'is_crowd' in sample:
  888. sample['is_crowd'] = np.take(
  889. sample['is_crowd'], valid_ids, axis=0)
  890. if 'mask' in sample:
  891. sample['mask'] = self.apply_mask(sample['mask'], crop_box)
  892. if 'aux_masks' in sample:
  893. sample['aux_masks'] = list(
  894. map(partial(
  895. self.apply_mask, crop=crop_box),
  896. sample['aux_masks']))
  897. if self.crop_size is not None:
  898. sample = Resize(self.crop_size)(sample)
  899. return sample
  900. class RandomScaleAspect(Transform):
  901. """
  902. Crop input image(s) and resize back to original sizes.
  903. Args:
  904. min_scale (float): Minimum ratio between the cropped region and the original image.
  905. If 0, image(s) will not be cropped. Defaults to .5.
  906. aspect_ratio (float): Aspect ratio of cropped region. Defaults to .33.
  907. """
  908. def __init__(self, min_scale=0.5, aspect_ratio=0.33):
  909. super(RandomScaleAspect, self).__init__()
  910. self.min_scale = min_scale
  911. self.aspect_ratio = aspect_ratio
  912. def apply(self, sample):
  913. if self.min_scale != 0 and self.aspect_ratio != 0:
  914. img_height, img_width = sample['image'].shape[:2]
  915. sample = RandomCrop(
  916. crop_size=(img_height, img_width),
  917. aspect_ratio=[self.aspect_ratio, 1. / self.aspect_ratio],
  918. scaling=[self.min_scale, 1.],
  919. num_attempts=10,
  920. allow_no_crop=False)(sample)
  921. return sample
  922. class RandomExpand(Transform):
  923. """
  924. Randomly expand the input by padding according to random offsets.
  925. Args:
  926. upper_ratio(float, optional): The maximum ratio to which the original image is expanded. Defaults to 4..
  927. prob(float, optional): The probability of apply expanding. Defaults to .5.
  928. im_padding_value(List[float] or Tuple[float], optional): RGB filling value for the image. Defaults to (127.5, 127.5, 127.5).
  929. label_padding_value(int, optional): Filling value for the mask. Defaults to 255.
  930. See Also:
  931. paddlers.transforms.Padding
  932. """
  933. def __init__(self,
  934. upper_ratio=4.,
  935. prob=.5,
  936. im_padding_value=127.5,
  937. label_padding_value=255):
  938. super(RandomExpand, self).__init__()
  939. assert upper_ratio > 1.01, "expand ratio must be larger than 1.01"
  940. self.upper_ratio = upper_ratio
  941. self.prob = prob
  942. assert isinstance(im_padding_value, (Number, Sequence)), \
  943. "fill value must be either float or sequence"
  944. self.im_padding_value = im_padding_value
  945. self.label_padding_value = label_padding_value
  946. def apply(self, sample):
  947. if random.random() < self.prob:
  948. im_h, im_w = sample['image'].shape[:2]
  949. ratio = np.random.uniform(1., self.upper_ratio)
  950. h = int(im_h * ratio)
  951. w = int(im_w * ratio)
  952. if h > im_h and w > im_w:
  953. y = np.random.randint(0, h - im_h)
  954. x = np.random.randint(0, w - im_w)
  955. target_size = (h, w)
  956. offsets = (x, y)
  957. sample = Padding(
  958. target_size=target_size,
  959. pad_mode=-1,
  960. offsets=offsets,
  961. im_padding_value=self.im_padding_value,
  962. label_padding_value=self.label_padding_value)(sample)
  963. return sample
  964. class Padding(Transform):
  965. def __init__(self,
  966. target_size=None,
  967. pad_mode=0,
  968. offsets=None,
  969. im_padding_value=127.5,
  970. label_padding_value=255,
  971. size_divisor=32):
  972. """
  973. Pad image to a specified size or multiple of size_divisor.
  974. Args:
  975. target_size(int, Sequence, optional): Image target size, if None, pad to multiple of size_divisor. Defaults to None.
  976. pad_mode({-1, 0, 1, 2}, optional): Pad mode, currently only supports four modes [-1, 0, 1, 2]. if -1, use specified offsets
  977. if 0, only pad to right and bottom. If 1, pad according to center. If 2, only pad left and top. Defaults to 0.
  978. im_padding_value(Sequence[float]): RGB value of pad area. Defaults to (127.5, 127.5, 127.5).
  979. label_padding_value(int, optional): Filling value for the mask. Defaults to 255.
  980. size_divisor(int): Image width and height after padding is a multiple of coarsest_stride.
  981. """
  982. super(Padding, self).__init__()
  983. if isinstance(target_size, (list, tuple)):
  984. if len(target_size) != 2:
  985. raise ValueError(
  986. '`target_size` should include 2 elements, but it is {}'.
  987. format(target_size))
  988. if isinstance(target_size, int):
  989. target_size = [target_size] * 2
  990. assert pad_mode in [
  991. -1, 0, 1, 2
  992. ], 'currently only supports four modes [-1, 0, 1, 2]'
  993. if pad_mode == -1:
  994. assert offsets, 'if pad_mode is -1, offsets should not be None'
  995. self.target_size = target_size
  996. self.size_divisor = size_divisor
  997. self.pad_mode = pad_mode
  998. self.offsets = offsets
  999. self.im_padding_value = im_padding_value
  1000. self.label_padding_value = label_padding_value
  1001. def apply_im(self, image, offsets, target_size):
  1002. x, y = offsets
  1003. h, w = target_size
  1004. im_h, im_w, channel = image.shape[:3]
  1005. canvas = np.ones((h, w, channel), dtype=np.float32)
  1006. canvas *= np.array(self.im_padding_value, dtype=np.float32)
  1007. canvas[y:y + im_h, x:x + im_w, :] = image.astype(np.float32)
  1008. return canvas
  1009. def apply_mask(self, mask, offsets, target_size):
  1010. x, y = offsets
  1011. im_h, im_w = mask.shape[:2]
  1012. h, w = target_size
  1013. canvas = np.ones((h, w), dtype=np.float32)
  1014. canvas *= np.array(self.label_padding_value, dtype=np.float32)
  1015. canvas[y:y + im_h, x:x + im_w] = mask.astype(np.float32)
  1016. return canvas
  1017. def apply_bbox(self, bbox, offsets):
  1018. return bbox + np.array(offsets * 2, dtype=np.float32)
  1019. def apply_segm(self, segms, offsets, im_size, size):
  1020. x, y = offsets
  1021. height, width = im_size
  1022. h, w = size
  1023. expanded_segms = []
  1024. for segm in segms:
  1025. if is_poly(segm):
  1026. # Polygon format
  1027. expanded_segms.append(
  1028. [expand_poly(poly, x, y) for poly in segm])
  1029. else:
  1030. # RLE format
  1031. expanded_segms.append(
  1032. expand_rle(segm, x, y, height, width, h, w))
  1033. return expanded_segms
  1034. def apply(self, sample):
  1035. im_h, im_w = sample['image'].shape[:2]
  1036. if self.target_size:
  1037. h, w = self.target_size
  1038. assert (
  1039. im_h <= h and im_w <= w
  1040. ), 'target size ({}, {}) cannot be less than image size ({}, {})'\
  1041. .format(h, w, im_h, im_w)
  1042. else:
  1043. h = (np.ceil(im_h / self.size_divisor) *
  1044. self.size_divisor).astype(int)
  1045. w = (np.ceil(im_w / self.size_divisor) *
  1046. self.size_divisor).astype(int)
  1047. if h == im_h and w == im_w:
  1048. return sample
  1049. if self.pad_mode == -1:
  1050. offsets = self.offsets
  1051. elif self.pad_mode == 0:
  1052. offsets = [0, 0]
  1053. elif self.pad_mode == 1:
  1054. offsets = [(w - im_w) // 2, (h - im_h) // 2]
  1055. else:
  1056. offsets = [w - im_w, h - im_h]
  1057. sample['image'] = self.apply_im(sample['image'], offsets, (h, w))
  1058. if 'image2' in sample:
  1059. sample['image2'] = self.apply_im(sample['image2'], offsets, (h, w))
  1060. if 'mask' in sample:
  1061. sample['mask'] = self.apply_mask(sample['mask'], offsets, (h, w))
  1062. if 'aux_masks' in sample:
  1063. sample['aux_masks'] = list(
  1064. map(partial(
  1065. self.apply_mask, offsets=offsets, target_size=(h, w)),
  1066. sample['aux_masks']))
  1067. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  1068. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], offsets)
  1069. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  1070. sample['gt_poly'] = self.apply_segm(
  1071. sample['gt_poly'], offsets, im_size=[im_h, im_w], size=[h, w])
  1072. return sample
  1073. class MixupImage(Transform):
  1074. def __init__(self, alpha=1.5, beta=1.5, mixup_epoch=-1):
  1075. """
  1076. Mixup two images and their gt_bbbox/gt_score.
  1077. Args:
  1078. alpha (float, optional): Alpha parameter of beta distribution. Defaults to 1.5.
  1079. beta (float, optional): Beta parameter of beta distribution. Defaults to 1.5.
  1080. """
  1081. super(MixupImage, self).__init__()
  1082. if alpha <= 0.0:
  1083. raise ValueError("alpha should be positive in {}".format(self))
  1084. if beta <= 0.0:
  1085. raise ValueError("beta should be positive in {}".format(self))
  1086. self.alpha = alpha
  1087. self.beta = beta
  1088. self.mixup_epoch = mixup_epoch
  1089. def apply_im(self, image1, image2, factor):
  1090. h = max(image1.shape[0], image2.shape[0])
  1091. w = max(image1.shape[1], image2.shape[1])
  1092. img = np.zeros((h, w, image1.shape[2]), 'float32')
  1093. img[:image1.shape[0], :image1.shape[1], :] = \
  1094. image1.astype('float32') * factor
  1095. img[:image2.shape[0], :image2.shape[1], :] += \
  1096. image2.astype('float32') * (1.0 - factor)
  1097. return img.astype('uint8')
  1098. def __call__(self, sample):
  1099. if not isinstance(sample, Sequence):
  1100. return sample
  1101. assert len(sample) == 2, 'mixup need two samples'
  1102. factor = np.random.beta(self.alpha, self.beta)
  1103. factor = max(0.0, min(1.0, factor))
  1104. if factor >= 1.0:
  1105. return sample[0]
  1106. if factor <= 0.0:
  1107. return sample[1]
  1108. image = self.apply_im(sample[0]['image'], sample[1]['image'], factor)
  1109. result = copy.deepcopy(sample[0])
  1110. result['image'] = image
  1111. # apply bbox and score
  1112. if 'gt_bbox' in sample[0]:
  1113. gt_bbox1 = sample[0]['gt_bbox']
  1114. gt_bbox2 = sample[1]['gt_bbox']
  1115. gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0)
  1116. result['gt_bbox'] = gt_bbox
  1117. if 'gt_poly' in sample[0]:
  1118. gt_poly1 = sample[0]['gt_poly']
  1119. gt_poly2 = sample[1]['gt_poly']
  1120. gt_poly = gt_poly1 + gt_poly2
  1121. result['gt_poly'] = gt_poly
  1122. if 'gt_class' in sample[0]:
  1123. gt_class1 = sample[0]['gt_class']
  1124. gt_class2 = sample[1]['gt_class']
  1125. gt_class = np.concatenate((gt_class1, gt_class2), axis=0)
  1126. result['gt_class'] = gt_class
  1127. gt_score1 = np.ones_like(sample[0]['gt_class'])
  1128. gt_score2 = np.ones_like(sample[1]['gt_class'])
  1129. gt_score = np.concatenate(
  1130. (gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
  1131. result['gt_score'] = gt_score
  1132. if 'is_crowd' in sample[0]:
  1133. is_crowd1 = sample[0]['is_crowd']
  1134. is_crowd2 = sample[1]['is_crowd']
  1135. is_crowd = np.concatenate((is_crowd1, is_crowd2), axis=0)
  1136. result['is_crowd'] = is_crowd
  1137. if 'difficult' in sample[0]:
  1138. is_difficult1 = sample[0]['difficult']
  1139. is_difficult2 = sample[1]['difficult']
  1140. is_difficult = np.concatenate(
  1141. (is_difficult1, is_difficult2), axis=0)
  1142. result['difficult'] = is_difficult
  1143. return result
  1144. class RandomDistort(Transform):
  1145. """
  1146. Random color distortion.
  1147. Args:
  1148. brightness_range(float, optional): Range of brightness distortion. Defaults to .5.
  1149. brightness_prob(float, optional): Probability of brightness distortion. Defaults to .5.
  1150. contrast_range(float, optional): Range of contrast distortion. Defaults to .5.
  1151. contrast_prob(float, optional): Probability of contrast distortion. Defaults to .5.
  1152. saturation_range(float, optional): Range of saturation distortion. Defaults to .5.
  1153. saturation_prob(float, optional): Probability of saturation distortion. Defaults to .5.
  1154. hue_range(float, optional): Range of hue distortion. Defaults to .5.
  1155. hue_prob(float, optional): Probability of hue distortion. Defaults to .5.
  1156. random_apply (bool, optional): whether to apply in random (yolo) or fixed (SSD)
  1157. order. Defaults to True.
  1158. count (int, optional): the number of doing distortion. Defaults to 4.
  1159. shuffle_channel (bool, optional): whether to swap channels randomly. Defaults to False.
  1160. """
  1161. def __init__(self,
  1162. brightness_range=0.5,
  1163. brightness_prob=0.5,
  1164. contrast_range=0.5,
  1165. contrast_prob=0.5,
  1166. saturation_range=0.5,
  1167. saturation_prob=0.5,
  1168. hue_range=18,
  1169. hue_prob=0.5,
  1170. random_apply=True,
  1171. count=4,
  1172. shuffle_channel=False):
  1173. super(RandomDistort, self).__init__()
  1174. self.brightness_range = [1 - brightness_range, 1 + brightness_range]
  1175. self.brightness_prob = brightness_prob
  1176. self.contrast_range = [1 - contrast_range, 1 + contrast_range]
  1177. self.contrast_prob = contrast_prob
  1178. self.saturation_range = [1 - saturation_range, 1 + saturation_range]
  1179. self.saturation_prob = saturation_prob
  1180. self.hue_range = [1 - hue_range, 1 + hue_range]
  1181. self.hue_prob = hue_prob
  1182. self.random_apply = random_apply
  1183. self.count = count
  1184. self.shuffle_channel = shuffle_channel
  1185. def apply_hue(self, image):
  1186. low, high = self.hue_range
  1187. if np.random.uniform(0., 1.) < self.hue_prob:
  1188. return image
  1189. # it works, but result differ from HSV version
  1190. delta = np.random.uniform(low, high)
  1191. u = np.cos(delta * np.pi)
  1192. w = np.sin(delta * np.pi)
  1193. bt = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
  1194. tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.321],
  1195. [0.211, -0.523, 0.311]])
  1196. ityiq = np.array([[1.0, 0.956, 0.621], [1.0, -0.272, -0.647],
  1197. [1.0, -1.107, 1.705]])
  1198. t = np.dot(np.dot(ityiq, bt), tyiq).T
  1199. res_list = []
  1200. channel = image.shape[2]
  1201. for i in range(channel // 3):
  1202. sub_img = image[:, :, 3 * i:3 * (i + 1)]
  1203. sub_img = sub_img.astype(np.float32)
  1204. sub_img = np.dot(image, t)
  1205. res_list.append(sub_img)
  1206. if channel % 3 != 0:
  1207. i = channel % 3
  1208. res_list.append(image[:, :, -i:])
  1209. return np.concatenate(res_list, axis=2)
  1210. def apply_saturation(self, image):
  1211. low, high = self.saturation_range
  1212. delta = np.random.uniform(low, high)
  1213. if np.random.uniform(0., 1.) < self.saturation_prob:
  1214. return image
  1215. res_list = []
  1216. channel = image.shape[2]
  1217. for i in range(channel // 3):
  1218. sub_img = image[:, :, 3 * i:3 * (i + 1)]
  1219. sub_img = sub_img.astype(np.float32)
  1220. # it works, but result differ from HSV version
  1221. gray = sub_img * np.array(
  1222. [[[0.299, 0.587, 0.114]]], dtype=np.float32)
  1223. gray = gray.sum(axis=2, keepdims=True)
  1224. gray *= (1.0 - delta)
  1225. sub_img *= delta
  1226. sub_img += gray
  1227. res_list.append(sub_img)
  1228. if channel % 3 != 0:
  1229. i = channel % 3
  1230. res_list.append(image[:, :, -i:])
  1231. return np.concatenate(res_list, axis=2)
  1232. def apply_contrast(self, image):
  1233. low, high = self.contrast_range
  1234. if np.random.uniform(0., 1.) < self.contrast_prob:
  1235. return image
  1236. delta = np.random.uniform(low, high)
  1237. image = image.astype(np.float32)
  1238. image *= delta
  1239. return image
  1240. def apply_brightness(self, image):
  1241. low, high = self.brightness_range
  1242. if np.random.uniform(0., 1.) < self.brightness_prob:
  1243. return image
  1244. delta = np.random.uniform(low, high)
  1245. image = image.astype(np.float32)
  1246. image += delta
  1247. return image
  1248. def apply(self, sample):
  1249. if self.random_apply:
  1250. functions = [
  1251. self.apply_brightness, self.apply_contrast,
  1252. self.apply_saturation, self.apply_hue
  1253. ]
  1254. distortions = np.random.permutation(functions)[:self.count]
  1255. for func in distortions:
  1256. sample['image'] = func(sample['image'])
  1257. if 'image2' in sample:
  1258. sample['image2'] = func(sample['image2'])
  1259. return sample
  1260. sample['image'] = self.apply_brightness(sample['image'])
  1261. if 'image2' in sample:
  1262. sample['image2'] = self.apply_brightness(sample['image2'])
  1263. mode = np.random.randint(0, 2)
  1264. if mode:
  1265. sample['image'] = self.apply_contrast(sample['image'])
  1266. if 'image2' in sample:
  1267. sample['image2'] = self.apply_contrast(sample['image2'])
  1268. sample['image'] = self.apply_saturation(sample['image'])
  1269. sample['image'] = self.apply_hue(sample['image'])
  1270. if 'image2' in sample:
  1271. sample['image2'] = self.apply_saturation(sample['image2'])
  1272. sample['image2'] = self.apply_hue(sample['image2'])
  1273. if not mode:
  1274. sample['image'] = self.apply_contrast(sample['image'])
  1275. if 'image2' in sample:
  1276. sample['image2'] = self.apply_contrast(sample['image2'])
  1277. if self.shuffle_channel:
  1278. if np.random.randint(0, 2):
  1279. sample['image'] = sample['image'][..., np.random.permutation(3)]
  1280. if 'image2' in sample:
  1281. sample['image2'] = sample['image2'][
  1282. ..., np.random.permutation(3)]
  1283. return sample
  1284. class RandomBlur(Transform):
  1285. """
  1286. Randomly blur input image(s).
  1287. Args:
  1288. prob (float): Probability of blurring.
  1289. """
  1290. def __init__(self, prob=0.1):
  1291. super(RandomBlur, self).__init__()
  1292. self.prob = prob
  1293. def apply_im(self, image, radius):
  1294. image = cv2.GaussianBlur(image, (radius, radius), 0, 0)
  1295. return image
  1296. def apply(self, sample):
  1297. if self.prob <= 0:
  1298. n = 0
  1299. elif self.prob >= 1:
  1300. n = 1
  1301. else:
  1302. n = int(1.0 / self.prob)
  1303. if n > 0:
  1304. if np.random.randint(0, n) == 0:
  1305. radius = np.random.randint(3, 10)
  1306. if radius % 2 != 1:
  1307. radius = radius + 1
  1308. if radius > 9:
  1309. radius = 9
  1310. sample['image'] = self.apply_im(sample['image'], radius)
  1311. if 'image2' in sample:
  1312. sample['image2'] = self.apply_im(sample['image2'], radius)
  1313. return sample
  1314. class Defogging(Transform):
  1315. """
  1316. Defog input image(s).
  1317. Args:
  1318. gamma (bool, optional): Use gamma correction or not. Defaults to False.
  1319. """
  1320. def __init__(self, gamma=False):
  1321. super(Defogging, self).__init__()
  1322. self.gamma = gamma
  1323. def apply_im(self, image):
  1324. image = de_haze(image, self.gamma)
  1325. return image
  1326. def apply(self, sample):
  1327. sample['image'] = self.apply_im(sample['image'])
  1328. if 'image2' in sample:
  1329. sample['image2'] = self.apply_im(sample['image2'])
  1330. return sample
  1331. class DimReducing(Transform):
  1332. """
  1333. Use PCA to reduce input image(s) dimension.
  1334. Args:
  1335. joblib_path (str): Path of *.joblib about PCA.
  1336. """
  1337. def __init__(self, joblib_path):
  1338. super(DimReducing, self).__init__()
  1339. ext = joblib_path.split(".")[-1]
  1340. if ext != "joblib":
  1341. raise ValueError("`joblib_path` must be *.joblib, not *.{}.".format(ext))
  1342. self.pca = load(joblib_path)
  1343. def apply_im(self, image):
  1344. H, W, C = image.shape
  1345. n_im = np.reshape(image, (-1, C))
  1346. im_pca = self.pca.transform(n_im)
  1347. result = np.reshape(im_pca, (H, W, -1))
  1348. return result
  1349. def apply(self, sample):
  1350. sample['image'] = self.apply_im(sample['image'])
  1351. if 'image2' in sample:
  1352. sample['image2'] = self.apply_im(sample['image2'])
  1353. return sample
  1354. class BandSelecting(Transform):
  1355. """
  1356. Select the band of the input image(s).
  1357. Args:
  1358. band_list (list, optional): Bands of selected (Start with 1). Defaults to [1, 2, 3].
  1359. """
  1360. def __init__(self, band_list=[1, 2, 3]):
  1361. super(BandSelecting, self).__init__()
  1362. self.band_list = band_list
  1363. def apply_im(self, image):
  1364. image = select_bands(image, self.band_list)
  1365. return image
  1366. def apply(self, sample):
  1367. sample['image'] = self.apply_im(sample['image'])
  1368. if 'image2' in sample:
  1369. sample['image2'] = self.apply_im(sample['image2'])
  1370. return sample
  1371. class _PadBox(Transform):
  1372. def __init__(self, num_max_boxes=50):
  1373. """
  1374. Pad zeros to bboxes if number of bboxes is less than num_max_boxes.
  1375. Args:
  1376. num_max_boxes (int, optional): the max number of bboxes. Defaults to 50.
  1377. """
  1378. self.num_max_boxes = num_max_boxes
  1379. super(_PadBox, self).__init__()
  1380. def apply(self, sample):
  1381. gt_num = min(self.num_max_boxes, len(sample['gt_bbox']))
  1382. num_max = self.num_max_boxes
  1383. pad_bbox = np.zeros((num_max, 4), dtype=np.float32)
  1384. if gt_num > 0:
  1385. pad_bbox[:gt_num, :] = sample['gt_bbox'][:gt_num, :]
  1386. sample['gt_bbox'] = pad_bbox
  1387. if 'gt_class' in sample:
  1388. pad_class = np.zeros((num_max, ), dtype=np.int32)
  1389. if gt_num > 0:
  1390. pad_class[:gt_num] = sample['gt_class'][:gt_num, 0]
  1391. sample['gt_class'] = pad_class
  1392. if 'gt_score' in sample:
  1393. pad_score = np.zeros((num_max, ), dtype=np.float32)
  1394. if gt_num > 0:
  1395. pad_score[:gt_num] = sample['gt_score'][:gt_num, 0]
  1396. sample['gt_score'] = pad_score
  1397. # in training, for example in op ExpandImage,
  1398. # the bbox and gt_class is expanded, but the difficult is not,
  1399. # so, judging by it's length
  1400. if 'difficult' in sample:
  1401. pad_diff = np.zeros((num_max, ), dtype=np.int32)
  1402. if gt_num > 0:
  1403. pad_diff[:gt_num] = sample['difficult'][:gt_num, 0]
  1404. sample['difficult'] = pad_diff
  1405. if 'is_crowd' in sample:
  1406. pad_crowd = np.zeros((num_max, ), dtype=np.int32)
  1407. if gt_num > 0:
  1408. pad_crowd[:gt_num] = sample['is_crowd'][:gt_num, 0]
  1409. sample['is_crowd'] = pad_crowd
  1410. return sample
  1411. class _NormalizeBox(Transform):
  1412. def __init__(self):
  1413. super(_NormalizeBox, self).__init__()
  1414. def apply(self, sample):
  1415. height, width = sample['image'].shape[:2]
  1416. for i in range(sample['gt_bbox'].shape[0]):
  1417. sample['gt_bbox'][i][0] = sample['gt_bbox'][i][0] / width
  1418. sample['gt_bbox'][i][1] = sample['gt_bbox'][i][1] / height
  1419. sample['gt_bbox'][i][2] = sample['gt_bbox'][i][2] / width
  1420. sample['gt_bbox'][i][3] = sample['gt_bbox'][i][3] / height
  1421. return sample
  1422. class _BboxXYXY2XYWH(Transform):
  1423. """
  1424. Convert bbox XYXY format to XYWH format.
  1425. """
  1426. def __init__(self):
  1427. super(_BboxXYXY2XYWH, self).__init__()
  1428. def apply(self, sample):
  1429. bbox = sample['gt_bbox']
  1430. bbox[:, 2:4] = bbox[:, 2:4] - bbox[:, :2]
  1431. bbox[:, :2] = bbox[:, :2] + bbox[:, 2:4] / 2.
  1432. sample['gt_bbox'] = bbox
  1433. return sample
  1434. class _Permute(Transform):
  1435. def __init__(self):
  1436. super(_Permute, self).__init__()
  1437. def apply(self, sample):
  1438. sample['image'] = permute(sample['image'], False)
  1439. if 'image2' in sample:
  1440. sample['image2'] = permute(sample['image2'], False)
  1441. return sample
  1442. class RandomSwap(Transform):
  1443. """
  1444. Randomly swap multi-temporal images.
  1445. Args:
  1446. prob (float, optional): Probability of swapping the input images. Default: 0.2.
  1447. """
  1448. def __init__(self, prob=0.2):
  1449. super(RandomSwap, self).__init__()
  1450. self.prob = prob
  1451. def apply(self, sample):
  1452. if 'image2' not in sample:
  1453. raise ValueError('image2 is not found in the sample.')
  1454. if random.random() < self.prob:
  1455. sample['image'], sample['image2'] = sample['image2'], sample[
  1456. 'image']
  1457. return sample
  1458. class ArrangeSegmenter(Transform):
  1459. def __init__(self, mode):
  1460. super(ArrangeSegmenter, self).__init__()
  1461. if mode not in ['train', 'eval', 'test', 'quant']:
  1462. raise ValueError(
  1463. "mode should be defined as one of ['train', 'eval', 'test', 'quant']!"
  1464. )
  1465. self.mode = mode
  1466. def apply(self, sample):
  1467. if 'mask' in sample:
  1468. mask = sample['mask']
  1469. image = permute(sample['image'], False)
  1470. if self.mode == 'train':
  1471. mask = mask.astype('int64')
  1472. return image, mask
  1473. if self.mode == 'eval':
  1474. mask = np.asarray(Image.open(mask))
  1475. mask = mask[np.newaxis, :, :].astype('int64')
  1476. return image, mask
  1477. if self.mode == 'test':
  1478. return image,
  1479. class ArrangeChangeDetector(Transform):
  1480. def __init__(self, mode):
  1481. super(ArrangeChangeDetector, self).__init__()
  1482. if mode not in ['train', 'eval', 'test', 'quant']:
  1483. raise ValueError(
  1484. "mode should be defined as one of ['train', 'eval', 'test', 'quant']!"
  1485. )
  1486. self.mode = mode
  1487. def apply(self, sample):
  1488. if 'mask' in sample:
  1489. mask = sample['mask']
  1490. image_t1 = permute(sample['image'], False)
  1491. image_t2 = permute(sample['image2'], False)
  1492. if self.mode == 'train':
  1493. mask = mask.astype('int64')
  1494. masks = [mask]
  1495. if 'aux_masks' in sample:
  1496. masks.extend(
  1497. map(methodcaller('astype', 'int64'), sample['aux_masks']))
  1498. return (
  1499. image_t1,
  1500. image_t2, ) + tuple(masks)
  1501. if self.mode == 'eval':
  1502. mask = np.asarray(Image.open(mask))
  1503. mask = mask[np.newaxis, :, :].astype('int64')
  1504. return image_t1, image_t2, mask
  1505. if self.mode == 'test':
  1506. return image_t1, image_t2,
  1507. class ArrangeClassifier(Transform):
  1508. def __init__(self, mode):
  1509. super(ArrangeClassifier, self).__init__()
  1510. if mode not in ['train', 'eval', 'test', 'quant']:
  1511. raise ValueError(
  1512. "mode should be defined as one of ['train', 'eval', 'test', 'quant']!"
  1513. )
  1514. self.mode = mode
  1515. def apply(self, sample):
  1516. image = permute(sample['image'], False)
  1517. if self.mode in ['train', 'eval']:
  1518. return image, sample['label']
  1519. else:
  1520. return image
  1521. class ArrangeDetector(Transform):
  1522. def __init__(self, mode):
  1523. super(ArrangeDetector, self).__init__()
  1524. if mode not in ['train', 'eval', 'test', 'quant']:
  1525. raise ValueError(
  1526. "mode should be defined as one of ['train', 'eval', 'test', 'quant']!"
  1527. )
  1528. self.mode = mode
  1529. def apply(self, sample):
  1530. if self.mode == 'eval' and 'gt_poly' in sample:
  1531. del sample['gt_poly']
  1532. return sample