Parcourir la source

[Re-request][Feature] Add support for multi-task CD models (#22)

* [Style] Remove blank line and return builtin list

* [Feature] Add support for multitask CD models
Lin Manhui il y a 3 ans
Parent
commit
d2e80e829d

+ 1 - 0
paddlers/custom_models/cd/models/backbones/resnet.py

@@ -36,6 +36,7 @@ import paddle.nn as nn
 
 from paddle.utils.download import get_weights_path_from_url
 
+
 __all__ = []
 
 model_urls = {

+ 1 - 2
paddlers/custom_models/cd/models/bit.py

@@ -17,7 +17,6 @@ import paddle.nn as nn
 import paddle.nn.functional as F
 from paddle.nn.initializer import Normal
 
-
 from .backbones import resnet
 from .layers import Conv3x3, Conv1x1, get_norm_layer, Identity
 from .param_init import KaimingInitMixin
@@ -182,7 +181,7 @@ class BIT(nn.Layer):
 
         # Classifier forward
         pred = self.conv_out(y)
-        return pred,
+        return [pred]
 
     def init_weight(self):
         # Use the default initialization method.

+ 6 - 5
paddlers/custom_models/cd/models/dsamnet.py

@@ -16,7 +16,6 @@ import paddle
 import paddle.nn as nn
 import paddle.nn.functional as F
 
-
 from .layers import make_norm, Conv3x3, CBAM
 from .stanet import Backbone, Decoder
 
@@ -76,10 +75,12 @@ class DSAMNet(nn.Layer):
         out = F.interpolate(out, size=paddle.shape(t1)[2:], mode='bilinear', align_corners=True)
         pred = self.conv_out(out)
 
-        ds2 = self.dsl2(paddle.abs(f1[0]-f2[0]))
-        ds3 = self.dsl3(paddle.abs(f1[1]-f2[1]))
-
-        return pred, ds2, ds3
+        if not self.training:
+            return [pred]
+        else:
+            ds2 = self.dsl2(paddle.abs(f1[0]-f2[0]))
+            ds3 = self.dsl3(paddle.abs(f1[1]-f2[1]))
+            return [pred, ds2, ds3]
 
     def init_weight(self):
         pass

+ 39 - 30
paddlers/custom_models/cd/models/dsifn.py

@@ -17,7 +17,6 @@ import paddle.nn as nn
 import paddle.nn.functional as F
 from paddle.vision.models import vgg16
 
-
 from .layers import Conv1x1, make_norm, ChannelAttention, SpatialAttention
 
 
@@ -101,19 +100,16 @@ class DSIFN(nn.Layer):
         t1_f_l3, t1_f_l8, t1_f_l15, t1_f_l22, t1_f_l29 = t1_feats
         t2_f_l3, t2_f_l8, t2_f_l15, t2_f_l22, t2_f_l29,= t2_feats
 
+        aux_x = []
+
         # Multi-level decoding
         x = paddle.concat([t1_f_l29, t2_f_l29], axis=1)
         x = self.o1_conv1(x)
         x = self.o1_conv2(x)
         x = self.sa1(x) * x
         x = self.bn_sa1(x)
-
-        out1 = F.interpolate(
-            self.o1_conv3(x), 
-            size=paddle.shape(t1)[2:], 
-            mode='bilinear', 
-            align_corners=True
-        )
+        if self.training:
+            aux_x.append(x)
 
         x = self.trans_conv1(x)
         x = paddle.concat([x, t1_f_l22, t2_f_l22], axis=1)
@@ -123,13 +119,8 @@ class DSIFN(nn.Layer):
         x = self.o2_conv3(x)
         x = self.sa2(x) *x
         x = self.bn_sa2(x)
-
-        out2 = F.interpolate(
-            self.o2_conv4(x), 
-            size=paddle.shape(t1)[2:], 
-            mode='bilinear', 
-            align_corners=True
-        )
+        if self.training:
+            aux_x.append(x)
 
         x = self.trans_conv2(x)
         x = paddle.concat([x, t1_f_l15, t2_f_l15], axis=1)
@@ -139,13 +130,8 @@ class DSIFN(nn.Layer):
         x = self.o3_conv3(x)
         x = self.sa3(x) *x
         x = self.bn_sa3(x)
-
-        out3 = F.interpolate(
-            self.o3_conv4(x), 
-            size=paddle.shape(t1)[2:], 
-            mode='bilinear', 
-            align_corners=True
-        )
+        if self.training:
+            aux_x.append(x)
 
         x = self.trans_conv3(x)
         x = paddle.concat([x, t1_f_l8, t2_f_l8], axis=1)
@@ -155,13 +141,8 @@ class DSIFN(nn.Layer):
         x = self.o4_conv3(x)
         x = self.sa4(x) *x
         x = self.bn_sa4(x)
-
-        out4 = F.interpolate(
-            self.o4_conv4(x), 
-            size=paddle.shape(t1)[2:], 
-            mode='bilinear', 
-            align_corners=True
-        )
+        if self.training:
+            aux_x.append(x)
 
         x = self.trans_conv4(x)
         x = paddle.concat([x, t1_f_l3, t2_f_l3], axis=1)
@@ -174,7 +155,35 @@ class DSIFN(nn.Layer):
 
         out5 = self.o5_conv4(x)
 
-        return out5, out4, out3, out2, out1
+        if not self.training:
+            return [out5]
+        else:
+            size = paddle.shape(t1)[2:]
+            out1 = F.interpolate(
+                self.o1_conv3(aux_x[0]), 
+                size=size, 
+                mode='bilinear', 
+                align_corners=True
+            )
+            out2 = F.interpolate(
+                self.o2_conv4(aux_x[1]), 
+                size=size, 
+                mode='bilinear', 
+                align_corners=True
+            )
+            out3 = F.interpolate(
+                self.o3_conv4(aux_x[2]), 
+                size=size, 
+                mode='bilinear', 
+                align_corners=True
+            )
+            out4 = F.interpolate(
+                self.o4_conv4(aux_x[3]), 
+                size=size, 
+                mode='bilinear', 
+                align_corners=True
+            )
+            return [out5, out4, out3, out2, out1]
 
     def init_weight(self):
         # Do nothing

+ 0 - 1
paddlers/custom_models/cd/models/layers/attention.py

@@ -16,7 +16,6 @@ import paddle
 import paddle.nn as nn
 import paddle.nn.functional as F
 
-
 from .blocks import Conv1x1, BasicConv
 
 

+ 1 - 3
paddlers/custom_models/cd/models/snunet.py

@@ -12,12 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import paddle
 import paddle.nn as nn
 import paddle.nn.functional as F
 
-
 from .layers import Conv1x1, MaxPool2x2, make_norm, ChannelAttention
 from .param_init import KaimingInitMixin
 
@@ -116,7 +114,7 @@ class SNUNet(nn.Layer, KaimingInitMixin):
         out = self.ca_inter(out) * (out + paddle.tile(m_intra, (1,4,1,1)))
 
         pred = self.conv_out(out)
-        return pred,
+        return [pred]
 
 
 class ConvBlockNested(nn.Layer):

+ 1 - 2
paddlers/custom_models/cd/models/stanet.py

@@ -16,7 +16,6 @@ import paddle
 import paddle.nn as nn
 import paddle.nn.functional as F
 
-
 from .backbones import resnet
 from .layers import Conv1x1, Conv3x3, get_norm_layer, Identity
 from .param_init import KaimingInitMixin
@@ -76,7 +75,7 @@ class STANet(nn.Layer):
         y = F.interpolate(y, size=paddle.shape(t1)[2:], mode='bilinear', align_corners=True)
 
         pred = self.conv_out(y)
-        return pred,
+        return [pred]
 
     def init_weight(self):
         # Do nothing here as the encoder and decoder weights have already been initialized.

+ 1 - 2
paddlers/custom_models/cd/models/unet_ef.py

@@ -16,7 +16,6 @@ import paddle
 import paddle.nn as nn
 import paddle.nn.functional as F
 
-
 from .layers import Conv3x3, MaxPool2x2, ConvTransposed3x3, Identity
 from .param_init import normal_init, constant_init
 
@@ -184,7 +183,7 @@ class UNetEarlyFusion(nn.Layer):
         x12d = self.do12d(self.conv12d(x1d))
         x11d = self.conv11d(x12d)
 
-        return x11d,
+        return [x11d]
 
     def init_weight(self):
         for sublayer in self.sublayers():

+ 1 - 2
paddlers/custom_models/cd/models/unet_siamconc.py

@@ -16,7 +16,6 @@ import paddle
 import paddle.nn as nn
 import paddle.nn.functional as F
 
-
 from .layers import Conv3x3, MaxPool2x2, ConvTransposed3x3, Identity
 from .param_init import normal_init, constant_init
 
@@ -207,7 +206,7 @@ class UNetSiamConc(nn.Layer):
         x12d = self.do12d(self.conv12d(x1d))
         x11d = self.conv11d(x12d)
 
-        return x11d,
+        return [x11d]
 
     def init_weight(self):
         for sublayer in self.sublayers():

+ 1 - 2
paddlers/custom_models/cd/models/unet_siamdiff.py

@@ -16,7 +16,6 @@ import paddle
 import paddle.nn as nn
 import paddle.nn.functional as F
 
-
 from .layers import Conv3x3, MaxPool2x2, ConvTransposed3x3, Identity
 from .param_init import normal_init, constant_init
 
@@ -207,7 +206,7 @@ class UNetSiamDiff(nn.Layer):
         x12d = self.do12d(self.conv12d(x1d))
         x11d = self.conv11d(x12d)
 
-        return x11d,
+        return [x11d]
 
     def init_weight(self):
         for sublayer in self.sublayers():

+ 57 - 20
paddlers/datasets/cd_dataset.py

@@ -12,15 +12,17 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import os.path as osp
 import copy
+from enum import IntEnum
+import os.path as osp
 
 from paddle.io import Dataset
 from paddlers.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
 
 
 class CDDataset(Dataset):
-    """读取变化检测任务数据集,并对样本进行相应的处理(来自SegDataset,图像标签需要两个)。
+    """
+    读取变化检测任务数据集,并对样本进行相应的处理(来自SegDataset,图像标签需要两个)。
 
     Args:
         data_dir (str): 数据集所在的目录路径。
@@ -29,6 +31,7 @@ class CDDataset(Dataset):
         transforms (paddlers.transforms): 数据集中每个样本的预处理/增强算子。
         num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
+        with_seg_labels (bool, optional): 数据集中是否包含两个时相的语义分割标签。默认为False。
     """
 
     def __init__(self,
@@ -37,15 +40,24 @@ class CDDataset(Dataset):
                  label_list=None,
                  transforms=None,
                  num_workers='auto',
-                 shuffle=False):
+                 shuffle=False,
+                 with_seg_labels=False):
         super(CDDataset, self).__init__()
+
+        DELIMETER = ' '
+
         self.transforms = copy.deepcopy(transforms)
-        # TODO batch padding
+        # TODO: batch padding
         self.batch_transforms = None
         self.num_workers = get_num_workers(num_workers)
         self.shuffle = shuffle
         self.file_list = list()
         self.labels = list()
+        self.with_seg_labels = with_seg_labels
+        if self.with_seg_labels:
+            num_items = 5   # 3+2
+        else:
+            num_items = 3
 
         # TODO:非None时,让用户跳转数据集分析生成label_list
         # 不要在此处分析label file
@@ -54,19 +66,20 @@ class CDDataset(Dataset):
                 for line in f:
                     item = line.strip()
                     self.labels.append(item)
+                    
         with open(file_list, encoding=get_encoding(file_list)) as f:
             for line in f:
-                items = line.strip().split()
-                if len(items) > 3:
-                    raise Exception(
-                        "A space is defined as the delimiter to separate the image and label path, " \
-                        "so the space cannot be in the image or label path, but the line[{}] of " \
-                        " file_list[{}] has a space in the image or label path.".format(line, file_list))
-                items[0] = path_normalization(items[0])
-                items[1] = path_normalization(items[1])
-                items[2] = path_normalization(items[2])
-                if not is_pic(items[0]) or not is_pic(items[1]) or not is_pic(items[2]):
+                items = line.strip().split(DELIMETER)
+
+                if len(items) != num_items:
+                    raise Exception("Line[{}] in file_list[{}] has an incorrect number of file paths.".format(
+                        line.strip(), file_list
+                    ))
+
+                items = list(map(path_normalization, items))
+                if not all(map(is_pic, items)):
                     continue
+
                 full_path_im_t1 = osp.join(data_dir, items[0])
                 full_path_im_t2 = osp.join(data_dir, items[1])
                 full_path_label = osp.join(data_dir, items[2])
@@ -79,11 +92,27 @@ class CDDataset(Dataset):
                 if not osp.exists(full_path_label):
                     raise IOError('Label file {} does not exist!'.format(
                         full_path_label))
-                self.file_list.append({
-                    'image_t1': full_path_im_t1,
-                    'image_t2': full_path_im_t2,
-                    'mask': full_path_label
-                })
+
+                if with_seg_labels:
+                    full_path_seg_label_t1 = osp.join(data_dir, items[3])
+                    full_path_seg_label_t2 = osp.join(data_dir, items[4])
+                    if not osp.exists(full_path_seg_label_t1):
+                        raise IOError('Label file {} does not exist!'.format(
+                            full_path_seg_label_t1))
+                    if not osp.exists(full_path_seg_label_t2):
+                        raise IOError('Label file {} does not exist!'.format(
+                            full_path_seg_label_t2))
+
+                item_dict = dict(
+                    image_t1=full_path_im_t1,
+                    image_t2=full_path_im_t2,
+                    mask=full_path_label
+                )
+                if with_seg_labels:
+                    item_dict['aux_masks'] = [full_path_seg_label_t1, full_path_seg_label_t2]
+
+                self.file_list.append(item_dict)
+
         self.num_samples = len(self.file_list)
         logging.info("{} samples in file {}".format(
             len(self.file_list), file_list))
@@ -91,7 +120,15 @@ class CDDataset(Dataset):
     def __getitem__(self, idx):
         sample = copy.deepcopy(self.file_list[idx])
         outputs = self.transforms(sample)
+
         return outputs
 
     def __len__(self):
-        return len(self.file_list)
+        return len(self.file_list)
+
+
+class MaskType(IntEnum):
+    """Enumeration of the mask types used in the change detection task."""
+    CD = 0
+    SEG_T1 = 1
+    SEG_T2 = 2

+ 47 - 23
paddlers/tasks/changedetector.py

@@ -14,24 +14,38 @@
 
 import math
 import os.path as osp
-import numpy as np
-import cv2
 from collections import OrderedDict
+from operator import attrgetter
+
+import cv2
+import numpy as np
 import paddle
 import paddle.nn.functional as F
 from paddle.static import InputSpec
-import paddlers.models.ppseg as paddleseg
+
 import paddlers
+import paddlers.custom_models.cd as cd
+import paddlers.utils.logging as logging
+import paddlers.models.ppseg as paddleseg
 from paddlers.transforms import arrange_transforms
+from paddlers.transforms import ImgDecoder, Resize
 from paddlers.utils import get_single_card_bs, DisablePrint
-import paddlers.utils.logging as logging
+from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from .base import BaseModel
 from .utils import seg_metrics as metrics
-from paddlers.utils.checkpoint import seg_pretrain_weights_dict
-from paddlers.transforms import ImgDecoder, Resize
-import paddlers.custom_models.cd as cd
 
-__all__ = ["CDNet", "UNetEarlyFusion", "UNetSiamConc", "UNetSiamDiff", "STANet", "BIT", "SNUNet", "DSIFN", "DSAMNet"]
+
+__all__ = [
+    "CDNet", 
+    "UNetEarlyFusion", 
+    "UNetSiamConc", 
+    "UNetSiamDiff", 
+    "STANet", 
+    "BIT", 
+    "SNUNet", 
+    "DSIFN", 
+    "DSAMNet"
+]
 
 
 class BaseChangeDetector(BaseModel):
@@ -143,8 +157,16 @@ class BaseChangeDetector(BaseModel):
             outputs['conf_mat'] = metrics.confusion_matrix(pred, label,
                                                            self.num_classes)
         if mode == 'train':
-            loss_list = metrics.loss_computation(
-                logits_list=net_out, labels=inputs[2], losses=self.losses)
+            if hasattr(net, 'USE_MULTITASK_DECODER') and net.USE_MULTITASK_DECODER is True:
+                # CD+Seg
+                if len(inputs) != 5:
+                    raise ValueError("Cannot perform loss computation with {} inputs.".format(len(inputs)))
+                labels_list = [inputs[2+idx] for idx in map(attrgetter('value'), net.OUT_TYPES)]
+                loss_list = metrics.multitask_loss_computation(
+                    logits_list=net_out, labels_list=labels_list, losses=self.losses)
+            else:
+                loss_list = metrics.loss_computation(
+                    logits_list=net_out, labels=inputs[2], losses=self.losses)
             loss = sum(loss_list)
             outputs['loss'] = loss
         return outputs
@@ -798,7 +820,7 @@ class SNUNet(BaseChangeDetector):
 class DSIFN(BaseChangeDetector):
     def __init__(self,
                  num_classes=2,
-                 use_mixed_loss=None,
+                 use_mixed_loss=False,
                  use_dropout=False,
                  **params):
         params.update({
@@ -809,21 +831,22 @@ class DSIFN(BaseChangeDetector):
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
             **params)
-        # HACK: currently the only legal value of `use_mixed_loss` is None, in which case the loss specifications are
-        # constructed automatically.
-        assert use_mixed_loss is None
-        if use_mixed_loss is None:
-            self.losses = {
+
+    def default_loss(self):
+        if self.use_mixed_loss is False:
+            return {
                 # XXX: make sure the shallow copy works correctly here.
                 'types': [paddleseg.models.CrossEntropyLoss()]*5,
                 'coef': [1.0]*5
             }
+        else:
+            return super().default_loss()
 
 
 class DSAMNet(BaseChangeDetector):
     def __init__(self,
                  num_classes=2,
-                 use_mixed_loss=None,
+                 use_mixed_loss=False,
                  in_channels=3,
                  ca_ratio=8,
                  sa_kernel=7,
@@ -838,15 +861,16 @@ class DSAMNet(BaseChangeDetector):
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
             **params)
-        # HACK: currently the only legal value of `use_mixed_loss` is None, in which case the loss specifications are
-        # constructed automatically.
-        assert use_mixed_loss is None
-        if use_mixed_loss is None:
-            self.losses = {
+
+    def default_loss(self):
+        if self.use_mixed_loss is False:
+            return {
                 'types': [
                     paddleseg.models.CrossEntropyLoss(), 
                     paddleseg.models.DiceLoss(), 
                     paddleseg.models.DiceLoss()
                 ],
                 'coef': [1.0, 0.05, 0.05]
-            }
+            }
+        else:
+            return super().default_loss()

+ 11 - 0
paddlers/tasks/utils/seg_metrics.py

@@ -26,6 +26,17 @@ def loss_computation(logits_list, labels, losses):
     return loss_list
 
 
+def multitask_loss_computation(logits_list, labels_list, losses):
+    loss_list = []
+    for i in range(len(logits_list)):
+        logits = logits_list[i]
+        labels = labels_list[i]
+        loss_i = losses['types'][i]
+        loss_list.append(losses['coef'][i] * loss_i(logits, labels))
+
+    return loss_list
+
+
 def f1_score(intersect_area, pred_area, label_area):
     intersect_area = intersect_area.numpy()
     pred_area = pred_area.numpy()

+ 42 - 12
paddlers/transforms/operators.py

@@ -12,24 +12,28 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import numpy as np
-import cv2
+import os
 import copy
 import random
-import imghdr
-import os
-from PIL import Image
-import paddlers
-
+from numbers import Number
+from functools import partial
+from operator import methodcaller
 try:
     from collections.abc import Sequence
 except Exception:
     from collections import Sequence
-from numbers import Number
+
+import numpy as np
+import cv2
+import imghdr
+from PIL import Image
+import paddlers
+
 from .functions import normalize, horizontal_flip, permute, vertical_flip, center_crop, is_poly, \
     horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly, vertical_flip_rle, crop_poly, \
     crop_rle, expand_poly, expand_rle, resize_poly, resize_rle
 
+
 __all__ = [
     "Compose", "ImgDecoder", "Resize", "RandomResize", "ResizeByShort",
     "RandomResizeByShort", "ResizeByLong", "RandomHorizontalFlip",
@@ -78,6 +82,8 @@ class Transform(object):
             sample['mask'] = self.apply_mask(sample['mask'])
         if 'gt_bbox' in sample:
             sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'])
+        if 'aux_masks' in sample:
+            sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
 
         return sample
 
@@ -182,6 +188,9 @@ class ImgDecoder(Transform):
             if im_height != se_height or im_width != se_width:
                 raise Exception(
                     "The height or width of the im is not same as the mask")
+        if 'aux_masks' in sample:
+            sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
+            # TODO: check the shape of auxiliary masks
 
         sample['im_shape'] = np.array(
             sample['image'].shape[:2], dtype=np.float32)
@@ -217,9 +226,12 @@ class Compose(Transform):
         self.apply_im_only = False
 
     def __call__(self, sample):
-        if self.apply_im_only and 'mask' in sample:
-            mask_backup = copy.deepcopy(sample['mask'])
-            del sample['mask']
+        if self.apply_im_only:
+            if 'mask' in sample:
+                mask_backup = copy.deepcopy(sample['mask'])
+                del sample['mask']
+            if 'aux_masks' in sample:
+                aux_masks = copy.deepcopy(sample['aux_masks'])
 
         sample = self.decode_image(sample)
 
@@ -234,6 +246,8 @@ class Compose(Transform):
         if self.arrange_outputs is not None:
             if self.apply_im_only:
                 sample['mask'] = mask_backup
+                if 'aux_masks' in locals():
+                    sample['aux_masks'] = aux_masks
             sample = self.arrange_outputs(sample)
 
         return sample
@@ -338,6 +352,8 @@ class Resize(Transform):
 
         if 'mask' in sample:
             sample['mask'] = self.apply_mask(sample['mask'], target_size)
+        if 'aux_masks' in sample:
+            sample['aux_masks'] = list(map(partial(self.apply_mask, target_size=target_size), sample['aux_masks']))
         if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
             sample['gt_bbox'] = self.apply_bbox(
                 sample['gt_bbox'], [im_scale_x, im_scale_y], target_size)
@@ -538,6 +554,8 @@ class RandomHorizontalFlip(Transform):
                 sample['image2'] = self.apply_im(sample['image2'])
             if 'mask' in sample:
                 sample['mask'] = self.apply_mask(sample['mask'])
+            if 'aux_masks' in sample:
+                sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
             if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
                 sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], im_w)
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
@@ -593,6 +611,8 @@ class RandomVerticalFlip(Transform):
                 sample['image2'] = self.apply_im(sample['image2'])
             if 'mask' in sample:
                 sample['mask'] = self.apply_mask(sample['mask'])
+            if 'aux_masks' in sample:
+                sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
             if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
                 sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], im_h)
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
@@ -691,6 +711,8 @@ class CenterCrop(Transform):
                 sample['image2'] = self.apply_im(sample['image2'])
         if 'mask' in sample:
             sample['mask'] = self.apply_mask(sample['mask'])
+        if 'aux_masks' in sample:
+            sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
         return sample
 
 
@@ -882,6 +904,9 @@ class RandomCrop(Transform):
             if 'mask' in sample:
                 sample['mask'] = self.apply_mask(sample['mask'], crop_box)
 
+            if 'aux_masks' in sample:
+                sample['aux_masks'] = list(map(partial(self.apply_mask, crop=crop_box), sample['aux_masks']))
+
         if self.crop_size is not None:
             sample = Resize(self.crop_size)(sample)
 
@@ -1071,6 +1096,8 @@ class Padding(Transform):
                 sample['image2'] = self.apply_im(sample['image2'], offsets, (h, w))
         if 'mask' in sample:
             sample['mask'] = self.apply_mask(sample['mask'], offsets, (h, w))
+        if 'aux_masks' in sample:
+            sample['aux_masks'] = list(map(partial(self.apply_mask, offsets=offsets, target_size=(h,w)), sample['aux_masks']))
         if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
             sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], offsets)
         if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
@@ -1479,7 +1506,10 @@ class ArrangeChangeDetector(Transform):
         image_t2 = permute(sample['image2'], False)
         if self.mode == 'train':
             mask = mask.astype('int64')
-            return image_t1, image_t2, mask
+            masks = [mask]
+            if 'aux_masks' in sample:
+                masks.extend(map(methodcaller('astype', 'int64'), sample['aux_masks']))
+            return (image_t1, image_t2,) + tuple(masks)
         if self.mode == 'eval':
             mask = np.asarray(Image.open(mask))
             mask = mask[np.newaxis, :, :].astype('int64')