|
@@ -30,39 +30,21 @@ from PIL import Image
|
|
|
from joblib import load
|
|
|
|
|
|
import paddlers
|
|
|
-from .functions import normalize, horizontal_flip, permute, vertical_flip, center_crop, is_poly, \
|
|
|
- horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly, vertical_flip_rle, crop_poly, \
|
|
|
- crop_rle, expand_poly, expand_rle, resize_poly, resize_rle, dehaze, select_bands, \
|
|
|
- to_intensity, to_uint8, img_flip, img_simple_rotate
|
|
|
+from .functions import (
|
|
|
+ normalize, horizontal_flip, permute, vertical_flip, center_crop, is_poly,
|
|
|
+ horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly,
|
|
|
+ vertical_flip_rle, crop_poly, crop_rle, expand_poly, expand_rle,
|
|
|
+ resize_poly, resize_rle, dehaze, select_bands, to_intensity, to_uint8,
|
|
|
+ img_flip, img_simple_rotate, decode_seg_mask)
|
|
|
|
|
|
__all__ = [
|
|
|
- "Compose",
|
|
|
- "DecodeImg",
|
|
|
- "Resize",
|
|
|
- "RandomResize",
|
|
|
- "ResizeByShort",
|
|
|
- "RandomResizeByShort",
|
|
|
- "ResizeByLong",
|
|
|
- "RandomHorizontalFlip",
|
|
|
- "RandomVerticalFlip",
|
|
|
- "Normalize",
|
|
|
- "CenterCrop",
|
|
|
- "RandomCrop",
|
|
|
- "RandomScaleAspect",
|
|
|
- "RandomExpand",
|
|
|
- "Pad",
|
|
|
- "MixupImage",
|
|
|
- "RandomDistort",
|
|
|
- "RandomBlur",
|
|
|
- "RandomSwap",
|
|
|
- "Dehaze",
|
|
|
- "ReduceDim",
|
|
|
- "SelectBand",
|
|
|
- "ArrangeSegmenter",
|
|
|
- "ArrangeChangeDetector",
|
|
|
- "ArrangeClassifier",
|
|
|
- "ArrangeDetector",
|
|
|
- "RandomFlipOrRotate",
|
|
|
+ "Compose", "DecodeImg", "Resize", "RandomResize", "ResizeByShort",
|
|
|
+ "RandomResizeByShort", "ResizeByLong", "RandomHorizontalFlip",
|
|
|
+ "RandomVerticalFlip", "Normalize", "CenterCrop", "RandomCrop",
|
|
|
+ "RandomScaleAspect", "RandomExpand", "Pad", "MixupImage", "RandomDistort",
|
|
|
+ "RandomBlur", "RandomSwap", "Dehaze", "ReduceDim", "SelectBand",
|
|
|
+ "ArrangeSegmenter", "ArrangeChangeDetector", "ArrangeClassifier",
|
|
|
+ "ArrangeDetector", "RandomFlipOrRotate", "ReloadMask"
|
|
|
]
|
|
|
|
|
|
interp_dict = {
|
|
@@ -74,6 +56,71 @@ interp_dict = {
|
|
|
}
|
|
|
|
|
|
|
|
|
+class Compose(object):
|
|
|
+ """
|
|
|
+ Apply a series of data augmentation strategies to the input.
|
|
|
+ All input images should be in Height-Width-Channel ([H, W, C]) format.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ transforms (list[paddlers.transforms.Transform]): List of data preprocess or
|
|
|
+ augmentation operators.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ TypeError: Invalid type of transforms.
|
|
|
+ ValueError: Invalid length of transforms.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, transforms):
|
|
|
+ super(Compose, self).__init__()
|
|
|
+ if not isinstance(transforms, list):
|
|
|
+ raise TypeError(
|
|
|
+ "Type of transforms is invalid. Must be a list, but received is {}."
|
|
|
+ .format(type(transforms)))
|
|
|
+ if len(transforms) < 1:
|
|
|
+ raise ValueError(
|
|
|
+ "Length of transforms must not be less than 1, but received is {}."
|
|
|
+ .format(len(transforms)))
|
|
|
+ transforms = copy.deepcopy(transforms)
|
|
|
+ self.arrange = self._pick_arrange(transforms)
|
|
|
+ self.transforms = transforms
|
|
|
+
|
|
|
+ def __call__(self, sample):
|
|
|
+ """
|
|
|
+ This is equivalent to sequentially calling compose_obj.apply_transforms()
|
|
|
+ and compose_obj.arrange_outputs().
|
|
|
+ """
|
|
|
+
|
|
|
+ sample = self.apply_transforms(sample)
|
|
|
+ sample = self.arrange_outputs(sample)
|
|
|
+ return sample
|
|
|
+
|
|
|
+ def apply_transforms(self, sample):
|
|
|
+ for op in self.transforms:
|
|
|
+ # Skip batch transforms amd mixup
|
|
|
+ if isinstance(op, (paddlers.transforms.BatchRandomResize,
|
|
|
+ paddlers.transforms.BatchRandomResizeByShort,
|
|
|
+ MixupImage)):
|
|
|
+ continue
|
|
|
+ sample = op(sample)
|
|
|
+ return sample
|
|
|
+
|
|
|
+ def arrange_outputs(self, sample):
|
|
|
+ if self.arrange is not None:
|
|
|
+ sample = self.arrange(sample)
|
|
|
+ return sample
|
|
|
+
|
|
|
+ def _pick_arrange(self, transforms):
|
|
|
+ arrange = None
|
|
|
+ for idx, op in enumerate(transforms):
|
|
|
+ if isinstance(op, Arrange):
|
|
|
+ if idx != len(transforms) - 1:
|
|
|
+ raise ValueError(
|
|
|
+ "Arrange operator must be placed at the end of the list."
|
|
|
+ )
|
|
|
+ arrange = transforms.pop(idx)
|
|
|
+ return arrange
|
|
|
+
|
|
|
+
|
|
|
class Transform(object):
|
|
|
"""
|
|
|
Parent class of all data augmentation operations
|
|
@@ -178,14 +225,14 @@ class DecodeImg(Transform):
|
|
|
elif ext == '.npy':
|
|
|
return np.load(img_path)
|
|
|
else:
|
|
|
- raise TypeError('Image format {} is not supported!'.format(ext))
|
|
|
+ raise TypeError("Image format {} is not supported!".format(ext))
|
|
|
|
|
|
def apply_im(self, im_path):
|
|
|
if isinstance(im_path, str):
|
|
|
try:
|
|
|
image = self.read_img(im_path)
|
|
|
except:
|
|
|
- raise ValueError('Cannot read the image file {}!'.format(
|
|
|
+ raise ValueError("Cannot read the image file {}!".format(
|
|
|
im_path))
|
|
|
else:
|
|
|
image = im_path
|
|
@@ -217,7 +264,9 @@ class DecodeImg(Transform):
|
|
|
Returns:
|
|
|
dict: Decoded sample.
|
|
|
"""
|
|
|
+
|
|
|
if 'image' in sample:
|
|
|
+ sample['image_ori'] = copy.deepcopy(sample['image'])
|
|
|
sample['image'] = self.apply_im(sample['image'])
|
|
|
if 'image2' in sample:
|
|
|
sample['image2'] = self.apply_im(sample['image2'])
|
|
@@ -227,6 +276,7 @@ class DecodeImg(Transform):
|
|
|
sample['image'] = self.apply_im(sample['image_t1'])
|
|
|
sample['image2'] = self.apply_im(sample['image_t2'])
|
|
|
if 'mask' in sample:
|
|
|
+ sample['mask_ori'] = copy.deepcopy(sample['mask'])
|
|
|
sample['mask'] = self.apply_mask(sample['mask'])
|
|
|
im_height, im_width, _ = sample['image'].shape
|
|
|
se_height, se_width = sample['mask'].shape
|
|
@@ -234,6 +284,7 @@ class DecodeImg(Transform):
|
|
|
raise ValueError(
|
|
|
"The height or width of the image is not same as the mask.")
|
|
|
if 'aux_masks' in sample:
|
|
|
+ sample['aux_masks_ori'] = copy.deepcopy(sample['aux_masks_ori'])
|
|
|
sample['aux_masks'] = list(
|
|
|
map(self.apply_mask, sample['aux_masks']))
|
|
|
# TODO: check the shape of auxiliary masks
|
|
@@ -244,61 +295,6 @@ class DecodeImg(Transform):
|
|
|
return sample
|
|
|
|
|
|
|
|
|
-class Compose(Transform):
|
|
|
- """
|
|
|
- Apply a series of data augmentation to the input.
|
|
|
- All input images are in Height-Width-Channel ([H, W, C]) format.
|
|
|
-
|
|
|
- Args:
|
|
|
- transforms (list[paddlers.transforms.Transform]): List of data preprocess or augmentations.
|
|
|
- Raises:
|
|
|
- TypeError: Invalid type of transforms.
|
|
|
- ValueError: Invalid length of transforms.
|
|
|
- """
|
|
|
-
|
|
|
- def __init__(self, transforms, to_uint8=True):
|
|
|
- super(Compose, self).__init__()
|
|
|
- if not isinstance(transforms, list):
|
|
|
- raise TypeError(
|
|
|
- 'Type of transforms is invalid. Must be a list, but received is {}'
|
|
|
- .format(type(transforms)))
|
|
|
- if len(transforms) < 1:
|
|
|
- raise ValueError(
|
|
|
- 'Length of transforms must not be less than 1, but received is {}'
|
|
|
- .format(len(transforms)))
|
|
|
- self.transforms = transforms
|
|
|
- self.decode_image = DecodeImg(to_uint8=to_uint8)
|
|
|
- self.arrange_outputs = None
|
|
|
- self.apply_im_only = False
|
|
|
-
|
|
|
- def __call__(self, sample):
|
|
|
- if self.apply_im_only:
|
|
|
- if 'mask' in sample:
|
|
|
- mask_backup = copy.deepcopy(sample['mask'])
|
|
|
- del sample['mask']
|
|
|
- if 'aux_masks' in sample:
|
|
|
- aux_masks = copy.deepcopy(sample['aux_masks'])
|
|
|
-
|
|
|
- sample = self.decode_image(sample)
|
|
|
-
|
|
|
- for op in self.transforms:
|
|
|
- # skip batch transforms amd mixup
|
|
|
- if isinstance(op, (paddlers.transforms.BatchRandomResize,
|
|
|
- paddlers.transforms.BatchRandomResizeByShort,
|
|
|
- MixupImage)):
|
|
|
- continue
|
|
|
- sample = op(sample)
|
|
|
-
|
|
|
- if self.arrange_outputs is not None:
|
|
|
- if self.apply_im_only:
|
|
|
- sample['mask'] = mask_backup
|
|
|
- if 'aux_masks' in locals():
|
|
|
- sample['aux_masks'] = aux_masks
|
|
|
- sample = self.arrange_outputs(sample)
|
|
|
-
|
|
|
- return sample
|
|
|
-
|
|
|
-
|
|
|
class Resize(Transform):
|
|
|
"""
|
|
|
Resize input.
|
|
@@ -323,7 +319,7 @@ class Resize(Transform):
|
|
|
def __init__(self, target_size, interp='LINEAR', keep_ratio=False):
|
|
|
super(Resize, self).__init__()
|
|
|
if not (interp == "RANDOM" or interp in interp_dict):
|
|
|
- raise ValueError("interp should be one of {}".format(
|
|
|
+ raise ValueError("`interp` should be one of {}.".format(
|
|
|
interp_dict.keys()))
|
|
|
if isinstance(target_size, int):
|
|
|
target_size = (target_size, target_size)
|
|
@@ -331,7 +327,7 @@ class Resize(Transform):
|
|
|
if not (isinstance(target_size,
|
|
|
(list, tuple)) and len(target_size) == 2):
|
|
|
raise TypeError(
|
|
|
- "target_size should be an int or a list of length 2, but received {}".
|
|
|
+ "`target_size` should be an int or a list of length 2, but received {}.".
|
|
|
format(target_size))
|
|
|
# (height, width)
|
|
|
self.target_size = target_size
|
|
@@ -443,11 +439,11 @@ class RandomResize(Transform):
|
|
|
def __init__(self, target_sizes, interp='LINEAR'):
|
|
|
super(RandomResize, self).__init__()
|
|
|
if not (interp == "RANDOM" or interp in interp_dict):
|
|
|
- raise ValueError("interp should be one of {}".format(
|
|
|
+ raise ValueError("`interp` should be one of {}.".format(
|
|
|
interp_dict.keys()))
|
|
|
self.interp = interp
|
|
|
assert isinstance(target_sizes, list), \
|
|
|
- "target_size must be a list."
|
|
|
+ "`target_size` must be a list."
|
|
|
for i, item in enumerate(target_sizes):
|
|
|
if isinstance(item, int):
|
|
|
target_sizes[i] = (item, item)
|
|
@@ -478,7 +474,7 @@ class ResizeByShort(Transform):
|
|
|
|
|
|
def __init__(self, short_size=256, max_size=-1, interp='LINEAR'):
|
|
|
if not (interp == "RANDOM" or interp in interp_dict):
|
|
|
- raise ValueError("interp should be one of {}".format(
|
|
|
+ raise ValueError("`interp` should be one of {}".format(
|
|
|
interp_dict.keys()))
|
|
|
super(ResizeByShort, self).__init__()
|
|
|
self.short_size = short_size
|
|
@@ -522,11 +518,11 @@ class RandomResizeByShort(Transform):
|
|
|
def __init__(self, short_sizes, max_size=-1, interp='LINEAR'):
|
|
|
super(RandomResizeByShort, self).__init__()
|
|
|
if not (interp == "RANDOM" or interp in interp_dict):
|
|
|
- raise ValueError("interp should be one of {}".format(
|
|
|
+ raise ValueError("`interp` should be one of {}".format(
|
|
|
interp_dict.keys()))
|
|
|
self.interp = interp
|
|
|
assert isinstance(short_sizes, list), \
|
|
|
- "short_sizes must be a list."
|
|
|
+ "`short_sizes` must be a list."
|
|
|
|
|
|
self.short_sizes = short_sizes
|
|
|
self.max_size = max_size
|
|
@@ -574,6 +570,7 @@ class RandomFlipOrRotate(Transform):
|
|
|
|
|
|
# 定义数据增强
|
|
|
train_transforms = T.Compose([
|
|
|
+ T.DecodeImg(),
|
|
|
T.RandomFlipOrRotate(
|
|
|
probs = [0.3, 0.2] # 进行flip增强的概率是0.3,进行rotate增强的概率是0.2,不变的概率是0.5
|
|
|
probsf = [0.3, 0.25, 0, 0, 0] # flip增强时,使用水平flip、垂直flip的概率分别是0.3、0.25,水平且垂直flip、对角线flip、反对角线flip概率均为0,不变的概率是0.45
|
|
@@ -609,12 +606,12 @@ class RandomFlipOrRotate(Transform):
|
|
|
|
|
|
def apply_bbox(self, bbox, mode_id, flip_mode=True):
|
|
|
raise TypeError(
|
|
|
- "Currently, `paddlers.transforms.RandomFlipOrRotate` is not available for object detection tasks."
|
|
|
+ "Currently, RandomFlipOrRotate is not available for object detection tasks."
|
|
|
)
|
|
|
|
|
|
def apply_segm(self, bbox, mode_id, flip_mode=True):
|
|
|
raise TypeError(
|
|
|
- "Currently, `paddlers.transforms.RandomFlipOrRotate` is not available for object detection tasks."
|
|
|
+ "Currently, RandomFlipOrRotate is not available for object detection tasks."
|
|
|
)
|
|
|
|
|
|
def get_probs_range(self, probs):
|
|
@@ -845,11 +842,11 @@ class Normalize(Transform):
|
|
|
from functools import reduce
|
|
|
if reduce(lambda x, y: x * y, std) == 0:
|
|
|
raise ValueError(
|
|
|
- 'Std should not contain 0, but received is {}.'.format(std))
|
|
|
+ "`std` should not contain 0, but received is {}.".format(std))
|
|
|
if reduce(lambda x, y: x * y,
|
|
|
[a - b for a, b in zip(max_val, min_val)]) == 0:
|
|
|
raise ValueError(
|
|
|
- '(max_val - min_val) should not contain 0, but received is {}.'.
|
|
|
+ "(`max_val` - `min_val`) should not contain 0, but received is {}.".
|
|
|
format((np.asarray(max_val) - np.asarray(min_val)).tolist()))
|
|
|
|
|
|
self.mean = mean
|
|
@@ -1153,11 +1150,11 @@ class RandomExpand(Transform):
|
|
|
im_padding_value=127.5,
|
|
|
label_padding_value=255):
|
|
|
super(RandomExpand, self).__init__()
|
|
|
- assert upper_ratio > 1.01, "expand ratio must be larger than 1.01"
|
|
|
+ assert upper_ratio > 1.01, "`upper_ratio` must be larger than 1.01."
|
|
|
self.upper_ratio = upper_ratio
|
|
|
self.prob = prob
|
|
|
assert isinstance(im_padding_value, (Number, Sequence)), \
|
|
|
- "fill value must be either float or sequence"
|
|
|
+ "Value to fill must be either float or sequence."
|
|
|
self.im_padding_value = im_padding_value
|
|
|
self.label_padding_value = label_padding_value
|
|
|
|
|
@@ -1204,16 +1201,16 @@ class Pad(Transform):
|
|
|
if isinstance(target_size, (list, tuple)):
|
|
|
if len(target_size) != 2:
|
|
|
raise ValueError(
|
|
|
- '`target_size` should include 2 elements, but it is {}'.
|
|
|
+ "`target_size` should contain 2 elements, but it is {}.".
|
|
|
format(target_size))
|
|
|
if isinstance(target_size, int):
|
|
|
target_size = [target_size] * 2
|
|
|
|
|
|
assert pad_mode in [
|
|
|
-1, 0, 1, 2
|
|
|
- ], 'currently only supports four modes [-1, 0, 1, 2]'
|
|
|
+ ], "Currently only four modes are supported: [-1, 0, 1, 2]."
|
|
|
if pad_mode == -1:
|
|
|
- assert offsets, 'if pad_mode is -1, offsets should not be None'
|
|
|
+ assert offsets, "if `pad_mode` is -1, `offsets` should not be None."
|
|
|
|
|
|
self.target_size = target_size
|
|
|
self.size_divisor = size_divisor
|
|
@@ -1314,9 +1311,9 @@ class MixupImage(Transform):
|
|
|
"""
|
|
|
super(MixupImage, self).__init__()
|
|
|
if alpha <= 0.0:
|
|
|
- raise ValueError("alpha should be positive in {}".format(self))
|
|
|
+ raise ValueError("`alpha` should be positive in MixupImage.")
|
|
|
if beta <= 0.0:
|
|
|
- raise ValueError("beta should be positive in {}".format(self))
|
|
|
+ raise ValueError("`beta` should be positive in MixupImage.")
|
|
|
self.alpha = alpha
|
|
|
self.beta = beta
|
|
|
self.mixup_epoch = mixup_epoch
|
|
@@ -1753,55 +1750,56 @@ class RandomSwap(Transform):
|
|
|
|
|
|
def apply(self, sample):
|
|
|
if 'image2' not in sample:
|
|
|
- raise ValueError('image2 is not found in the sample.')
|
|
|
+ raise ValueError("'image2' is not found in the sample.")
|
|
|
if random.random() < self.prob:
|
|
|
sample['image'], sample['image2'] = sample['image2'], sample[
|
|
|
'image']
|
|
|
return sample
|
|
|
|
|
|
|
|
|
-class ArrangeSegmenter(Transform):
|
|
|
+class ReloadMask(Transform):
|
|
|
+ def apply(self, sample):
|
|
|
+ sample['mask'] = decode_seg_mask(sample['mask_ori'])
|
|
|
+ if 'aux_masks' in sample:
|
|
|
+ sample['aux_masks'] = list(
|
|
|
+ map(decode_seg_mask, sample['aux_masks_ori']))
|
|
|
+ return sample
|
|
|
+
|
|
|
+
|
|
|
+class Arrange(Transform):
|
|
|
def __init__(self, mode):
|
|
|
- super(ArrangeSegmenter, self).__init__()
|
|
|
+ super().__init__()
|
|
|
if mode not in ['train', 'eval', 'test', 'quant']:
|
|
|
raise ValueError(
|
|
|
- "mode should be defined as one of ['train', 'eval', 'test', 'quant']!"
|
|
|
+ "`mode` should be defined as one of ['train', 'eval', 'test', 'quant']!"
|
|
|
)
|
|
|
self.mode = mode
|
|
|
|
|
|
+
|
|
|
+class ArrangeSegmenter(Arrange):
|
|
|
def apply(self, sample):
|
|
|
if 'mask' in sample:
|
|
|
mask = sample['mask']
|
|
|
+ mask = mask.astype('int64')
|
|
|
|
|
|
image = permute(sample['image'], False)
|
|
|
if self.mode == 'train':
|
|
|
- mask = mask.astype('int64')
|
|
|
return image, mask
|
|
|
if self.mode == 'eval':
|
|
|
- mask = np.asarray(Image.open(mask))
|
|
|
- mask = mask[np.newaxis, :, :].astype('int64')
|
|
|
return image, mask
|
|
|
if self.mode == 'test':
|
|
|
return image,
|
|
|
|
|
|
|
|
|
-class ArrangeChangeDetector(Transform):
|
|
|
- def __init__(self, mode):
|
|
|
- super(ArrangeChangeDetector, self).__init__()
|
|
|
- if mode not in ['train', 'eval', 'test', 'quant']:
|
|
|
- raise ValueError(
|
|
|
- "mode should be defined as one of ['train', 'eval', 'test', 'quant']!"
|
|
|
- )
|
|
|
- self.mode = mode
|
|
|
-
|
|
|
+class ArrangeChangeDetector(Arrange):
|
|
|
def apply(self, sample):
|
|
|
if 'mask' in sample:
|
|
|
mask = sample['mask']
|
|
|
+ mask = mask.astype('int64')
|
|
|
|
|
|
image_t1 = permute(sample['image'], False)
|
|
|
image_t2 = permute(sample['image2'], False)
|
|
|
if self.mode == 'train':
|
|
|
- mask = mask.astype('int64')
|
|
|
masks = [mask]
|
|
|
if 'aux_masks' in sample:
|
|
|
masks.extend(
|
|
@@ -1810,22 +1808,12 @@ class ArrangeChangeDetector(Transform):
|
|
|
image_t1,
|
|
|
image_t2, ) + tuple(masks)
|
|
|
if self.mode == 'eval':
|
|
|
- mask = np.asarray(Image.open(mask))
|
|
|
- mask = mask[np.newaxis, :, :].astype('int64')
|
|
|
return image_t1, image_t2, mask
|
|
|
if self.mode == 'test':
|
|
|
return image_t1, image_t2,
|
|
|
|
|
|
|
|
|
-class ArrangeClassifier(Transform):
|
|
|
- def __init__(self, mode):
|
|
|
- super(ArrangeClassifier, self).__init__()
|
|
|
- if mode not in ['train', 'eval', 'test', 'quant']:
|
|
|
- raise ValueError(
|
|
|
- "mode should be defined as one of ['train', 'eval', 'test', 'quant']!"
|
|
|
- )
|
|
|
- self.mode = mode
|
|
|
-
|
|
|
+class ArrangeClassifier(Arrange):
|
|
|
def apply(self, sample):
|
|
|
image = permute(sample['image'], False)
|
|
|
if self.mode in ['train', 'eval']:
|
|
@@ -1834,15 +1822,7 @@ class ArrangeClassifier(Transform):
|
|
|
return image
|
|
|
|
|
|
|
|
|
-class ArrangeDetector(Transform):
|
|
|
- def __init__(self, mode):
|
|
|
- super(ArrangeDetector, self).__init__()
|
|
|
- if mode not in ['train', 'eval', 'test', 'quant']:
|
|
|
- raise ValueError(
|
|
|
- "mode should be defined as one of ['train', 'eval', 'test', 'quant']!"
|
|
|
- )
|
|
|
- self.mode = mode
|
|
|
-
|
|
|
+class ArrangeDetector(Arrange):
|
|
|
def apply(self, sample):
|
|
|
if self.mode == 'eval' and 'gt_poly' in sample:
|
|
|
del sample['gt_poly']
|