operators.py 77 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import copy
  16. import random
  17. from numbers import Number
  18. from functools import partial
  19. from operator import methodcaller
  20. from collections import OrderedDict
  21. from collections.abc import Sequence
  22. import numpy as np
  23. import cv2
  24. import imghdr
  25. from PIL import Image
  26. from joblib import load
  27. import paddlers
  28. import paddlers.transforms.functions as F
  29. import paddlers.transforms.indices as indices
  30. import paddlers.transforms.satellites as satellites
  31. __all__ = [
  32. "construct_sample",
  33. "construct_sample_from_dict",
  34. "Compose",
  35. "DecodeImg",
  36. "Resize",
  37. "RandomResize",
  38. "ResizeByShort",
  39. "RandomResizeByShort",
  40. "ResizeByLong",
  41. "RandomHorizontalFlip",
  42. "RandomVerticalFlip",
  43. "Normalize",
  44. "CenterCrop",
  45. "RandomCrop",
  46. "RandomScaleAspect",
  47. "RandomExpand",
  48. "Pad",
  49. "MixupImage",
  50. "RandomDistort",
  51. "RandomBlur",
  52. "RandomSwap",
  53. "Dehaze",
  54. "ReduceDim",
  55. "SelectBand",
  56. "RandomFlipOrRotate",
  57. "ReloadMask",
  58. "AppendIndex",
  59. "MatchRadiance",
  60. "ArrangeRestorer",
  61. "ArrangeSegmenter",
  62. "ArrangeChangeDetector",
  63. "ArrangeClassifier",
  64. "ArrangeDetector",
  65. ]
  66. interp_dict = {
  67. 'NEAREST': cv2.INTER_NEAREST,
  68. 'LINEAR': cv2.INTER_LINEAR,
  69. 'CUBIC': cv2.INTER_CUBIC,
  70. 'AREA': cv2.INTER_AREA,
  71. 'LANCZOS4': cv2.INTER_LANCZOS4
  72. }
  73. def construct_sample(**kwargs):
  74. sample = OrderedDict()
  75. for k, v in kwargs.items():
  76. sample[k] = v
  77. if 'trans_info' not in sample:
  78. sample['trans_info'] = []
  79. return sample
  80. def construct_sample_from_dict(dict_like_obj):
  81. return construct_sample(**dict_like_obj)
  82. class Compose(object):
  83. """
  84. Apply a series of data augmentation strategies to the input.
  85. All input images should be in Height-Width-Channel ([H, W, C]) format.
  86. Args:
  87. transforms (list[paddlers.transforms.Transform]): List of data preprocess or
  88. augmentation operators.
  89. Raises:
  90. TypeError: Invalid type of transforms.
  91. ValueError: Invalid length of transforms.
  92. """
  93. def __init__(self, transforms):
  94. super(Compose, self).__init__()
  95. if not isinstance(transforms, list):
  96. raise TypeError(
  97. "Type of transforms is invalid. Must be a list, but received is {}."
  98. .format(type(transforms)))
  99. if len(transforms) < 1:
  100. raise ValueError(
  101. "Length of transforms must not be less than 1, but received is {}."
  102. .format(len(transforms)))
  103. transforms = copy.deepcopy(transforms)
  104. # We will have to do a late binding of `self.arrange`
  105. self.arrange = None
  106. self.transforms = transforms
  107. def __call__(self, sample):
  108. """
  109. This is equivalent to sequentially calling compose_obj.apply_transforms()
  110. and compose_obj.arrange_outputs().
  111. """
  112. if 'trans_info' not in sample:
  113. sample['trans_info'] = []
  114. sample = self.apply_transforms(sample)
  115. trans_info = sample['trans_info']
  116. sample = self.arrange_outputs(sample)
  117. return sample, trans_info
  118. def apply_transforms(self, sample):
  119. for op in self.transforms:
  120. # Skip batch transforms
  121. if getattr(op, 'is_batch_transform', False):
  122. continue
  123. sample = op(sample)
  124. return sample
  125. def arrange_outputs(self, sample):
  126. if self.arrange is not None:
  127. sample = self.arrange(sample)
  128. return sample
  129. class Transform(object):
  130. """
  131. Parent class of all data augmentation operators.
  132. """
  133. def __init__(self):
  134. pass
  135. def apply_im(self, image):
  136. return image
  137. def apply_mask(self, mask):
  138. return mask
  139. def apply_bbox(self, bbox):
  140. return bbox
  141. def apply_segm(self, segms):
  142. return segms
  143. def apply(self, sample):
  144. if 'image' in sample:
  145. sample['image'] = self.apply_im(sample['image'])
  146. else: # image_tx
  147. sample['image'] = self.apply_im(sample['image_t1'])
  148. sample['image2'] = self.apply_im(sample['image_t2'])
  149. if 'mask' in sample:
  150. sample['mask'] = self.apply_mask(sample['mask'])
  151. if 'gt_bbox' in sample:
  152. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'])
  153. if 'aux_masks' in sample:
  154. sample['aux_masks'] = list(
  155. map(self.apply_mask, sample['aux_masks']))
  156. if 'target' in sample:
  157. sample['target'] = self.apply_im(sample['target'])
  158. return sample
  159. def __call__(self, sample):
  160. if isinstance(sample, Sequence):
  161. sample = [self.apply(s) for s in sample]
  162. else:
  163. sample = self.apply(sample)
  164. return sample
  165. def get_attrs_for_serialization(self):
  166. return self.__dict__
  167. class DecodeImg(Transform):
  168. """
  169. Decode image(s) in input.
  170. Args:
  171. to_rgb (bool, optional): If True, convert input image(s) from BGR format to
  172. RGB format. Defaults to True.
  173. to_uint8 (bool, optional): If True, quantize and convert decoded image(s) to
  174. uint8 type. Defaults to True.
  175. decode_bgr (bool, optional): If True, automatically interpret a non-geo image
  176. (e.g., jpeg images) as a BGR image. Defaults to True.
  177. decode_sar (bool, optional): If True, automatically interpret a single-channel
  178. geo image (e.g. geotiff images) as a SAR image, set this argument to
  179. True. Defaults to True.
  180. read_geo_info (bool, optional): If True, read geographical information from
  181. the image. Deafults to False.
  182. use_stretch (bool, optional): Whether to apply 2% linear stretch. Valid only if
  183. `to_uint8` is True. Defaults to False.
  184. """
  185. def __init__(self,
  186. to_rgb=True,
  187. to_uint8=True,
  188. decode_bgr=True,
  189. decode_sar=True,
  190. read_geo_info=False,
  191. use_stretch=False):
  192. super(DecodeImg, self).__init__()
  193. self.to_rgb = to_rgb
  194. self.to_uint8 = to_uint8
  195. self.decode_bgr = decode_bgr
  196. self.decode_sar = decode_sar
  197. self.read_geo_info = read_geo_info
  198. self.use_stretch = use_stretch
  199. def read_img(self, img_path):
  200. img_format = imghdr.what(img_path)
  201. name, ext = os.path.splitext(img_path)
  202. geo_trans, geo_proj = None, None
  203. if img_format == 'tiff' or ext == '.img':
  204. try:
  205. import gdal
  206. except:
  207. try:
  208. from osgeo import gdal
  209. except ImportError:
  210. raise ImportError(
  211. "Failed to import gdal! Please install GDAL library according to the document."
  212. )
  213. dataset = gdal.Open(img_path)
  214. if dataset == None:
  215. raise IOError('Cannot open', img_path)
  216. im_data = dataset.ReadAsArray()
  217. if im_data.ndim == 2 and self.decode_sar:
  218. im_data = F.to_intensity(im_data)
  219. im_data = im_data[:, :, np.newaxis]
  220. else:
  221. if im_data.ndim == 3:
  222. im_data = im_data.transpose((1, 2, 0))
  223. if self.read_geo_info:
  224. geo_trans = dataset.GetGeoTransform()
  225. geo_proj = dataset.GetProjection()
  226. elif img_format in ['jpeg', 'bmp', 'png', 'jpg']:
  227. if self.decode_bgr:
  228. im_data = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
  229. cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR)
  230. else:
  231. im_data = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
  232. cv2.IMREAD_ANYCOLOR)
  233. if self.to_rgb and im_data.shape[-1] == 3:
  234. im_data = cv2.cvtColor(im_data, cv2.COLOR_BGR2RGB)
  235. elif ext == '.npy':
  236. im_data = np.load(img_path)
  237. else:
  238. raise TypeError("Image format {} is not supported!".format(ext))
  239. if self.read_geo_info:
  240. return im_data, geo_trans, geo_proj
  241. else:
  242. return im_data
  243. def apply_im(self, im_path):
  244. if isinstance(im_path, str):
  245. try:
  246. data = self.read_img(im_path)
  247. except:
  248. raise ValueError("Cannot read the image file {}!".format(
  249. im_path))
  250. if self.read_geo_info:
  251. image, geo_trans, geo_proj = data
  252. geo_info_dict = {'geo_trans': geo_trans, 'geo_proj': geo_proj}
  253. else:
  254. image = data
  255. else:
  256. image = im_path
  257. if self.to_uint8:
  258. if self.use_stretch:
  259. image = F.to_uint8(image, norm=False, stretch=True)
  260. else:
  261. image = F.to_uint8(image, norm=True, stretch=False)
  262. if self.read_geo_info:
  263. return image, geo_info_dict
  264. else:
  265. return image
  266. def apply_mask(self, mask):
  267. try:
  268. mask = np.asarray(Image.open(mask))
  269. except:
  270. raise ValueError("Cannot read the mask file {}!".format(mask))
  271. if len(mask.shape) != 2:
  272. raise ValueError(
  273. "Mask should be a 1-channel image, but recevied is a {}-channel image.".
  274. format(mask.shape[2]))
  275. return mask
  276. def apply(self, sample):
  277. """
  278. Args:
  279. sample (dict): Input sample.
  280. Returns:
  281. dict: Sample with decoded images.
  282. """
  283. if 'image' in sample:
  284. if self.read_geo_info:
  285. image, geo_info_dict = self.apply_im(sample['image'])
  286. sample['image'] = image
  287. sample['geo_info_dict'] = geo_info_dict
  288. else:
  289. sample['image'] = self.apply_im(sample['image'])
  290. if 'image2' in sample:
  291. if self.read_geo_info:
  292. image2, geo_info_dict2 = self.apply_im(sample['image2'])
  293. sample['image2'] = image2
  294. sample['geo_info_dict2'] = geo_info_dict2
  295. else:
  296. sample['image2'] = self.apply_im(sample['image2'])
  297. if 'image_t1' in sample and not 'image' in sample:
  298. if not ('image_t2' in sample and 'image2' not in sample):
  299. raise ValueError
  300. if self.read_geo_info:
  301. image, geo_info_dict = self.apply_im(sample['image_t1'])
  302. sample['image'] = image
  303. sample['geo_info_dict'] = geo_info_dict
  304. else:
  305. sample['image'] = self.apply_im(sample['image_t1'])
  306. if self.read_geo_info:
  307. image2, geo_info_dict2 = self.apply_im(sample['image_t2'])
  308. sample['image2'] = image2
  309. sample['geo_info_dict2'] = geo_info_dict2
  310. else:
  311. sample['image2'] = self.apply_im(sample['image_t2'])
  312. if 'mask' in sample:
  313. sample['mask_ori'] = copy.deepcopy(sample['mask'])
  314. sample['mask'] = self.apply_mask(sample['mask'])
  315. im_height, im_width, _ = sample['image'].shape
  316. se_height, se_width = sample['mask'].shape
  317. if im_height != se_height or im_width != se_width:
  318. raise ValueError(
  319. "The height or width of the image is not same as the mask.")
  320. if 'aux_masks' in sample:
  321. sample['aux_masks_ori'] = copy.deepcopy(sample['aux_masks'])
  322. sample['aux_masks'] = list(
  323. map(self.apply_mask, sample['aux_masks']))
  324. # TODO: check the shape of auxiliary masks
  325. if 'target' in sample:
  326. if self.read_geo_info:
  327. target, geo_info_dict = self.apply_im(sample['target'])
  328. sample['target'] = target
  329. sample['geo_info_dict_tar'] = geo_info_dict
  330. else:
  331. sample['target'] = self.apply_im(sample['target'])
  332. sample['im_shape'] = np.array(
  333. sample['image'].shape[:2], dtype=np.float32)
  334. sample['scale_factor'] = np.array([1., 1.], dtype=np.float32)
  335. return sample
  336. class Resize(Transform):
  337. """
  338. Resize input.
  339. - If `target_size` is an int, resize the image(s) to (`target_size`, `target_size`).
  340. - If `target_size` is a list or tuple, resize the image(s) to `target_size`.
  341. Attention: If `interp` is 'RANDOM', the interpolation method will be chosen randomly.
  342. Args:
  343. target_size (int | list[int] | tuple[int]): Target size. If it is an integer, the
  344. target height and width will be both set to `target_size`. Otherwise,
  345. `target_size` represents [target height, target width].
  346. interp (str, optional): Interpolation method for resizing image(s). One of
  347. {'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}.
  348. Defaults to 'LINEAR'.
  349. keep_ratio (bool, optional): If True, the scaling factor of width and height will
  350. be set to same value, and height/width of the resized image will be not
  351. greater than the target width/height. Defaults to False.
  352. Raises:
  353. TypeError: Invalid type of target_size.
  354. ValueError: Invalid interpolation method.
  355. """
  356. def __init__(self, target_size, interp='LINEAR', keep_ratio=False):
  357. super(Resize, self).__init__()
  358. if not (interp == "RANDOM" or interp in interp_dict):
  359. raise ValueError("`interp` should be one of {}.".format(
  360. interp_dict.keys()))
  361. if isinstance(target_size, int):
  362. target_size = (target_size, target_size)
  363. else:
  364. if not (isinstance(target_size,
  365. (list, tuple)) and len(target_size) == 2):
  366. raise TypeError(
  367. "`target_size` should be an int or a list of length 2, but received {}.".
  368. format(target_size))
  369. # (height, width)
  370. self.target_size = target_size
  371. self.interp = interp
  372. self.keep_ratio = keep_ratio
  373. def apply_im(self, image, interp, target_size):
  374. flag = image.shape[2] == 1
  375. image = cv2.resize(image, target_size, interpolation=interp)
  376. if flag:
  377. image = image[:, :, np.newaxis]
  378. return image
  379. def apply_mask(self, mask, target_size):
  380. mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
  381. return mask
  382. def apply_bbox(self, bbox, scale, target_size):
  383. im_scale_x, im_scale_y = scale
  384. bbox[:, 0::2] *= im_scale_x
  385. bbox[:, 1::2] *= im_scale_y
  386. bbox[:, 0::2] = np.clip(bbox[:, 0::2], 0, target_size[0])
  387. bbox[:, 1::2] = np.clip(bbox[:, 1::2], 0, target_size[1])
  388. return bbox
  389. def apply_segm(self, segms, im_size, scale):
  390. im_h, im_w = im_size
  391. im_scale_x, im_scale_y = scale
  392. resized_segms = []
  393. for segm in segms:
  394. if F.is_poly(segm):
  395. # Polygon format
  396. resized_segms.append([
  397. F.resize_poly(poly, im_scale_x, im_scale_y) for poly in segm
  398. ])
  399. else:
  400. # RLE format
  401. resized_segms.append(
  402. F.resize_rle(segm, im_h, im_w, im_scale_x, im_scale_y))
  403. return resized_segms
  404. def apply(self, sample):
  405. sample['trans_info'].append(('resize', sample['image'].shape[0:2]))
  406. if self.interp == "RANDOM":
  407. interp = random.choice(list(interp_dict.values()))
  408. else:
  409. interp = interp_dict[self.interp]
  410. im_h, im_w = sample['image'].shape[:2]
  411. im_scale_y = self.target_size[0] / im_h
  412. im_scale_x = self.target_size[1] / im_w
  413. target_size = (self.target_size[1], self.target_size[0])
  414. if self.keep_ratio:
  415. scale = min(im_scale_y, im_scale_x)
  416. target_w = int(round(im_w * scale))
  417. target_h = int(round(im_h * scale))
  418. target_size = (target_w, target_h)
  419. im_scale_y = target_h / im_h
  420. im_scale_x = target_w / im_w
  421. sample['image'] = self.apply_im(sample['image'], interp, target_size)
  422. if 'image2' in sample:
  423. sample['image2'] = self.apply_im(sample['image2'], interp,
  424. target_size)
  425. if 'mask' in sample:
  426. sample['mask'] = self.apply_mask(sample['mask'], target_size)
  427. if 'aux_masks' in sample:
  428. sample['aux_masks'] = list(
  429. map(partial(
  430. self.apply_mask, target_size=target_size),
  431. sample['aux_masks']))
  432. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  433. sample['gt_bbox'] = self.apply_bbox(
  434. sample['gt_bbox'], [im_scale_x, im_scale_y], target_size)
  435. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  436. sample['gt_poly'] = self.apply_segm(
  437. sample['gt_poly'], [im_h, im_w], [im_scale_x, im_scale_y])
  438. if 'target' in sample:
  439. if 'sr_factor' in sample:
  440. # For SR tasks
  441. sample['target'] = self.apply_im(
  442. sample['target'], interp,
  443. F.calc_hr_shape(target_size, sample['sr_factor']))
  444. else:
  445. # For non-SR tasks
  446. sample['target'] = self.apply_im(sample['target'], interp,
  447. target_size)
  448. sample['im_shape'] = np.asarray(
  449. sample['image'].shape[:2], dtype=np.float32)
  450. if 'scale_factor' in sample:
  451. scale_factor = sample['scale_factor']
  452. sample['scale_factor'] = np.asarray(
  453. [scale_factor[0] * im_scale_y, scale_factor[1] * im_scale_x],
  454. dtype=np.float32)
  455. return sample
  456. class RandomResize(Transform):
  457. """
  458. Resize input to random sizes.
  459. Attention: If `interp` is 'RANDOM', the interpolation method will be chosen randomly.
  460. Args:
  461. target_sizes (list[int] | list[list|tuple] | tuple[list|tuple]):
  462. Multiple target sizes, each of which should be int, list, or tuple.
  463. interp (str, optional): Interpolation method for resizing image(s). One of
  464. {'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}.
  465. Defaults to 'LINEAR'.
  466. Raises:
  467. TypeError: Invalid type of `target_size`.
  468. ValueError: Invalid interpolation method.
  469. """
  470. def __init__(self, target_sizes, interp='LINEAR'):
  471. super(RandomResize, self).__init__()
  472. if not (interp == "RANDOM" or interp in interp_dict):
  473. raise ValueError("`interp` should be one of {}.".format(
  474. interp_dict.keys()))
  475. self.interp = interp
  476. assert isinstance(target_sizes, list), \
  477. "`target_size` must be a list."
  478. for i, item in enumerate(target_sizes):
  479. if isinstance(item, int):
  480. target_sizes[i] = (item, item)
  481. self.target_size = target_sizes
  482. def apply(self, sample):
  483. height, width = random.choice(self.target_size)
  484. resizer = Resize((height, width), interp=self.interp)
  485. sample = resizer(sample)
  486. return sample
  487. class ResizeByShort(Transform):
  488. """
  489. Resize input while keeping the aspect ratio.
  490. Attention: If `interp` is 'RANDOM', the interpolation method will be chosen randomly.
  491. Args:
  492. short_size (int): Target size of the shorter side of the image(s).
  493. max_size (int, optional): Upper bound of longer side of the image(s). If
  494. `max_size` is -1, no upper bound will be applied. Defaults to -1.
  495. interp (str, optional): Interpolation method for resizing image(s). One of
  496. {'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}.
  497. Defaults to 'LINEAR'.
  498. Raises:
  499. ValueError: Invalid interpolation method.
  500. """
  501. def __init__(self, short_size=256, max_size=-1, interp='LINEAR'):
  502. if not (interp == "RANDOM" or interp in interp_dict):
  503. raise ValueError("`interp` should be one of {}".format(
  504. interp_dict.keys()))
  505. super(ResizeByShort, self).__init__()
  506. self.short_size = short_size
  507. self.max_size = max_size
  508. self.interp = interp
  509. def apply(self, sample):
  510. im_h, im_w = sample['image'].shape[:2]
  511. im_short_size = min(im_h, im_w)
  512. im_long_size = max(im_h, im_w)
  513. scale = float(self.short_size) / float(im_short_size)
  514. if 0 < self.max_size < np.round(scale * im_long_size):
  515. scale = float(self.max_size) / float(im_long_size)
  516. target_w = int(round(im_w * scale))
  517. target_h = int(round(im_h * scale))
  518. sample = Resize(
  519. target_size=(target_h, target_w), interp=self.interp)(sample)
  520. return sample
  521. class RandomResizeByShort(Transform):
  522. """
  523. Resize input to random sizes while keeping the aspect ratio.
  524. Attention: If `interp` is 'RANDOM', the interpolation method will be chosen randomly.
  525. Args:
  526. short_sizes (list[int]): Target size of the shorter side of the image(s).
  527. max_size (int, optional): Upper bound of longer side of the image(s).
  528. If `max_size` is -1, no upper bound will be applied. Defaults to -1.
  529. interp (str, optional): Interpolation method for resizing image(s). One of
  530. {'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}.
  531. Defaults to 'LINEAR'.
  532. Raises:
  533. TypeError: Invalid type of `target_size`.
  534. ValueError: Invalid interpolation method.
  535. See Also:
  536. ResizeByShort: Resize image(s) in input while keeping the aspect ratio.
  537. """
  538. def __init__(self, short_sizes, max_size=-1, interp='LINEAR'):
  539. super(RandomResizeByShort, self).__init__()
  540. if not (interp == "RANDOM" or interp in interp_dict):
  541. raise ValueError("`interp` should be one of {}".format(
  542. interp_dict.keys()))
  543. self.interp = interp
  544. assert isinstance(short_sizes, list), \
  545. "`short_sizes` must be a list."
  546. self.short_sizes = short_sizes
  547. self.max_size = max_size
  548. def apply(self, sample):
  549. short_size = random.choice(self.short_sizes)
  550. resizer = ResizeByShort(
  551. short_size=short_size, max_size=self.max_size, interp=self.interp)
  552. sample = resizer(sample)
  553. return sample
  554. class ResizeByLong(Transform):
  555. def __init__(self, long_size=256, interp='LINEAR'):
  556. super(ResizeByLong, self).__init__()
  557. self.long_size = long_size
  558. self.interp = interp
  559. def apply(self, sample):
  560. im_h, im_w = sample['image'].shape[:2]
  561. im_long_size = max(im_h, im_w)
  562. scale = float(self.long_size) / float(im_long_size)
  563. target_h = int(round(im_h * scale))
  564. target_w = int(round(im_w * scale))
  565. sample = Resize(
  566. target_size=(target_h, target_w), interp=self.interp)(sample)
  567. return sample
  568. class RandomFlipOrRotate(Transform):
  569. """
  570. Flip or Rotate an image in different directions with a certain probability.
  571. Args:
  572. probs (list[float]): Probabilities of performing flipping and rotation.
  573. Default: [0.35,0.25].
  574. probsf (list[float]): Probabilities of 5 flipping modes (horizontal,
  575. vertical, both horizontal and vertical, diagonal, anti-diagonal).
  576. Default: [0.3, 0.3, 0.2, 0.1, 0.1].
  577. probsr (list[float]): Probabilities of 3 rotation modes (90°, 180°, 270°
  578. clockwise). Default: [0.25, 0.5, 0.25].
  579. Examples:
  580. from paddlers import transforms as T
  581. # Define operators for data augmentation
  582. train_transforms = T.Compose([
  583. T.DecodeImg(),
  584. T.RandomFlipOrRotate(
  585. probs = [0.3, 0.2] # p=0.3 to flip the image, p=0.2 to rotate the image, p=0.5 to keep the image unchanged.
  586. probsf = [0.3, 0.25, 0, 0, 0] # p=0.3 and p=0.25 to perform horizontal and vertical flipping; probility of no-flipping is 0.45.
  587. probsr = [0, 0.65, 0]), # p=0.65 to rotate the image by 180°; probility of no-rotation is 0.35.
  588. T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  589. ])
  590. """
  591. def __init__(self,
  592. probs=[0.35, 0.25],
  593. probsf=[0.3, 0.3, 0.2, 0.1, 0.1],
  594. probsr=[0.25, 0.5, 0.25]):
  595. super(RandomFlipOrRotate, self).__init__()
  596. # Change various probabilities into probability intervals, to judge in which mode to flip or rotate
  597. self.probs = [probs[0], probs[0] + probs[1]]
  598. self.probsf = self.get_probs_range(probsf)
  599. self.probsr = self.get_probs_range(probsr)
  600. def apply_im(self, image, mode_id, flip_mode=True):
  601. if flip_mode:
  602. image = F.img_flip(image, mode_id)
  603. else:
  604. image = F.img_simple_rotate(image, mode_id)
  605. return image
  606. def apply_mask(self, mask, mode_id, flip_mode=True):
  607. if flip_mode:
  608. mask = F.img_flip(mask, mode_id)
  609. else:
  610. mask = F.img_simple_rotate(mask, mode_id)
  611. return mask
  612. def apply_bbox(self, bbox, mode_id, flip_mode=True):
  613. raise TypeError(
  614. "Currently, RandomFlipOrRotate is not available for object detection tasks."
  615. )
  616. def apply_segm(self, bbox, mode_id, flip_mode=True):
  617. raise TypeError(
  618. "Currently, RandomFlipOrRotate is not available for object detection tasks."
  619. )
  620. def get_probs_range(self, probs):
  621. """
  622. Change list of probabilities into cumulative probability intervals.
  623. Args:
  624. probs (list[float]): Probabilities of different modes, shape: [n].
  625. Returns:
  626. list[list]: Probability intervals, shape: [n, 2].
  627. """
  628. ps = []
  629. last_prob = 0
  630. for prob in probs:
  631. p_s = last_prob
  632. cur_prob = prob / sum(probs)
  633. last_prob += cur_prob
  634. p_e = last_prob
  635. ps.append([p_s, p_e])
  636. return ps
  637. def judge_probs_range(self, p, probs):
  638. """
  639. Judge whether the value of `p` falls within the given probability interval.
  640. Args:
  641. p (float): Value between 0 and 1.
  642. probs (list[list]): Probability intervals, shape: [n, 2].
  643. Returns:
  644. int: Interval where the input probability falls into.
  645. """
  646. for id, id_range in enumerate(probs):
  647. if p > id_range[0] and p < id_range[1]:
  648. return id
  649. return -1
  650. def apply(self, sample):
  651. p_m = random.random()
  652. if p_m < self.probs[0]:
  653. mode_p = random.random()
  654. mode_id = self.judge_probs_range(mode_p, self.probsf)
  655. sample['image'] = self.apply_im(sample['image'], mode_id, True)
  656. if 'image2' in sample:
  657. sample['image2'] = self.apply_im(sample['image2'], mode_id,
  658. True)
  659. if 'mask' in sample:
  660. sample['mask'] = self.apply_mask(sample['mask'], mode_id, True)
  661. if 'aux_masks' in sample:
  662. sample['aux_masks'] = [
  663. self.apply_mask(aux_mask, mode_id, True)
  664. for aux_mask in sample['aux_masks']
  665. ]
  666. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  667. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], mode_id,
  668. True)
  669. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  670. sample['gt_poly'] = self.apply_segm(sample['gt_poly'], mode_id,
  671. True)
  672. if 'target' in sample:
  673. sample['target'] = self.apply_im(sample['target'], mode_id,
  674. True)
  675. elif p_m < self.probs[1]:
  676. mode_p = random.random()
  677. mode_id = self.judge_probs_range(mode_p, self.probsr)
  678. sample['image'] = self.apply_im(sample['image'], mode_id, False)
  679. if 'image2' in sample:
  680. sample['image2'] = self.apply_im(sample['image2'], mode_id,
  681. False)
  682. if 'mask' in sample:
  683. sample['mask'] = self.apply_mask(sample['mask'], mode_id, False)
  684. if 'aux_masks' in sample:
  685. sample['aux_masks'] = [
  686. self.apply_mask(aux_mask, mode_id, False)
  687. for aux_mask in sample['aux_masks']
  688. ]
  689. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  690. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], mode_id,
  691. False)
  692. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  693. sample['gt_poly'] = self.apply_segm(sample['gt_poly'], mode_id,
  694. False)
  695. if 'target' in sample:
  696. sample['target'] = self.apply_im(sample['target'], mode_id,
  697. False)
  698. return sample
  699. class RandomHorizontalFlip(Transform):
  700. """
  701. Randomly flip the input horizontally.
  702. Args:
  703. prob (float, optional): Probability of flipping the input. Defaults to .5.
  704. """
  705. def __init__(self, prob=0.5):
  706. super(RandomHorizontalFlip, self).__init__()
  707. self.prob = prob
  708. def apply_im(self, image):
  709. image = F.horizontal_flip(image)
  710. return image
  711. def apply_mask(self, mask):
  712. mask = F.horizontal_flip(mask)
  713. return mask
  714. def apply_bbox(self, bbox, width):
  715. oldx1 = bbox[:, 0].copy()
  716. oldx2 = bbox[:, 2].copy()
  717. bbox[:, 0] = width - oldx2
  718. bbox[:, 2] = width - oldx1
  719. return bbox
  720. def apply_segm(self, segms, height, width):
  721. flipped_segms = []
  722. for segm in segms:
  723. if F.is_poly(segm):
  724. # Polygon format
  725. flipped_segms.append(
  726. [F.horizontal_flip_poly(poly, width) for poly in segm])
  727. else:
  728. # RLE format
  729. flipped_segms.append(F.horizontal_flip_rle(segm, height, width))
  730. return flipped_segms
  731. def apply(self, sample):
  732. if random.random() < self.prob:
  733. im_h, im_w = sample['image'].shape[:2]
  734. sample['image'] = self.apply_im(sample['image'])
  735. if 'image2' in sample:
  736. sample['image2'] = self.apply_im(sample['image2'])
  737. if 'mask' in sample:
  738. sample['mask'] = self.apply_mask(sample['mask'])
  739. if 'aux_masks' in sample:
  740. sample['aux_masks'] = list(
  741. map(self.apply_mask, sample['aux_masks']))
  742. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  743. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], im_w)
  744. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  745. sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_h,
  746. im_w)
  747. if 'target' in sample:
  748. sample['target'] = self.apply_im(sample['target'])
  749. return sample
  750. class RandomVerticalFlip(Transform):
  751. """
  752. Randomly flip the input vertically.
  753. Args:
  754. prob (float, optional): Probability of flipping the input. Defaults to .5.
  755. """
  756. def __init__(self, prob=0.5):
  757. super(RandomVerticalFlip, self).__init__()
  758. self.prob = prob
  759. def apply_im(self, image):
  760. image = F.vertical_flip(image)
  761. return image
  762. def apply_mask(self, mask):
  763. mask = F.vertical_flip(mask)
  764. return mask
  765. def apply_bbox(self, bbox, height):
  766. oldy1 = bbox[:, 1].copy()
  767. oldy2 = bbox[:, 3].copy()
  768. bbox[:, 0] = height - oldy2
  769. bbox[:, 2] = height - oldy1
  770. return bbox
  771. def apply_segm(self, segms, height, width):
  772. flipped_segms = []
  773. for segm in segms:
  774. if F.is_poly(segm):
  775. # Polygon format
  776. flipped_segms.append(
  777. [F.vertical_flip_poly(poly, height) for poly in segm])
  778. else:
  779. # RLE format
  780. flipped_segms.append(F.vertical_flip_rle(segm, height, width))
  781. return flipped_segms
  782. def apply(self, sample):
  783. if random.random() < self.prob:
  784. im_h, im_w = sample['image'].shape[:2]
  785. sample['image'] = self.apply_im(sample['image'])
  786. if 'image2' in sample:
  787. sample['image2'] = self.apply_im(sample['image2'])
  788. if 'mask' in sample:
  789. sample['mask'] = self.apply_mask(sample['mask'])
  790. if 'aux_masks' in sample:
  791. sample['aux_masks'] = list(
  792. map(self.apply_mask, sample['aux_masks']))
  793. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  794. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], im_h)
  795. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  796. sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_h,
  797. im_w)
  798. if 'target' in sample:
  799. sample['target'] = self.apply_im(sample['target'])
  800. return sample
  801. class Normalize(Transform):
  802. """
  803. Apply normalization to the input image(s). The normalization steps are:
  804. 1. im = (im - min_value) * 1 / (max_value - min_value)
  805. 2. im = im - mean
  806. 3. im = im / std
  807. Args:
  808. mean (list[float] | tuple[float], optional): Mean of input image(s).
  809. Defaults to [0.485, 0.456, 0.406].
  810. std (list[float] | tuple[float], optional): Standard deviation of input
  811. image(s). Defaults to [0.229, 0.224, 0.225].
  812. min_val (list[float] | tuple[float], optional): Minimum value of input
  813. image(s). If None, use 0 for all channels. Defaults to None.
  814. max_val (list[float] | tuple[float], optional): Maximum value of input
  815. image(s). If None, use 255. for all channels. Defaults to None.
  816. apply_to_tar (bool, optional): Whether to apply transformation to the target
  817. image. Defaults to True.
  818. """
  819. def __init__(self,
  820. mean=[0.485, 0.456, 0.406],
  821. std=[0.229, 0.224, 0.225],
  822. min_val=None,
  823. max_val=None,
  824. apply_to_tar=True):
  825. super(Normalize, self).__init__()
  826. channel = len(mean)
  827. if min_val is None:
  828. min_val = [0] * channel
  829. if max_val is None:
  830. max_val = [255.] * channel
  831. from functools import reduce
  832. if reduce(lambda x, y: x * y, std) == 0:
  833. raise ValueError(
  834. "`std` should not contain 0, but received is {}.".format(std))
  835. if reduce(lambda x, y: x * y,
  836. [a - b for a, b in zip(max_val, min_val)]) == 0:
  837. raise ValueError(
  838. "(`max_val` - `min_val`) should not contain 0, but received is {}.".
  839. format((np.asarray(max_val) - np.asarray(min_val)).tolist()))
  840. self.mean = mean
  841. self.std = std
  842. self.min_val = min_val
  843. self.max_val = max_val
  844. self.apply_to_tar = apply_to_tar
  845. def apply_im(self, image):
  846. image = image.astype(np.float32)
  847. mean = np.asarray(
  848. self.mean, dtype=np.float32)[np.newaxis, np.newaxis, :]
  849. std = np.asarray(self.std, dtype=np.float32)[np.newaxis, np.newaxis, :]
  850. image = F.normalize(image, mean, std, self.min_val, self.max_val)
  851. return image
  852. def apply(self, sample):
  853. sample['image'] = self.apply_im(sample['image'])
  854. if 'image2' in sample:
  855. sample['image2'] = self.apply_im(sample['image2'])
  856. if 'target' in sample and self.apply_to_tar:
  857. sample['target'] = self.apply_im(sample['target'])
  858. return sample
  859. class CenterCrop(Transform):
  860. """
  861. Crop the input image(s) at the center.
  862. 1. Locate the center of the image.
  863. 2. Crop the image.
  864. Args:
  865. crop_size (int, optional): Target size of the cropped image(s).
  866. Defaults to 224.
  867. """
  868. def __init__(self, crop_size=224):
  869. super(CenterCrop, self).__init__()
  870. self.crop_size = crop_size
  871. def apply_im(self, image):
  872. image = F.center_crop(image, self.crop_size)
  873. return image
  874. def apply_mask(self, mask):
  875. mask = F.center_crop(mask, self.crop_size)
  876. return mask
  877. def apply(self, sample):
  878. sample['image'] = self.apply_im(sample['image'])
  879. if 'image2' in sample:
  880. sample['image2'] = self.apply_im(sample['image2'])
  881. if 'mask' in sample:
  882. sample['mask'] = self.apply_mask(sample['mask'])
  883. if 'aux_masks' in sample:
  884. sample['aux_masks'] = list(
  885. map(self.apply_mask, sample['aux_masks']))
  886. if 'target' in sample:
  887. sample['target'] = self.apply_im(sample['target'])
  888. return sample
  889. class RandomCrop(Transform):
  890. """
  891. Randomly crop the input.
  892. 1. Compute the height and width of cropped area according to `aspect_ratio` and
  893. `scaling`.
  894. 2. Locate the upper left corner of cropped area randomly.
  895. 3. Crop the image(s).
  896. 4. Resize the cropped area to `crop_size` x `crop_size`.
  897. Args:
  898. crop_size (int | list[int] | tuple[int]): Target size of the cropped area. If
  899. None, the cropped area will not be resized. Defaults to None.
  900. aspect_ratio (list[float], optional): Aspect ratio of cropped region in
  901. [min, max] format. Defaults to [.5, 2.].
  902. thresholds (list[float], optional): Iou thresholds to decide a valid bbox
  903. crop. Defaults to [.0, .1, .3, .5, .7, .9].
  904. scaling (list[float], optional): Ratio between the cropped region and the
  905. original image in [min, max] format. Defaults to [.3, 1.].
  906. num_attempts (int, optional): Max number of tries before giving up.
  907. Defaults to 50.
  908. allow_no_crop (bool, optional): Whether returning without doing crop is
  909. allowed. Defaults to True.
  910. cover_all_box (bool, optional): Whether to ensure all bboxes be covered in
  911. the final crop. Defaults to False.
  912. """
  913. def __init__(self,
  914. crop_size=None,
  915. aspect_ratio=[.5, 2.],
  916. thresholds=[.0, .1, .3, .5, .7, .9],
  917. scaling=[.3, 1.],
  918. num_attempts=50,
  919. allow_no_crop=True,
  920. cover_all_box=False):
  921. super(RandomCrop, self).__init__()
  922. self.crop_size = crop_size
  923. self.aspect_ratio = aspect_ratio
  924. self.thresholds = thresholds
  925. self.scaling = scaling
  926. self.num_attempts = num_attempts
  927. self.allow_no_crop = allow_no_crop
  928. self.cover_all_box = cover_all_box
  929. def _generate_crop_info(self, sample):
  930. im_h, im_w = sample['image'].shape[:2]
  931. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  932. thresholds = self.thresholds
  933. if self.allow_no_crop:
  934. thresholds.append('no_crop')
  935. np.random.shuffle(thresholds)
  936. for thresh in thresholds:
  937. if thresh == 'no_crop':
  938. return None
  939. for i in range(self.num_attempts):
  940. crop_box = self._get_crop_box(im_h, im_w)
  941. if crop_box is None:
  942. continue
  943. iou = self._iou_matrix(
  944. sample['gt_bbox'],
  945. np.array(
  946. [crop_box], dtype=np.float32))
  947. if iou.max() < thresh:
  948. continue
  949. if self.cover_all_box and iou.min() < thresh:
  950. continue
  951. cropped_box, valid_ids = self._crop_box_with_center_constraint(
  952. sample['gt_bbox'], np.array(
  953. crop_box, dtype=np.float32))
  954. if valid_ids.size > 0:
  955. return crop_box, cropped_box, valid_ids
  956. else:
  957. for i in range(self.num_attempts):
  958. crop_box = self._get_crop_box(im_h, im_w)
  959. if crop_box is None:
  960. continue
  961. return crop_box, None, None
  962. return None
  963. def _get_crop_box(self, im_h, im_w):
  964. scale = np.random.uniform(*self.scaling)
  965. if self.aspect_ratio is not None:
  966. min_ar, max_ar = self.aspect_ratio
  967. aspect_ratio = np.random.uniform(
  968. max(min_ar, scale**2), min(max_ar, scale**-2))
  969. h_scale = scale / np.sqrt(aspect_ratio)
  970. w_scale = scale * np.sqrt(aspect_ratio)
  971. else:
  972. h_scale = np.random.uniform(*self.scaling)
  973. w_scale = np.random.uniform(*self.scaling)
  974. crop_h = im_h * h_scale
  975. crop_w = im_w * w_scale
  976. if self.aspect_ratio is None:
  977. if crop_h / crop_w < 0.5 or crop_h / crop_w > 2.0:
  978. return None
  979. crop_h = int(crop_h)
  980. crop_w = int(crop_w)
  981. crop_y = np.random.randint(0, im_h - crop_h)
  982. crop_x = np.random.randint(0, im_w - crop_w)
  983. return [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h]
  984. def _iou_matrix(self, a, b):
  985. tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2])
  986. br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
  987. area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2)
  988. area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
  989. area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
  990. area_o = (area_a[:, np.newaxis] + area_b - area_i)
  991. return area_i / (area_o + 1e-10)
  992. def _crop_box_with_center_constraint(self, box, crop):
  993. cropped_box = box.copy()
  994. cropped_box[:, :2] = np.maximum(box[:, :2], crop[:2])
  995. cropped_box[:, 2:] = np.minimum(box[:, 2:], crop[2:])
  996. cropped_box[:, :2] -= crop[:2]
  997. cropped_box[:, 2:] -= crop[:2]
  998. centers = (box[:, :2] + box[:, 2:]) / 2
  999. valid = np.logical_and(crop[:2] <= centers,
  1000. centers < crop[2:]).all(axis=1)
  1001. valid = np.logical_and(
  1002. valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1))
  1003. return cropped_box, np.where(valid)[0]
  1004. def _crop_segm(self, segms, valid_ids, crop, height, width):
  1005. crop_segms = []
  1006. for id in valid_ids:
  1007. segm = segms[id]
  1008. if F.is_poly(segm):
  1009. # Polygon format
  1010. crop_segms.append(F.crop_poly(segm, crop))
  1011. else:
  1012. # RLE format
  1013. crop_segms.append(F.crop_rle(segm, crop, height, width))
  1014. return crop_segms
  1015. def apply_im(self, image, crop):
  1016. x1, y1, x2, y2 = crop
  1017. return image[y1:y2, x1:x2, :]
  1018. def apply_mask(self, mask, crop):
  1019. x1, y1, x2, y2 = crop
  1020. return mask[y1:y2, x1:x2, ...]
  1021. def apply(self, sample):
  1022. crop_info = self._generate_crop_info(sample)
  1023. if crop_info is not None:
  1024. crop_box, cropped_box, valid_ids = crop_info
  1025. im_h, im_w = sample['image'].shape[:2]
  1026. sample['image'] = self.apply_im(sample['image'], crop_box)
  1027. if 'image2' in sample:
  1028. sample['image2'] = self.apply_im(sample['image2'], crop_box)
  1029. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  1030. crop_polys = self._crop_segm(
  1031. sample['gt_poly'],
  1032. valid_ids,
  1033. np.array(
  1034. crop_box, dtype=np.int64),
  1035. im_h,
  1036. im_w)
  1037. if [] in crop_polys:
  1038. delete_id = list()
  1039. valid_polys = list()
  1040. for idx, poly in enumerate(crop_polys):
  1041. if not poly:
  1042. delete_id.append(idx)
  1043. else:
  1044. valid_polys.append(poly)
  1045. valid_ids = np.delete(valid_ids, delete_id)
  1046. if not valid_polys:
  1047. return sample
  1048. sample['gt_poly'] = valid_polys
  1049. else:
  1050. sample['gt_poly'] = crop_polys
  1051. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  1052. sample['gt_bbox'] = np.take(cropped_box, valid_ids, axis=0)
  1053. sample['gt_class'] = np.take(
  1054. sample['gt_class'], valid_ids, axis=0)
  1055. if 'gt_score' in sample:
  1056. sample['gt_score'] = np.take(
  1057. sample['gt_score'], valid_ids, axis=0)
  1058. if 'is_crowd' in sample:
  1059. sample['is_crowd'] = np.take(
  1060. sample['is_crowd'], valid_ids, axis=0)
  1061. if 'mask' in sample:
  1062. sample['mask'] = self.apply_mask(sample['mask'], crop_box)
  1063. if 'aux_masks' in sample:
  1064. sample['aux_masks'] = list(
  1065. map(partial(
  1066. self.apply_mask, crop=crop_box),
  1067. sample['aux_masks']))
  1068. if 'target' in sample:
  1069. if 'sr_factor' in sample:
  1070. sample['target'] = self.apply_im(
  1071. sample['target'],
  1072. F.calc_hr_shape(crop_box, sample['sr_factor']))
  1073. else:
  1074. sample['target'] = self.apply_im(sample['image'], crop_box)
  1075. if self.crop_size is not None:
  1076. sample = Resize(self.crop_size)(sample)
  1077. return sample
  1078. class RandomScaleAspect(Transform):
  1079. """
  1080. Crop input image(s) and resize back to original sizes.
  1081. Args:
  1082. min_scale (float): Minimum ratio between the cropped region and the original
  1083. image. If 0, image(s) will not be cropped. Defaults to .5.
  1084. aspect_ratio (float): Aspect ratio of cropped region. Defaults to .33.
  1085. """
  1086. def __init__(self, min_scale=0.5, aspect_ratio=0.33):
  1087. super(RandomScaleAspect, self).__init__()
  1088. self.min_scale = min_scale
  1089. self.aspect_ratio = aspect_ratio
  1090. def apply(self, sample):
  1091. if self.min_scale != 0 and self.aspect_ratio != 0:
  1092. img_height, img_width = sample['image'].shape[:2]
  1093. sample = RandomCrop(
  1094. crop_size=(img_height, img_width),
  1095. aspect_ratio=[self.aspect_ratio, 1. / self.aspect_ratio],
  1096. scaling=[self.min_scale, 1.],
  1097. num_attempts=10,
  1098. allow_no_crop=False)(sample)
  1099. return sample
  1100. class RandomExpand(Transform):
  1101. """
  1102. Randomly expand the input by padding according to random offsets.
  1103. Args:
  1104. upper_ratio (float, optional): Maximum ratio to which the original image
  1105. is expanded. Defaults to 4..
  1106. prob (float, optional): Probability of expanding. Defaults to .5.
  1107. im_padding_value (list[float] | tuple[float], optional): RGB filling value
  1108. for the image. Defaults to (127.5, 127.5, 127.5).
  1109. label_padding_value (int, optional): Filling value for the mask.
  1110. Defaults to 255.
  1111. See Also:
  1112. paddlers.transforms.Pad
  1113. """
  1114. def __init__(self,
  1115. upper_ratio=4.,
  1116. prob=.5,
  1117. im_padding_value=127.5,
  1118. label_padding_value=255):
  1119. super(RandomExpand, self).__init__()
  1120. assert upper_ratio > 1.01, "`upper_ratio` must be larger than 1.01."
  1121. self.upper_ratio = upper_ratio
  1122. self.prob = prob
  1123. assert isinstance(im_padding_value, (Number, Sequence)), \
  1124. "Value to fill must be either float or sequence."
  1125. self.im_padding_value = im_padding_value
  1126. self.label_padding_value = label_padding_value
  1127. def apply(self, sample):
  1128. if random.random() < self.prob:
  1129. im_h, im_w = sample['image'].shape[:2]
  1130. ratio = np.random.uniform(1., self.upper_ratio)
  1131. h = int(im_h * ratio)
  1132. w = int(im_w * ratio)
  1133. if h > im_h and w > im_w:
  1134. y = np.random.randint(0, h - im_h)
  1135. x = np.random.randint(0, w - im_w)
  1136. target_size = (h, w)
  1137. offsets = (x, y)
  1138. sample = Pad(
  1139. target_size=target_size,
  1140. pad_mode=-1,
  1141. offsets=offsets,
  1142. im_padding_value=self.im_padding_value,
  1143. label_padding_value=self.label_padding_value)(sample)
  1144. return sample
  1145. class Pad(Transform):
  1146. def __init__(self,
  1147. target_size=None,
  1148. pad_mode=0,
  1149. offsets=None,
  1150. im_padding_value=127.5,
  1151. label_padding_value=255,
  1152. size_divisor=32):
  1153. """
  1154. Pad image to a specified size or multiple of `size_divisor`.
  1155. Args:
  1156. target_size (list[int] | tuple[int], optional): Image target size, if None, pad to
  1157. multiple of size_divisor. Defaults to None.
  1158. pad_mode (int, optional): Pad mode. Currently only four modes are supported:
  1159. [-1, 0, 1, 2]. if -1, use specified offsets. If 0, only pad to right and bottom.
  1160. If 1, pad according to center. If 2, only pad left and top. Defaults to 0.
  1161. offsets (list[int]|None, optional): Padding offsets. Defaults to None.
  1162. im_padding_value (list[float] | tuple[float]): RGB value of padded area.
  1163. Defaults to (127.5, 127.5, 127.5).
  1164. label_padding_value (int, optional): Filling value for the mask.
  1165. Defaults to 255.
  1166. size_divisor (int): Image width and height after padding will be a multiple of
  1167. `size_divisor`.
  1168. """
  1169. super(Pad, self).__init__()
  1170. if isinstance(target_size, (list, tuple)):
  1171. if len(target_size) != 2:
  1172. raise ValueError(
  1173. "`target_size` should contain 2 elements, but it is {}.".
  1174. format(target_size))
  1175. if isinstance(target_size, int):
  1176. target_size = [target_size] * 2
  1177. assert pad_mode in [
  1178. -1, 0, 1, 2
  1179. ], "Currently only four modes are supported: [-1, 0, 1, 2]."
  1180. if pad_mode == -1:
  1181. assert offsets, "if `pad_mode` is -1, `offsets` should not be None."
  1182. self.target_size = target_size
  1183. self.size_divisor = size_divisor
  1184. self.pad_mode = pad_mode
  1185. self.offsets = offsets
  1186. self.im_padding_value = im_padding_value
  1187. self.label_padding_value = label_padding_value
  1188. def apply_im(self, image, offsets, target_size):
  1189. x, y = offsets
  1190. h, w = target_size
  1191. im_h, im_w, channel = image.shape[:3]
  1192. canvas = np.ones((h, w, channel), dtype=np.float32)
  1193. canvas *= np.array(self.im_padding_value, dtype=np.float32)
  1194. canvas[y:y + im_h, x:x + im_w, :] = image.astype(np.float32)
  1195. return canvas
  1196. def apply_mask(self, mask, offsets, target_size):
  1197. x, y = offsets
  1198. im_h, im_w = mask.shape[:2]
  1199. h, w = target_size
  1200. canvas = np.ones((h, w), dtype=np.float32)
  1201. canvas *= np.array(self.label_padding_value, dtype=np.float32)
  1202. canvas[y:y + im_h, x:x + im_w] = mask.astype(np.float32)
  1203. return canvas
  1204. def apply_bbox(self, bbox, offsets):
  1205. return bbox + np.array(offsets * 2, dtype=np.float32)
  1206. def apply_segm(self, segms, offsets, im_size, size):
  1207. x, y = offsets
  1208. height, width = im_size
  1209. h, w = size
  1210. expanded_segms = []
  1211. for segm in segms:
  1212. if F.is_poly(segm):
  1213. # Polygon format
  1214. expanded_segms.append(
  1215. [F.expand_poly(poly, x, y) for poly in segm])
  1216. else:
  1217. # RLE format
  1218. expanded_segms.append(
  1219. F.expand_rle(segm, x, y, height, width, h, w))
  1220. return expanded_segms
  1221. def _get_offsets(self, im_h, im_w, h, w):
  1222. if self.pad_mode == -1:
  1223. offsets = self.offsets
  1224. elif self.pad_mode == 0:
  1225. offsets = [0, 0]
  1226. elif self.pad_mode == 1:
  1227. offsets = [(w - im_w) // 2, (h - im_h) // 2]
  1228. else:
  1229. offsets = [w - im_w, h - im_h]
  1230. return offsets
  1231. def apply(self, sample):
  1232. im_h, im_w = sample['image'].shape[:2]
  1233. if self.target_size:
  1234. h, w = self.target_size
  1235. assert (
  1236. im_h <= h and im_w <= w
  1237. ), 'target size ({}, {}) cannot be less than image size ({}, {})' \
  1238. .format(h, w, im_h, im_w)
  1239. else:
  1240. h = (np.ceil(im_h / self.size_divisor) *
  1241. self.size_divisor).astype(int)
  1242. w = (np.ceil(im_w / self.size_divisor) *
  1243. self.size_divisor).astype(int)
  1244. if h == im_h and w == im_w:
  1245. return sample
  1246. offsets = self._get_offsets(im_h, im_w, h, w)
  1247. sample['trans_info'].append(
  1248. ('padding', sample['image'].shape[0:2], offsets))
  1249. sample['image'] = self.apply_im(sample['image'], offsets, (h, w))
  1250. if 'image2' in sample:
  1251. sample['image2'] = self.apply_im(sample['image2'], offsets, (h, w))
  1252. if 'mask' in sample:
  1253. sample['mask'] = self.apply_mask(sample['mask'], offsets, (h, w))
  1254. if 'aux_masks' in sample:
  1255. sample['aux_masks'] = list(
  1256. map(partial(
  1257. self.apply_mask, offsets=offsets, target_size=(h, w)),
  1258. sample['aux_masks']))
  1259. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  1260. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], offsets)
  1261. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  1262. sample['gt_poly'] = self.apply_segm(
  1263. sample['gt_poly'], offsets, im_size=[im_h, im_w], size=[h, w])
  1264. if 'target' in sample:
  1265. if 'sr_factor' in sample:
  1266. hr_shape = F.calc_hr_shape((h, w), sample['sr_factor'])
  1267. hr_offsets = self._get_offsets(*sample['target'].shape[:2],
  1268. *hr_shape)
  1269. sample['target'] = self.apply_im(sample['target'], hr_offsets,
  1270. hr_shape)
  1271. else:
  1272. sample['target'] = self.apply_im(sample['target'], offsets,
  1273. (h, w))
  1274. return sample
  1275. class MixupImage(Transform):
  1276. is_batch_transform = True
  1277. def __init__(self, alpha=1.5, beta=1.5, mixup_epoch=-1):
  1278. """
  1279. Mixup two images and their gt_bbbox/gt_score.
  1280. Args:
  1281. alpha (float, optional): Alpha parameter of beta distribution.
  1282. Defaults to 1.5.
  1283. beta (float, optional): Beta parameter of beta distribution.
  1284. Defaults to 1.5.
  1285. """
  1286. super(MixupImage, self).__init__()
  1287. if alpha <= 0.0:
  1288. raise ValueError("`alpha` should be positive in MixupImage.")
  1289. if beta <= 0.0:
  1290. raise ValueError("`beta` should be positive in MixupImage.")
  1291. self.alpha = alpha
  1292. self.beta = beta
  1293. self.mixup_epoch = mixup_epoch
  1294. def apply_im(self, image1, image2, factor):
  1295. h = max(image1.shape[0], image2.shape[0])
  1296. w = max(image1.shape[1], image2.shape[1])
  1297. img = np.zeros((h, w, image1.shape[2]), 'float32')
  1298. img[:image1.shape[0], :image1.shape[1], :] = \
  1299. image1.astype('float32') * factor
  1300. img[:image2.shape[0], :image2.shape[1], :] += \
  1301. image2.astype('float32') * (1.0 - factor)
  1302. return img.astype('uint8')
  1303. def __call__(self, sample):
  1304. if not isinstance(sample, Sequence):
  1305. return sample
  1306. assert len(sample) == 2, 'mixup need two samples'
  1307. factor = np.random.beta(self.alpha, self.beta)
  1308. factor = max(0.0, min(1.0, factor))
  1309. if factor >= 1.0:
  1310. return sample[0]
  1311. if factor <= 0.0:
  1312. return sample[1]
  1313. image = self.apply_im(sample[0]['image'], sample[1]['image'], factor)
  1314. result = copy.deepcopy(sample[0])
  1315. result['image'] = image
  1316. # Apply bbox and score
  1317. if 'gt_bbox' in sample[0]:
  1318. gt_bbox1 = sample[0]['gt_bbox']
  1319. gt_bbox2 = sample[1]['gt_bbox']
  1320. gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0)
  1321. result['gt_bbox'] = gt_bbox
  1322. if 'gt_poly' in sample[0]:
  1323. gt_poly1 = sample[0]['gt_poly']
  1324. gt_poly2 = sample[1]['gt_poly']
  1325. gt_poly = gt_poly1 + gt_poly2
  1326. result['gt_poly'] = gt_poly
  1327. if 'gt_class' in sample[0]:
  1328. gt_class1 = sample[0]['gt_class']
  1329. gt_class2 = sample[1]['gt_class']
  1330. gt_class = np.concatenate((gt_class1, gt_class2), axis=0)
  1331. result['gt_class'] = gt_class
  1332. gt_score1 = np.ones_like(sample[0]['gt_class'])
  1333. gt_score2 = np.ones_like(sample[1]['gt_class'])
  1334. gt_score = np.concatenate(
  1335. (gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
  1336. result['gt_score'] = gt_score
  1337. if 'is_crowd' in sample[0]:
  1338. is_crowd1 = sample[0]['is_crowd']
  1339. is_crowd2 = sample[1]['is_crowd']
  1340. is_crowd = np.concatenate((is_crowd1, is_crowd2), axis=0)
  1341. result['is_crowd'] = is_crowd
  1342. if 'difficult' in sample[0]:
  1343. is_difficult1 = sample[0]['difficult']
  1344. is_difficult2 = sample[1]['difficult']
  1345. is_difficult = np.concatenate(
  1346. (is_difficult1, is_difficult2), axis=0)
  1347. result['difficult'] = is_difficult
  1348. return result
  1349. class RandomDistort(Transform):
  1350. """
  1351. Random color distortion.
  1352. Args:
  1353. brightness_range (float, optional): Range of brightness distortion.
  1354. Defaults to .5.
  1355. brightness_prob (float, optional): Probability of brightness distortion.
  1356. Defaults to .5.
  1357. contrast_range (float, optional): Range of contrast distortion.
  1358. Defaults to .5.
  1359. contrast_prob (float, optional): Probability of contrast distortion.
  1360. Defaults to .5.
  1361. saturation_range (float, optional): Range of saturation distortion.
  1362. Defaults to .5.
  1363. saturation_prob (float, optional): Probability of saturation distortion.
  1364. Defaults to .5.
  1365. hue_range (float, optional): Range of hue distortion. Defaults to .5.
  1366. hue_prob (float, optional): Probability of hue distortion. Defaults to .5.
  1367. random_apply (bool, optional): Apply the transformation in random (YOLO) or
  1368. fixed (SSD) order. Defaults to True.
  1369. count (int, optional): Number of distortions to apply. Defaults to 4.
  1370. shuffle_channel (bool, optional): Whether to permute channels randomly.
  1371. Defaults to False.
  1372. """
  1373. def __init__(self,
  1374. brightness_range=0.5,
  1375. brightness_prob=0.5,
  1376. contrast_range=0.5,
  1377. contrast_prob=0.5,
  1378. saturation_range=0.5,
  1379. saturation_prob=0.5,
  1380. hue_range=18,
  1381. hue_prob=0.5,
  1382. random_apply=True,
  1383. count=4,
  1384. shuffle_channel=False):
  1385. super(RandomDistort, self).__init__()
  1386. self.brightness_range = [1 - brightness_range, 1 + brightness_range]
  1387. self.brightness_prob = brightness_prob
  1388. self.contrast_range = [1 - contrast_range, 1 + contrast_range]
  1389. self.contrast_prob = contrast_prob
  1390. self.saturation_range = [1 - saturation_range, 1 + saturation_range]
  1391. self.saturation_prob = saturation_prob
  1392. self.hue_range = [1 - hue_range, 1 + hue_range]
  1393. self.hue_prob = hue_prob
  1394. self.random_apply = random_apply
  1395. self.count = count
  1396. self.shuffle_channel = shuffle_channel
  1397. def apply_hue(self, image):
  1398. low, high = self.hue_range
  1399. if np.random.uniform(0., 1.) < self.hue_prob:
  1400. return image
  1401. # It works, but the result differs from HSV version.
  1402. delta = np.random.uniform(low, high)
  1403. u = np.cos(delta * np.pi)
  1404. w = np.sin(delta * np.pi)
  1405. bt = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
  1406. tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.321],
  1407. [0.211, -0.523, 0.311]])
  1408. ityiq = np.array([[1.0, 0.956, 0.621], [1.0, -0.272, -0.647],
  1409. [1.0, -1.107, 1.705]])
  1410. t = np.dot(np.dot(ityiq, bt), tyiq).T
  1411. res_list = []
  1412. channel = image.shape[2]
  1413. for i in range(channel // 3):
  1414. sub_img = image[:, :, 3 * i:3 * (i + 1)]
  1415. sub_img = sub_img.astype(np.float32)
  1416. sub_img = np.dot(image, t)
  1417. res_list.append(sub_img)
  1418. if channel % 3 != 0:
  1419. i = channel % 3
  1420. res_list.append(image[:, :, -i:])
  1421. return np.concatenate(res_list, axis=2)
  1422. def apply_saturation(self, image):
  1423. low, high = self.saturation_range
  1424. delta = np.random.uniform(low, high)
  1425. if np.random.uniform(0., 1.) < self.saturation_prob:
  1426. return image
  1427. res_list = []
  1428. channel = image.shape[2]
  1429. for i in range(channel // 3):
  1430. sub_img = image[:, :, 3 * i:3 * (i + 1)]
  1431. sub_img = sub_img.astype(np.float32)
  1432. # It works, but the result differs from HSV version.
  1433. gray = sub_img * np.array(
  1434. [[[0.299, 0.587, 0.114]]], dtype=np.float32)
  1435. gray = gray.sum(axis=2, keepdims=True)
  1436. gray *= (1.0 - delta)
  1437. sub_img *= delta
  1438. sub_img += gray
  1439. res_list.append(sub_img)
  1440. if channel % 3 != 0:
  1441. i = channel % 3
  1442. res_list.append(image[:, :, -i:])
  1443. return np.concatenate(res_list, axis=2)
  1444. def apply_contrast(self, image):
  1445. low, high = self.contrast_range
  1446. if np.random.uniform(0., 1.) < self.contrast_prob:
  1447. return image
  1448. delta = np.random.uniform(low, high)
  1449. image = image.astype(np.float32)
  1450. image *= delta
  1451. return image
  1452. def apply_brightness(self, image):
  1453. low, high = self.brightness_range
  1454. if np.random.uniform(0., 1.) < self.brightness_prob:
  1455. return image
  1456. delta = np.random.uniform(low, high)
  1457. image = image.astype(np.float32)
  1458. image += delta
  1459. return image
  1460. def apply(self, sample):
  1461. if self.random_apply:
  1462. functions = [
  1463. self.apply_brightness, self.apply_contrast,
  1464. self.apply_saturation, self.apply_hue
  1465. ]
  1466. distortions = np.random.permutation(functions)[:self.count]
  1467. for func in distortions:
  1468. sample['image'] = func(sample['image'])
  1469. if 'image2' in sample:
  1470. sample['image2'] = func(sample['image2'])
  1471. return sample
  1472. sample['image'] = self.apply_brightness(sample['image'])
  1473. if 'image2' in sample:
  1474. sample['image2'] = self.apply_brightness(sample['image2'])
  1475. mode = np.random.randint(0, 2)
  1476. if mode:
  1477. sample['image'] = self.apply_contrast(sample['image'])
  1478. if 'image2' in sample:
  1479. sample['image2'] = self.apply_contrast(sample['image2'])
  1480. sample['image'] = self.apply_saturation(sample['image'])
  1481. sample['image'] = self.apply_hue(sample['image'])
  1482. if 'image2' in sample:
  1483. sample['image2'] = self.apply_saturation(sample['image2'])
  1484. sample['image2'] = self.apply_hue(sample['image2'])
  1485. if not mode:
  1486. sample['image'] = self.apply_contrast(sample['image'])
  1487. if 'image2' in sample:
  1488. sample['image2'] = self.apply_contrast(sample['image2'])
  1489. if self.shuffle_channel:
  1490. if np.random.randint(0, 2):
  1491. sample['image'] = sample['image'][..., np.random.permutation(3)]
  1492. if 'image2' in sample:
  1493. sample['image2'] = sample['image2'][
  1494. ..., np.random.permutation(3)]
  1495. return sample
  1496. class RandomBlur(Transform):
  1497. """
  1498. Randomly blur input image(s).
  1499. Args:
  1500. prob (float): Probability of blurring.
  1501. """
  1502. def __init__(self, prob=0.1):
  1503. super(RandomBlur, self).__init__()
  1504. self.prob = prob
  1505. def apply_im(self, image, radius):
  1506. image = cv2.GaussianBlur(image, (radius, radius), 0, 0)
  1507. return image
  1508. def apply(self, sample):
  1509. if self.prob <= 0:
  1510. n = 0
  1511. elif self.prob >= 1:
  1512. n = 1
  1513. else:
  1514. n = int(1.0 / self.prob)
  1515. if n > 0:
  1516. if np.random.randint(0, n) == 0:
  1517. radius = np.random.randint(3, 10)
  1518. if radius % 2 != 1:
  1519. radius = radius + 1
  1520. if radius > 9:
  1521. radius = 9
  1522. sample['image'] = self.apply_im(sample['image'], radius)
  1523. if 'image2' in sample:
  1524. sample['image2'] = self.apply_im(sample['image2'], radius)
  1525. return sample
  1526. class Dehaze(Transform):
  1527. """
  1528. Dehaze input image(s).
  1529. Args:
  1530. gamma (bool, optional): Use gamma correction or not. Defaults to False.
  1531. """
  1532. def __init__(self, gamma=False):
  1533. super(Dehaze, self).__init__()
  1534. self.gamma = gamma
  1535. def apply_im(self, image):
  1536. image = F.dehaze(image, self.gamma)
  1537. return image
  1538. def apply(self, sample):
  1539. sample['image'] = self.apply_im(sample['image'])
  1540. if 'image2' in sample:
  1541. sample['image2'] = self.apply_im(sample['image2'])
  1542. return sample
  1543. class ReduceDim(Transform):
  1544. """
  1545. Use PCA to reduce the dimension of input image(s).
  1546. Args:
  1547. joblib_path (str): Path of *.joblib file of PCA.
  1548. apply_to_tar (bool, optional): Whether to apply transformation to the target
  1549. image. Defaults to True.
  1550. """
  1551. def __init__(self, joblib_path, apply_to_tar=True):
  1552. super(ReduceDim, self).__init__()
  1553. ext = joblib_path.split(".")[-1]
  1554. if ext != "joblib":
  1555. raise ValueError("`joblib_path` must be *.joblib, not *.{}.".format(
  1556. ext))
  1557. self.pca = load(joblib_path)
  1558. self.apply_to_tar = apply_to_tar
  1559. def apply_im(self, image):
  1560. H, W, C = image.shape
  1561. n_im = np.reshape(image, (-1, C))
  1562. im_pca = self.pca.transform(n_im)
  1563. result = np.reshape(im_pca, (H, W, -1))
  1564. return result
  1565. def apply(self, sample):
  1566. sample['image'] = self.apply_im(sample['image'])
  1567. if 'image2' in sample:
  1568. sample['image2'] = self.apply_im(sample['image2'])
  1569. if 'target' in sample and self.apply_to_tar:
  1570. sample['target'] = self.apply_im(sample['target'])
  1571. return sample
  1572. class SelectBand(Transform):
  1573. """
  1574. Select a set of bands of input image(s).
  1575. Args:
  1576. band_list (list, optional): Bands to select (band index starts from 1).
  1577. Defaults to [1, 2, 3].
  1578. apply_to_tar (bool, optional): Whether to apply transformation to the target
  1579. image. Defaults to True.
  1580. """
  1581. def __init__(self, band_list=[1, 2, 3], apply_to_tar=True):
  1582. super(SelectBand, self).__init__()
  1583. self.band_list = band_list
  1584. self.apply_to_tar = apply_to_tar
  1585. def apply_im(self, image):
  1586. image = F.select_bands(image, self.band_list)
  1587. return image
  1588. def apply(self, sample):
  1589. sample['image'] = self.apply_im(sample['image'])
  1590. if 'image2' in sample:
  1591. sample['image2'] = self.apply_im(sample['image2'])
  1592. if 'target' in sample and self.apply_to_tar:
  1593. sample['target'] = self.apply_im(sample['target'])
  1594. return sample
  1595. class _PadBox(Transform):
  1596. def __init__(self, num_max_boxes=50):
  1597. """
  1598. Pad zeros to bboxes if number of bboxes is less than `num_max_boxes`.
  1599. Args:
  1600. num_max_boxes (int, optional): Max number of bboxes. Defaults to 50.
  1601. """
  1602. self.num_max_boxes = num_max_boxes
  1603. super(_PadBox, self).__init__()
  1604. def apply(self, sample):
  1605. gt_num = min(self.num_max_boxes, len(sample['gt_bbox']))
  1606. num_max = self.num_max_boxes
  1607. pad_bbox = np.zeros((num_max, 4), dtype=np.float32)
  1608. if gt_num > 0:
  1609. pad_bbox[:gt_num, :] = sample['gt_bbox'][:gt_num, :]
  1610. sample['gt_bbox'] = pad_bbox
  1611. if 'gt_class' in sample:
  1612. pad_class = np.zeros((num_max, ), dtype=np.int32)
  1613. if gt_num > 0:
  1614. pad_class[:gt_num] = sample['gt_class'][:gt_num, 0]
  1615. sample['gt_class'] = pad_class
  1616. if 'gt_score' in sample:
  1617. pad_score = np.zeros((num_max, ), dtype=np.float32)
  1618. if gt_num > 0:
  1619. pad_score[:gt_num] = sample['gt_score'][:gt_num, 0]
  1620. sample['gt_score'] = pad_score
  1621. # In training, for example in op ExpandImage,
  1622. # bbox and gt_class are expanded, but difficult is not,
  1623. # so judge by its length.
  1624. if 'difficult' in sample:
  1625. pad_diff = np.zeros((num_max, ), dtype=np.int32)
  1626. if gt_num > 0:
  1627. pad_diff[:gt_num] = sample['difficult'][:gt_num, 0]
  1628. sample['difficult'] = pad_diff
  1629. if 'is_crowd' in sample:
  1630. pad_crowd = np.zeros((num_max, ), dtype=np.int32)
  1631. if gt_num > 0:
  1632. pad_crowd[:gt_num] = sample['is_crowd'][:gt_num, 0]
  1633. sample['is_crowd'] = pad_crowd
  1634. return sample
  1635. class _NormalizeBox(Transform):
  1636. def __init__(self):
  1637. super(_NormalizeBox, self).__init__()
  1638. def apply(self, sample):
  1639. height, width = sample['image'].shape[:2]
  1640. for i in range(sample['gt_bbox'].shape[0]):
  1641. sample['gt_bbox'][i][0] = sample['gt_bbox'][i][0] / width
  1642. sample['gt_bbox'][i][1] = sample['gt_bbox'][i][1] / height
  1643. sample['gt_bbox'][i][2] = sample['gt_bbox'][i][2] / width
  1644. sample['gt_bbox'][i][3] = sample['gt_bbox'][i][3] / height
  1645. return sample
  1646. class _BboxXYXY2XYWH(Transform):
  1647. """
  1648. Convert bbox XYXY format to XYWH format.
  1649. """
  1650. def __init__(self):
  1651. super(_BboxXYXY2XYWH, self).__init__()
  1652. def apply(self, sample):
  1653. bbox = sample['gt_bbox']
  1654. bbox[:, 2:4] = bbox[:, 2:4] - bbox[:, :2]
  1655. bbox[:, :2] = bbox[:, :2] + bbox[:, 2:4] / 2.
  1656. sample['gt_bbox'] = bbox
  1657. return sample
  1658. class _Permute(Transform):
  1659. def __init__(self):
  1660. super(_Permute, self).__init__()
  1661. def apply(self, sample):
  1662. sample['image'] = F.permute(sample['image'], False)
  1663. if 'image2' in sample:
  1664. sample['image2'] = F.permute(sample['image2'], False)
  1665. if 'target' in sample:
  1666. sample['target'] = F.permute(sample['target'], False)
  1667. return sample
  1668. class RandomSwap(Transform):
  1669. """
  1670. Randomly swap multi-temporal images.
  1671. Args:
  1672. prob (float, optional): Probability of swapping the input images.
  1673. Default: 0.2.
  1674. """
  1675. def __init__(self, prob=0.2):
  1676. super(RandomSwap, self).__init__()
  1677. self.prob = prob
  1678. def apply(self, sample):
  1679. if 'image2' not in sample:
  1680. raise ValueError("'image2' is not found in the sample.")
  1681. if random.random() < self.prob:
  1682. sample['image'], sample['image2'] = sample['image2'], sample[
  1683. 'image']
  1684. return sample
  1685. class ReloadMask(Transform):
  1686. def apply(self, sample):
  1687. if 'mask' in sample or 'mask_ori' in sample:
  1688. sample['mask'] = F.decode_seg_mask(sample['mask_ori'])
  1689. if 'aux_masks' in sample or 'aux_masks_ori' in sample:
  1690. sample['aux_masks'] = list(
  1691. map(F.decode_seg_mask, sample['aux_masks_ori']))
  1692. return sample
  1693. class AppendIndex(Transform):
  1694. """
  1695. Append remote sensing index to input image(s).
  1696. Args:
  1697. index_type (str): Type of remote sensinng index. See supported
  1698. index types in
  1699. https://github.com/PaddlePaddle/PaddleRS/tree/develop/paddlers/transforms/indices.py .
  1700. band_indices (dict, optional): Mapping of band names to band indices
  1701. (starting from 1). See band names in
  1702. https://github.com/PaddlePaddle/PaddleRS/tree/develop/paddlers/transforms/indices.py .
  1703. Default: None.
  1704. satellite (str, optional): Type of satellite. If set,
  1705. band indices will be automatically determined accordingly. See supported satellites in
  1706. https://github.com/PaddlePaddle/PaddleRS/tree/develop/paddlers/transforms/satellites.py .
  1707. Default: None.
  1708. """
  1709. def __init__(self, index_type, band_indices=None, satellite=None, **kwargs):
  1710. super(AppendIndex, self).__init__()
  1711. self.index_type = index_type
  1712. self.band_indices = band_indices
  1713. self.satellite = satellite
  1714. self._index_args = kwargs
  1715. cls = getattr(indices, self.index_type)
  1716. if self.satellite is not None:
  1717. satellite_bands = getattr(satellites, self.satellite)
  1718. self._compute_index = cls(satellite_bands, **self._index_args)
  1719. else:
  1720. if self.band_indices is None:
  1721. raise ValueError(
  1722. "At least one of `band_indices` and `satellite` must not be None."
  1723. )
  1724. else:
  1725. self._compute_index = cls(self.band_indices, **self._index_args)
  1726. def apply_im(self, image):
  1727. index = self._compute_index(image)
  1728. index = index[..., None].astype('float32')
  1729. return np.concatenate([image, index], axis=-1)
  1730. def apply(self, sample):
  1731. sample['image'] = self.apply_im(sample['image'])
  1732. if 'image2' in sample:
  1733. sample['image2'] = self.apply_im(sample['image2'])
  1734. return sample
  1735. def get_attrs_for_serialization(self):
  1736. return {
  1737. 'index_type': self.index_type,
  1738. 'band_indices': self.band_indices,
  1739. 'satellite': self.satellite,
  1740. **
  1741. self._index_args
  1742. }
  1743. class MatchRadiance(Transform):
  1744. """
  1745. Perform relative radiometric correction between bi-temporal images.
  1746. Args:
  1747. method (str, optional): Method used to match the radiance of the
  1748. bi-temporal images. Choices are {'hist', 'lsr', 'fft}. 'hist'
  1749. stands for histogram matching, 'lsr' stands for least-squares
  1750. regression, and 'fft' replaces the low-frequency components of
  1751. the image to match the reference image. Default: 'hist'.
  1752. """
  1753. def __init__(self, method='hist'):
  1754. super(MatchRadiance, self).__init__()
  1755. if method == 'hist':
  1756. self._match_func = F.match_histograms
  1757. elif method == 'lsr':
  1758. self._match_func = F.match_by_regression
  1759. elif method == 'fft':
  1760. self._match_func = F.match_lf_components
  1761. else:
  1762. raise ValueError(
  1763. "{} is not a supported radiometric correction method.".format(
  1764. method))
  1765. self.method = method
  1766. def apply(self, sample):
  1767. if 'image2' not in sample:
  1768. raise ValueError("'image2' is not found in the sample.")
  1769. sample['image2'] = self._match_func(sample['image2'], sample['image'])
  1770. return sample
  1771. class Arrange(Transform):
  1772. def __init__(self, mode):
  1773. super().__init__()
  1774. if mode not in ['train', 'eval', 'test', 'quant']:
  1775. raise ValueError(
  1776. "`mode` should be defined as one of ['train', 'eval', 'test', 'quant']!"
  1777. )
  1778. self.mode = mode
  1779. class ArrangeSegmenter(Arrange):
  1780. def apply(self, sample):
  1781. if 'mask' in sample:
  1782. mask = sample['mask']
  1783. mask = mask.astype('int64')
  1784. image = F.permute(sample['image'], False)
  1785. if self.mode == 'train':
  1786. return image, mask
  1787. if self.mode == 'eval':
  1788. return image, mask
  1789. if self.mode == 'test':
  1790. return image,
  1791. class ArrangeChangeDetector(Arrange):
  1792. def apply(self, sample):
  1793. if 'mask' in sample:
  1794. mask = sample['mask']
  1795. mask = mask.astype('int64')
  1796. image_t1 = F.permute(sample['image'], False)
  1797. image_t2 = F.permute(sample['image2'], False)
  1798. if self.mode == 'train':
  1799. masks = [mask]
  1800. if 'aux_masks' in sample:
  1801. masks.extend(
  1802. map(methodcaller('astype', 'int64'), sample['aux_masks']))
  1803. return (
  1804. image_t1,
  1805. image_t2, ) + tuple(masks)
  1806. if self.mode == 'eval':
  1807. return image_t1, image_t2, mask
  1808. if self.mode == 'test':
  1809. return image_t1, image_t2
  1810. class ArrangeClassifier(Arrange):
  1811. def apply(self, sample):
  1812. image = F.permute(sample['image'], False)
  1813. if self.mode in ['train', 'eval']:
  1814. return image, sample['label']
  1815. else:
  1816. return image,
  1817. class ArrangeDetector(Arrange):
  1818. def apply(self, sample):
  1819. if self.mode == 'eval' and 'gt_poly' in sample:
  1820. del sample['gt_poly']
  1821. return sample
  1822. class ArrangeRestorer(Arrange):
  1823. def apply(self, sample):
  1824. if 'target' in sample:
  1825. target = F.permute(sample['target'], False)
  1826. image = F.permute(sample['image'], False)
  1827. if self.mode == 'train':
  1828. return image, target
  1829. if self.mode == 'eval':
  1830. return image, target
  1831. if self.mode == 'test':
  1832. return image,