|
@@ -1943,8 +1943,9 @@ class RandomSwap(Transform):
|
|
|
|
|
|
class ReloadMask(Transform):
|
|
|
def apply(self, sample):
|
|
|
- sample['mask'] = F.decode_seg_mask(sample['mask_ori'])
|
|
|
- if 'aux_masks' in sample:
|
|
|
+ if 'mask' in sample or 'mask_ori' in sample:
|
|
|
+ sample['mask'] = F.decode_seg_mask(sample['mask_ori'])
|
|
|
+ if 'aux_masks' in sample or 'aux_masks_ori' in sample:
|
|
|
sample['aux_masks'] = list(
|
|
|
map(F.decode_seg_mask, sample['aux_masks_ori']))
|
|
|
return sample
|