|
@@ -12,24 +12,28 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
|
|
|
-import numpy as np
|
|
|
-import cv2
|
|
|
+import os
|
|
|
import copy
|
|
|
import random
|
|
|
-import imghdr
|
|
|
-import os
|
|
|
-from PIL import Image
|
|
|
-import paddlers
|
|
|
-
|
|
|
+from numbers import Number
|
|
|
+from functools import partial
|
|
|
+from operator import methodcaller
|
|
|
try:
|
|
|
from collections.abc import Sequence
|
|
|
except Exception:
|
|
|
from collections import Sequence
|
|
|
-from numbers import Number
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+import cv2
|
|
|
+import imghdr
|
|
|
+from PIL import Image
|
|
|
+import paddlers
|
|
|
+
|
|
|
from .functions import normalize, horizontal_flip, permute, vertical_flip, center_crop, is_poly, \
|
|
|
horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly, vertical_flip_rle, crop_poly, \
|
|
|
crop_rle, expand_poly, expand_rle, resize_poly, resize_rle
|
|
|
|
|
|
+
|
|
|
__all__ = [
|
|
|
"Compose", "ImgDecoder", "Resize", "RandomResize", "ResizeByShort",
|
|
|
"RandomResizeByShort", "ResizeByLong", "RandomHorizontalFlip",
|
|
@@ -78,6 +82,8 @@ class Transform(object):
|
|
|
sample['mask'] = self.apply_mask(sample['mask'])
|
|
|
if 'gt_bbox' in sample:
|
|
|
sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'])
|
|
|
+ if 'aux_masks' in sample:
|
|
|
+ sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
|
|
|
|
|
|
return sample
|
|
|
|
|
@@ -182,6 +188,9 @@ class ImgDecoder(Transform):
|
|
|
if im_height != se_height or im_width != se_width:
|
|
|
raise Exception(
|
|
|
"The height or width of the im is not same as the mask")
|
|
|
+ if 'aux_masks' in sample:
|
|
|
+ sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
|
|
|
+ # TODO: check the shape of auxiliary masks
|
|
|
|
|
|
sample['im_shape'] = np.array(
|
|
|
sample['image'].shape[:2], dtype=np.float32)
|
|
@@ -217,9 +226,12 @@ class Compose(Transform):
|
|
|
self.apply_im_only = False
|
|
|
|
|
|
def __call__(self, sample):
|
|
|
- if self.apply_im_only and 'mask' in sample:
|
|
|
- mask_backup = copy.deepcopy(sample['mask'])
|
|
|
- del sample['mask']
|
|
|
+ if self.apply_im_only:
|
|
|
+ if 'mask' in sample:
|
|
|
+ mask_backup = copy.deepcopy(sample['mask'])
|
|
|
+ del sample['mask']
|
|
|
+ if 'aux_masks' in sample:
|
|
|
+ aux_masks = copy.deepcopy(sample['aux_masks'])
|
|
|
|
|
|
sample = self.decode_image(sample)
|
|
|
|
|
@@ -234,6 +246,8 @@ class Compose(Transform):
|
|
|
if self.arrange_outputs is not None:
|
|
|
if self.apply_im_only:
|
|
|
sample['mask'] = mask_backup
|
|
|
+ if 'aux_masks' in locals():
|
|
|
+ sample['aux_masks'] = aux_masks
|
|
|
sample = self.arrange_outputs(sample)
|
|
|
|
|
|
return sample
|
|
@@ -338,6 +352,8 @@ class Resize(Transform):
|
|
|
|
|
|
if 'mask' in sample:
|
|
|
sample['mask'] = self.apply_mask(sample['mask'], target_size)
|
|
|
+ if 'aux_masks' in sample:
|
|
|
+ sample['aux_masks'] = list(map(partial(self.apply_mask, target_size=target_size), sample['aux_masks']))
|
|
|
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
|
|
|
sample['gt_bbox'] = self.apply_bbox(
|
|
|
sample['gt_bbox'], [im_scale_x, im_scale_y], target_size)
|
|
@@ -538,6 +554,8 @@ class RandomHorizontalFlip(Transform):
|
|
|
sample['image2'] = self.apply_im(sample['image2'])
|
|
|
if 'mask' in sample:
|
|
|
sample['mask'] = self.apply_mask(sample['mask'])
|
|
|
+ if 'aux_masks' in sample:
|
|
|
+ sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
|
|
|
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
|
|
|
sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], im_w)
|
|
|
if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
|
|
@@ -593,6 +611,8 @@ class RandomVerticalFlip(Transform):
|
|
|
sample['image2'] = self.apply_im(sample['image2'])
|
|
|
if 'mask' in sample:
|
|
|
sample['mask'] = self.apply_mask(sample['mask'])
|
|
|
+ if 'aux_masks' in sample:
|
|
|
+ sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
|
|
|
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
|
|
|
sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], im_h)
|
|
|
if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
|
|
@@ -691,6 +711,8 @@ class CenterCrop(Transform):
|
|
|
sample['image2'] = self.apply_im(sample['image2'])
|
|
|
if 'mask' in sample:
|
|
|
sample['mask'] = self.apply_mask(sample['mask'])
|
|
|
+ if 'aux_masks' in sample:
|
|
|
+ sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
|
|
|
return sample
|
|
|
|
|
|
|
|
@@ -882,6 +904,9 @@ class RandomCrop(Transform):
|
|
|
if 'mask' in sample:
|
|
|
sample['mask'] = self.apply_mask(sample['mask'], crop_box)
|
|
|
|
|
|
+ if 'aux_masks' in sample:
|
|
|
+ sample['aux_masks'] = list(map(partial(self.apply_mask, crop=crop_box), sample['aux_masks']))
|
|
|
+
|
|
|
if self.crop_size is not None:
|
|
|
sample = Resize(self.crop_size)(sample)
|
|
|
|
|
@@ -1071,6 +1096,8 @@ class Padding(Transform):
|
|
|
sample['image2'] = self.apply_im(sample['image2'], offsets, (h, w))
|
|
|
if 'mask' in sample:
|
|
|
sample['mask'] = self.apply_mask(sample['mask'], offsets, (h, w))
|
|
|
+ if 'aux_masks' in sample:
|
|
|
+ sample['aux_masks'] = list(map(partial(self.apply_mask, offsets=offsets, target_size=(h,w)), sample['aux_masks']))
|
|
|
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
|
|
|
sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], offsets)
|
|
|
if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
|
|
@@ -1479,7 +1506,10 @@ class ArrangeChangeDetector(Transform):
|
|
|
image_t2 = permute(sample['image2'], False)
|
|
|
if self.mode == 'train':
|
|
|
mask = mask.astype('int64')
|
|
|
- return image_t1, image_t2, mask
|
|
|
+ masks = [mask]
|
|
|
+ if 'aux_masks' in sample:
|
|
|
+ masks.extend(map(methodcaller('astype', 'int64'), sample['aux_masks']))
|
|
|
+ return (image_t1, image_t2,) + tuple(masks)
|
|
|
if self.mode == 'eval':
|
|
|
mask = np.asarray(Image.open(mask))
|
|
|
mask = mask[np.newaxis, :, :].astype('int64')
|