Browse Source

Decouple compose

Bobholamovic 2 years ago
parent
commit
474bbb6271

+ 14 - 9
paddlers/tasks/base.py

@@ -30,7 +30,6 @@ from paddleslim import L1NormFilterPruner, FPGMFilterPruner
 
 import paddlers
 import paddlers.utils.logging as logging
-from paddlers.transforms import arrange_transforms
 from paddlers.utils import (seconds_to_hms, get_single_card_bs, dict2str,
                             get_pretrain_weights, load_pretrain_weights,
                             load_checkpoint, SmoothedValue, TrainingStats,
@@ -302,10 +301,7 @@ class BaseModel(metaclass=ModelMeta):
                    early_stop=False,
                    early_stop_patience=5,
                    use_vdl=True):
-        arrange_transforms(
-            model_type=self.model_type,
-            transforms=train_dataset.transforms,
-            mode='train')
+        self._check_transforms(train_dataset.transforms, 'train')
 
         if "RCNN" in self.__class__.__name__ and train_dataset.pos_num < len(
                 train_dataset.file_list):
@@ -488,10 +484,7 @@ class BaseModel(metaclass=ModelMeta):
 
         assert criterion in {'l1_norm', 'fpgm'}, \
             "Pruning criterion {} is not supported. Please choose from ['l1_norm', 'fpgm']"
-        arrange_transforms(
-            model_type=self.model_type,
-            transforms=dataset.transforms,
-            mode='eval')
+        self._check_transforms(dataset.transforms, 'eval')
         if self.model_type == 'detector':
             self.net.eval()
         else:
@@ -670,3 +663,15 @@ class BaseModel(metaclass=ModelMeta):
         open(osp.join(save_dir, '.success'), 'w').close()
         logging.info("The model for the inference deployment is saved in {}.".
                      format(save_dir))
+
+    def _check_transforms(self, transforms, mode):
+        # NOTE: Check transforms and transforms.arrange and give user-friendly error messages.
+        if not isinstance(transforms, paddlers.transforms.Compose):
+            raise TypeError("`transforms` must be paddlers.transforms.Compose.")
+        arrange_obj = transforms.arrange
+        if not isinstance(arrange_obj, paddlers.transforms.operators.Arrange):
+            raise TypeError("`transforms.arrange` must be an Arrange object.")
+        if arrange_obj.mode != mode:
+            raise ValueError(
+                f"Incorrect arrange mode! Expected {mode} but got {arrange_obj.mode}."
+            )

+ 14 - 7
paddlers/tasks/change_detector.py

@@ -28,7 +28,6 @@ import paddlers
 import paddlers.custom_models.cd as cmcd
 import paddlers.utils.logging as logging
 import paddlers.models.ppseg as paddleseg
-from paddlers.transforms import arrange_transforms
 from paddlers.transforms import Resize, decode_image
 from paddlers.utils import get_single_card_bs, DisablePrint
 from paddlers.utils.checkpoint import seg_pretrain_weights_dict
@@ -137,6 +136,11 @@ class BaseChangeDetector(BaseModel):
             else:
                 pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
             label = inputs[2]
+            if label.ndim == 3:
+                paddle.unsqueeze_(label, axis=1)
+            if label.ndim != 4:
+                raise ValueError("Expected label.ndim == 4 but got {}".format(
+                    label.ndim))
             origin_shape = [label.shape[-2:]]
             pred = self._postprocess(
                 pred, origin_shape, transforms=inputs[3])[0]  # NCHW
@@ -396,10 +400,7 @@ class BaseChangeDetector(BaseModel):
                  "category_F1-score": `F1 score`}.
 
         """
-        arrange_transforms(
-            model_type=self.model_type,
-            transforms=eval_dataset.transforms,
-            mode='eval')
+        self._check_transforms(eval_dataset.transforms, 'eval')
 
         self.net.eval()
         nranks = paddle.distributed.get_world_size()
@@ -641,8 +642,7 @@ class BaseChangeDetector(BaseModel):
         print("GeoTiff saved in {}.".format(save_file))
 
     def _preprocess(self, images, transforms, to_tensor=True):
-        arrange_transforms(
-            model_type=self.model_type, transforms=transforms, mode='test')
+        self._check_transforms(transforms, 'test')
         batch_im1, batch_im2 = list(), list()
         batch_ori_shape = list()
         for im1, im2 in images:
@@ -786,6 +786,13 @@ class BaseChangeDetector(BaseModel):
             score_maps.append(score_map.squeeze())
         return label_maps, score_maps
 
+    def _check_transforms(self, transforms, mode):
+        super()._check_transforms(transforms, mode)
+        if not isinstance(transforms.arrange,
+                          paddlers.transforms.ArrangeChangeDetector):
+            raise TypeError(
+                "`transforms.arrange` must be an ArrangeChangeDetector object.")
+
 
 class CDNet(BaseChangeDetector):
     def __init__(self,

+ 9 - 7
paddlers/tasks/classifier.py

@@ -25,7 +25,6 @@ from paddle.static import InputSpec
 import paddlers.models.ppcls as paddleclas
 import paddlers.custom_models.cls as cmcls
 import paddlers
-from paddlers.transforms import arrange_transforms
 from paddlers.utils import get_single_card_bs, DisablePrint
 import paddlers.utils.logging as logging
 from .base import BaseModel
@@ -358,10 +357,7 @@ class BaseClassifier(BaseModel):
                  "top5": `acc of top5`}.
 
         """
-        arrange_transforms(
-            model_type=self.model_type,
-            transforms=eval_dataset.transforms,
-            mode='eval')
+        self._check_transforms(eval_dataset.transforms, 'eval')
 
         self.net.eval()
         nranks = paddle.distributed.get_world_size()
@@ -460,8 +456,7 @@ class BaseClassifier(BaseModel):
         return prediction
 
     def _preprocess(self, images, transforms, to_tensor=True):
-        arrange_transforms(
-            model_type=self.model_type, transforms=transforms, mode='test')
+        self._check_transforms(transforms, 'test')
         batch_im = list()
         batch_ori_shape = list()
         for im in images:
@@ -527,6 +522,13 @@ class BaseClassifier(BaseModel):
             batch_restore_list.append(restore_list)
         return batch_restore_list
 
+    def _check_transforms(self, transforms, mode):
+        super()._check_transforms(transforms, mode)
+        if not isinstance(transforms.arrange,
+                          paddlers.transforms.ArrangeClassifier):
+            raise TypeError(
+                "`transforms.arrange` must be an ArrangeClassifier object.")
+
 
 class ResNet50_vd(BaseClassifier):
     def __init__(self, num_classes=2, use_mixed_loss=False, **params):

+ 9 - 7
paddlers/tasks/object_detector.py

@@ -31,7 +31,6 @@ from paddlers.transforms import decode_image
 from paddlers.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Pad
 from paddlers.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, \
     _BatchPad, _Gt2YoloTarget
-from paddlers.transforms import arrange_transforms
 from .base import BaseModel
 from .utils.det_metrics import VOCMetric, COCOMetric
 from paddlers.models.ppdet.optimizer import ModelEMA
@@ -452,10 +451,7 @@ class BaseDetector(BaseModel):
                 }
         eval_dataset.batch_transforms = self._compose_batch_transform(
             eval_dataset.transforms, mode='eval')
-        arrange_transforms(
-            model_type=self.model_type,
-            transforms=eval_dataset.transforms,
-            mode='eval')
+        self._check_transforms(eval_dataset.transforms, 'eval')
 
         self.net.eval()
         nranks = paddle.distributed.get_world_size()
@@ -545,8 +541,7 @@ class BaseDetector(BaseModel):
         return prediction
 
     def _preprocess(self, images, transforms, to_tensor=True):
-        arrange_transforms(
-            model_type=self.model_type, transforms=transforms, mode='test')
+        self._check_transforms(transforms, 'test')
         batch_samples = list()
         for im in images:
             if isinstance(im, str):
@@ -630,6 +625,13 @@ class BaseDetector(BaseModel):
 
         return results
 
+    def _check_transforms(self, transforms, mode):
+        super()._check_transforms(transforms, mode)
+        if not isinstance(transforms.arrange,
+                          paddlers.transforms.ArrangeDetector):
+            raise TypeError(
+                "`transforms.arrange` must be an ArrangeDetector object.")
+
 
 class PicoDet(BaseDetector):
     def __init__(self,

+ 14 - 7
paddlers/tasks/segmenter.py

@@ -26,7 +26,6 @@ from paddle.static import InputSpec
 import paddlers.models.ppseg as paddleseg
 import paddlers.custom_models.seg as cmseg
 import paddlers
-from paddlers.transforms import arrange_transforms
 from paddlers.utils import get_single_card_bs, DisablePrint
 import paddlers.utils.logging as logging
 from .base import BaseModel
@@ -136,6 +135,11 @@ class BaseSegmenter(BaseModel):
             else:
                 pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
             label = inputs[1]
+            if label.ndim == 3:
+                paddle.unsqueeze_(label, axis=1)
+            if label.ndim != 4:
+                raise ValueError("Expected label.ndim == 4 but got {}".format(
+                    label.ndim))
             origin_shape = [label.shape[-2:]]
             pred = self._postprocess(
                 pred, origin_shape, transforms=inputs[2])[0]  # NCHW
@@ -380,10 +384,7 @@ class BaseSegmenter(BaseModel):
                  "category_F1-score": `F1 score`}.
 
         """
-        arrange_transforms(
-            model_type=self.model_type,
-            transforms=eval_dataset.transforms,
-            mode='eval')
+        self._check_transforms(eval_dataset.transforms, 'eval')
 
         self.net.eval()
         nranks = paddle.distributed.get_world_size()
@@ -606,8 +607,7 @@ class BaseSegmenter(BaseModel):
         print("GeoTiff saved in {}.".format(save_file))
 
     def _preprocess(self, images, transforms, to_tensor=True):
-        arrange_transforms(
-            model_type=self.model_type, transforms=transforms, mode='test')
+        self._check_transforms(transforms, 'test')
         batch_im = list()
         batch_ori_shape = list()
         for im in images:
@@ -746,6 +746,13 @@ class BaseSegmenter(BaseModel):
             score_maps.append(score_map.squeeze())
         return label_maps, score_maps
 
+    def _check_transforms(self, transforms, mode):
+        super()._check_transforms(transforms, mode)
+        if not isinstance(transforms.arrange,
+                          paddlers.transforms.ArrangeSegmenter):
+            raise TypeError(
+                "`transforms.arrange` must be an ArrangeSegmenter object.")
+
 
 class UNet(BaseSegmenter):
     def __init__(self,

+ 85 - 112
paddlers/transforms/operators.py

@@ -74,6 +74,59 @@ interp_dict = {
 }
 
 
+class Compose(object):
+    """
+    Apply a series of data augmentation strategies to the input.
+    All input images should be in Height-Width-Channel ([H, W, C]) format.
+
+    Args:
+        transforms (list[paddlers.transforms.Transform]): List of data preprocess or augmentation operators.
+        arrange (list[paddlers.transforms.Arrange]|None, optional): If not None, the Arrange operator will be used to 
+            arrange the outputs of `transforms`. Defaults to None. 
+
+    Raises:
+        TypeError: Invalid type of transforms.
+        ValueError: Invalid length of transforms.
+    """
+
+    def __init__(self, transforms, arrange=None):
+        super(Compose, self).__init__()
+        if not isinstance(transforms, list):
+            raise TypeError(
+                "Type of transforms is invalid. Must be a list, but received is {}."
+                .format(type(transforms)))
+        if len(transforms) < 1:
+            raise ValueError(
+                "Length of transforms must not be less than 1, but received is {}."
+                .format(len(transforms)))
+        self.transforms = transforms
+        self.arrange = arrange
+
+    def __call__(self, sample):
+        """
+        This is equivalent to sequentially calling compose_obj.apply_transforms() and compose_obj.arrange_outputs().
+        """
+
+        sample = self.apply_transforms(sample)
+        sample = self.arrange_outputs(sample)
+        return sample
+
+    def apply_transforms(self, sample):
+        for op in self.transforms:
+            # Skip batch transforms amd mixup
+            if isinstance(op, (paddlers.transforms.BatchRandomResize,
+                               paddlers.transforms.BatchRandomResizeByShort,
+                               MixupImage)):
+                continue
+            sample = op(sample)
+        return sample
+
+    def arrange_outputs(self, sample):
+        if self.arrange is not None:
+            sample = self.arrange(sample)
+        return sample
+
+
 class Transform(object):
     """
     Parent class of all data augmentation operations
@@ -178,14 +231,14 @@ class DecodeImg(Transform):
         elif ext == '.npy':
             return np.load(img_path)
         else:
-            raise TypeError('Image format {} is not supported!'.format(ext))
+            raise TypeError("Image format {} is not supported!".format(ext))
 
     def apply_im(self, im_path):
         if isinstance(im_path, str):
             try:
                 image = self.read_img(im_path)
             except:
-                raise ValueError('Cannot read the image file {}!'.format(
+                raise ValueError("Cannot read the image file {}!".format(
                     im_path))
         else:
             image = im_path
@@ -244,61 +297,6 @@ class DecodeImg(Transform):
         return sample
 
 
-class Compose(Transform):
-    """
-    Apply a series of data augmentation to the input.
-    All input images are in Height-Width-Channel ([H, W, C]) format.
-
-    Args:
-        transforms (list[paddlers.transforms.Transform]): List of data preprocess or augmentations.
-    Raises:
-        TypeError: Invalid type of transforms.
-        ValueError: Invalid length of transforms.
-    """
-
-    def __init__(self, transforms, to_uint8=True):
-        super(Compose, self).__init__()
-        if not isinstance(transforms, list):
-            raise TypeError(
-                'Type of transforms is invalid. Must be a list, but received is {}'
-                .format(type(transforms)))
-        if len(transforms) < 1:
-            raise ValueError(
-                'Length of transforms must not be less than 1, but received is {}'
-                .format(len(transforms)))
-        self.transforms = transforms
-        self.decode_image = DecodeImg(to_uint8=to_uint8)
-        self.arrange_outputs = None
-        self.apply_im_only = False
-
-    def __call__(self, sample):
-        if self.apply_im_only:
-            if 'mask' in sample:
-                mask_backup = copy.deepcopy(sample['mask'])
-                del sample['mask']
-            if 'aux_masks' in sample:
-                aux_masks = copy.deepcopy(sample['aux_masks'])
-
-        sample = self.decode_image(sample)
-
-        for op in self.transforms:
-            # skip batch transforms amd mixup
-            if isinstance(op, (paddlers.transforms.BatchRandomResize,
-                               paddlers.transforms.BatchRandomResizeByShort,
-                               MixupImage)):
-                continue
-            sample = op(sample)
-
-        if self.arrange_outputs is not None:
-            if self.apply_im_only:
-                sample['mask'] = mask_backup
-                if 'aux_masks' in locals():
-                    sample['aux_masks'] = aux_masks
-            sample = self.arrange_outputs(sample)
-
-        return sample
-
-
 class Resize(Transform):
     """
     Resize input.
@@ -323,7 +321,7 @@ class Resize(Transform):
     def __init__(self, target_size, interp='LINEAR', keep_ratio=False):
         super(Resize, self).__init__()
         if not (interp == "RANDOM" or interp in interp_dict):
-            raise ValueError("interp should be one of {}".format(
+            raise ValueError("`interp` should be one of {}.".format(
                 interp_dict.keys()))
         if isinstance(target_size, int):
             target_size = (target_size, target_size)
@@ -331,7 +329,7 @@ class Resize(Transform):
             if not (isinstance(target_size,
                                (list, tuple)) and len(target_size) == 2):
                 raise TypeError(
-                    "target_size should be an int or a list of length 2, but received {}".
+                    "`target_size` should be an int or a list of length 2, but received {}.".
                     format(target_size))
         # (height, width)
         self.target_size = target_size
@@ -443,11 +441,11 @@ class RandomResize(Transform):
     def __init__(self, target_sizes, interp='LINEAR'):
         super(RandomResize, self).__init__()
         if not (interp == "RANDOM" or interp in interp_dict):
-            raise ValueError("interp should be one of {}".format(
+            raise ValueError("`interp` should be one of {}.".format(
                 interp_dict.keys()))
         self.interp = interp
         assert isinstance(target_sizes, list), \
-            "target_size must be a list."
+            "`target_size` must be a list."
         for i, item in enumerate(target_sizes):
             if isinstance(item, int):
                 target_sizes[i] = (item, item)
@@ -478,7 +476,7 @@ class ResizeByShort(Transform):
 
     def __init__(self, short_size=256, max_size=-1, interp='LINEAR'):
         if not (interp == "RANDOM" or interp in interp_dict):
-            raise ValueError("interp should be one of {}".format(
+            raise ValueError("`interp` should be one of {}".format(
                 interp_dict.keys()))
         super(ResizeByShort, self).__init__()
         self.short_size = short_size
@@ -522,11 +520,11 @@ class RandomResizeByShort(Transform):
     def __init__(self, short_sizes, max_size=-1, interp='LINEAR'):
         super(RandomResizeByShort, self).__init__()
         if not (interp == "RANDOM" or interp in interp_dict):
-            raise ValueError("interp should be one of {}".format(
+            raise ValueError("`interp` should be one of {}".format(
                 interp_dict.keys()))
         self.interp = interp
         assert isinstance(short_sizes, list), \
-            "short_sizes must be a list."
+            "`short_sizes` must be a list."
 
         self.short_sizes = short_sizes
         self.max_size = max_size
@@ -574,6 +572,7 @@ class RandomFlipOrRotate(Transform):
 
         # 定义数据增强
         train_transforms = T.Compose([
+            T.DecodeImg(),
             T.RandomFlipOrRotate(
                 probs  = [0.3, 0.2]             # 进行flip增强的概率是0.3,进行rotate增强的概率是0.2,不变的概率是0.5
                 probsf = [0.3, 0.25, 0, 0, 0]   # flip增强时,使用水平flip、垂直flip的概率分别是0.3、0.25,水平且垂直flip、对角线flip、反对角线flip概率均为0,不变的概率是0.45
@@ -609,12 +608,12 @@ class RandomFlipOrRotate(Transform):
 
     def apply_bbox(self, bbox, mode_id, flip_mode=True):
         raise TypeError(
-            "Currently, `paddlers.transforms.RandomFlipOrRotate` is not available for object detection tasks."
+            "Currently, RandomFlipOrRotate is not available for object detection tasks."
         )
 
     def apply_segm(self, bbox, mode_id, flip_mode=True):
         raise TypeError(
-            "Currently, `paddlers.transforms.RandomFlipOrRotate` is not available for object detection tasks."
+            "Currently, RandomFlipOrRotate is not available for object detection tasks."
         )
 
     def get_probs_range(self, probs):
@@ -845,11 +844,11 @@ class Normalize(Transform):
         from functools import reduce
         if reduce(lambda x, y: x * y, std) == 0:
             raise ValueError(
-                'Std should not contain 0, but received is {}.'.format(std))
+                "`std` should not contain 0, but received is {}.".format(std))
         if reduce(lambda x, y: x * y,
                   [a - b for a, b in zip(max_val, min_val)]) == 0:
             raise ValueError(
-                '(max_val - min_val) should not contain 0, but received is {}.'.
+                "(`max_val` - `min_val`) should not contain 0, but received is {}.".
                 format((np.asarray(max_val) - np.asarray(min_val)).tolist()))
 
         self.mean = mean
@@ -1153,11 +1152,11 @@ class RandomExpand(Transform):
                  im_padding_value=127.5,
                  label_padding_value=255):
         super(RandomExpand, self).__init__()
-        assert upper_ratio > 1.01, "expand ratio must be larger than 1.01"
+        assert upper_ratio > 1.01, "`upper_ratio` must be larger than 1.01."
         self.upper_ratio = upper_ratio
         self.prob = prob
         assert isinstance(im_padding_value, (Number, Sequence)), \
-            "fill value must be either float or sequence"
+            "Value to fill must be either float or sequence."
         self.im_padding_value = im_padding_value
         self.label_padding_value = label_padding_value
 
@@ -1204,16 +1203,16 @@ class Pad(Transform):
         if isinstance(target_size, (list, tuple)):
             if len(target_size) != 2:
                 raise ValueError(
-                    '`target_size` should include 2 elements, but it is {}'.
+                    "`target_size` should contain 2 elements, but it is {}.".
                     format(target_size))
         if isinstance(target_size, int):
             target_size = [target_size] * 2
 
         assert pad_mode in [
             -1, 0, 1, 2
-        ], 'currently only supports four modes [-1, 0, 1, 2]'
+        ], "Currently only four modes are supported: [-1, 0, 1, 2]."
         if pad_mode == -1:
-            assert offsets, 'if pad_mode is -1, offsets should not be None'
+            assert offsets, "if `pad_mode` is -1, `offsets` should not be None."
 
         self.target_size = target_size
         self.size_divisor = size_divisor
@@ -1314,9 +1313,9 @@ class MixupImage(Transform):
         """
         super(MixupImage, self).__init__()
         if alpha <= 0.0:
-            raise ValueError("alpha should be positive in {}".format(self))
+            raise ValueError("`alpha` should be positive in MixupImage.")
         if beta <= 0.0:
-            raise ValueError("beta should be positive in {}".format(self))
+            raise ValueError("`beta` should be positive in MixupImage.")
         self.alpha = alpha
         self.beta = beta
         self.mixup_epoch = mixup_epoch
@@ -1753,55 +1752,47 @@ class RandomSwap(Transform):
 
     def apply(self, sample):
         if 'image2' not in sample:
-            raise ValueError('image2 is not found in the sample.')
+            raise ValueError("'image2' is not found in the sample.")
         if random.random() < self.prob:
             sample['image'], sample['image2'] = sample['image2'], sample[
                 'image']
         return sample
 
 
-class ArrangeSegmenter(Transform):
+class Arrange(Transform):
     def __init__(self, mode):
-        super(ArrangeSegmenter, self).__init__()
+        super().__init__()
         if mode not in ['train', 'eval', 'test', 'quant']:
             raise ValueError(
-                "mode should be defined as one of ['train', 'eval', 'test', 'quant']!"
+                "`mode` should be defined as one of ['train', 'eval', 'test', 'quant']!"
             )
         self.mode = mode
 
+
+class ArrangeSegmenter(Arrange):
     def apply(self, sample):
         if 'mask' in sample:
             mask = sample['mask']
+            mask = mask.astype('int64')
 
         image = permute(sample['image'], False)
         if self.mode == 'train':
-            mask = mask.astype('int64')
             return image, mask
         if self.mode == 'eval':
-            mask = np.asarray(Image.open(mask))
-            mask = mask[np.newaxis, :, :].astype('int64')
             return image, mask
         if self.mode == 'test':
             return image,
 
 
-class ArrangeChangeDetector(Transform):
-    def __init__(self, mode):
-        super(ArrangeChangeDetector, self).__init__()
-        if mode not in ['train', 'eval', 'test', 'quant']:
-            raise ValueError(
-                "mode should be defined as one of ['train', 'eval', 'test', 'quant']!"
-            )
-        self.mode = mode
-
+class ArrangeChangeDetector(Arrange):
     def apply(self, sample):
         if 'mask' in sample:
             mask = sample['mask']
+            mask = mask.astype('int64')
 
         image_t1 = permute(sample['image'], False)
         image_t2 = permute(sample['image2'], False)
         if self.mode == 'train':
-            mask = mask.astype('int64')
             masks = [mask]
             if 'aux_masks' in sample:
                 masks.extend(
@@ -1810,22 +1801,12 @@ class ArrangeChangeDetector(Transform):
                 image_t1,
                 image_t2, ) + tuple(masks)
         if self.mode == 'eval':
-            mask = np.asarray(Image.open(mask))
-            mask = mask[np.newaxis, :, :].astype('int64')
             return image_t1, image_t2, mask
         if self.mode == 'test':
             return image_t1, image_t2,
 
 
-class ArrangeClassifier(Transform):
-    def __init__(self, mode):
-        super(ArrangeClassifier, self).__init__()
-        if mode not in ['train', 'eval', 'test', 'quant']:
-            raise ValueError(
-                "mode should be defined as one of ['train', 'eval', 'test', 'quant']!"
-            )
-        self.mode = mode
-
+class ArrangeClassifier(Arrange):
     def apply(self, sample):
         image = permute(sample['image'], False)
         if self.mode in ['train', 'eval']:
@@ -1834,15 +1815,7 @@ class ArrangeClassifier(Transform):
             return image
 
 
-class ArrangeDetector(Transform):
-    def __init__(self, mode):
-        super(ArrangeDetector, self).__init__()
-        if mode not in ['train', 'eval', 'test', 'quant']:
-            raise ValueError(
-                "mode should be defined as one of ['train', 'eval', 'test', 'quant']!"
-            )
-        self.mode = mode
-
+class ArrangeDetector(Arrange):
     def apply(self, sample):
         if self.mode == 'eval' and 'gt_poly' in sample:
             del sample['gt_poly']