|
@@ -35,8 +35,8 @@ __all__ = [
|
|
|
"RandomResizeByShort", "ResizeByLong", "RandomHorizontalFlip",
|
|
|
"RandomVerticalFlip", "Normalize", "CenterCrop", "RandomCrop",
|
|
|
"RandomScaleAspect", "RandomExpand", "Padding", "MixupImage",
|
|
|
- "RandomDistort", "RandomBlur", "ArrangeSegmenter", "ArrangeClassifier",
|
|
|
- "ArrangeDetector"
|
|
|
+ "RandomDistort", "RandomBlur", "ArrangeSegmenter", "ArrangeChangeDetector",
|
|
|
+ "ArrangeClassifier", "ArrangeDetector"
|
|
|
]
|
|
|
|
|
|
interp_dict = {
|
|
@@ -69,7 +69,11 @@ class Transform(object):
|
|
|
pass
|
|
|
|
|
|
def apply(self, sample):
|
|
|
- sample['image'] = self.apply_im(sample['image'])
|
|
|
+ if 'image' in sample:
|
|
|
+ sample['image'] = self.apply_im(sample['image'])
|
|
|
+ else: # image_tx
|
|
|
+ sample['image'] = self.apply_im(sample['image_t1'])
|
|
|
+ sample['image2'] = self.apply_im(sample['image_t2'])
|
|
|
if 'mask' in sample:
|
|
|
sample['mask'] = self.apply_mask(sample['mask'])
|
|
|
if 'gt_bbox' in sample:
|
|
@@ -175,7 +179,7 @@ class Decode(Transform):
|
|
|
return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
|
|
|
cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR)
|
|
|
else:
|
|
|
- return cv2.imread(im_file, cv2.IMREAD_ANYDEPTH |
|
|
|
+ return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
|
|
|
cv2.IMREAD_ANYCOLOR)
|
|
|
elif ext == '.npy':
|
|
|
return np.load(img_path)
|
|
@@ -218,7 +222,11 @@ class Decode(Transform):
|
|
|
dict: Decoded sample.
|
|
|
|
|
|
"""
|
|
|
- sample['image'] = self.apply_im(sample['image'])
|
|
|
+ if 'image' in sample:
|
|
|
+ sample['image'] = self.apply_im(sample['image'])
|
|
|
+ else: # image_tx
|
|
|
+ sample['image'] = self.apply_im(sample['image_t1'])
|
|
|
+ sample['image2'] = self.apply_im(sample['image_t2'])
|
|
|
if 'mask' in sample:
|
|
|
sample['mask'] = self.apply_mask(sample['mask'])
|
|
|
im_height, im_width, _ = sample['image'].shape
|
|
@@ -323,6 +331,8 @@ class Resize(Transform):
|
|
|
im_scale_x = target_w / im_w
|
|
|
|
|
|
sample['image'] = self.apply_im(sample['image'], interp, target_size)
|
|
|
+ if 'image2' in sample:
|
|
|
+ sample['image2'] = self.apply_im(sample['image2'], interp, target_size)
|
|
|
|
|
|
if 'mask' in sample:
|
|
|
sample['mask'] = self.apply_mask(sample['mask'], target_size)
|
|
@@ -523,6 +533,8 @@ class RandomHorizontalFlip(Transform):
|
|
|
if random.random() < self.prob:
|
|
|
im_h, im_w = sample['image'].shape[:2]
|
|
|
sample['image'] = self.apply_im(sample['image'])
|
|
|
+ if 'image2' in sample:
|
|
|
+ sample['image2'] = self.apply_im(sample['image2'])
|
|
|
if 'mask' in sample:
|
|
|
sample['mask'] = self.apply_mask(sample['mask'])
|
|
|
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
|
|
@@ -576,6 +588,8 @@ class RandomVerticalFlip(Transform):
|
|
|
if random.random() < self.prob:
|
|
|
im_h, im_w = sample['image'].shape[:2]
|
|
|
sample['image'] = self.apply_im(sample['image'])
|
|
|
+ if 'image2' in sample:
|
|
|
+ sample['image2'] = self.apply_im(sample['image2'])
|
|
|
if 'mask' in sample:
|
|
|
sample['mask'] = self.apply_mask(sample['mask'])
|
|
|
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
|
|
@@ -636,6 +650,8 @@ class Normalize(Transform):
|
|
|
|
|
|
def apply(self, sample):
|
|
|
sample['image'] = self.apply_im(sample['image'])
|
|
|
+ if 'image2' in sample:
|
|
|
+ sample['image2'] = self.apply_im(sample['image2'])
|
|
|
|
|
|
return sample
|
|
|
|
|
@@ -665,6 +681,8 @@ class CenterCrop(Transform):
|
|
|
|
|
|
def apply(self, sample):
|
|
|
sample['image'] = self.apply_im(sample['image'])
|
|
|
+ if 'image2' in sample:
|
|
|
+ sample['image2'] = self.apply_im(sample['image2'])
|
|
|
if 'mask' in sample:
|
|
|
sample['mask'] = self.apply_mask(sample['mask'])
|
|
|
return sample
|
|
@@ -819,6 +837,8 @@ class RandomCrop(Transform):
|
|
|
crop_box, cropped_box, valid_ids = crop_info
|
|
|
im_h, im_w = sample['image'].shape[:2]
|
|
|
sample['image'] = self.apply_im(sample['image'], crop_box)
|
|
|
+ if 'image2' in sample:
|
|
|
+ sample['image2'] = self.apply_im(sample['image2'], crop_box)
|
|
|
if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
|
|
|
crop_polys = self._crop_segm(
|
|
|
sample['gt_poly'],
|
|
@@ -1045,6 +1065,8 @@ class Padding(Transform):
|
|
|
offsets = [w - im_w, h - im_h]
|
|
|
|
|
|
sample['image'] = self.apply_im(sample['image'], offsets, (h, w))
|
|
|
+ if 'image2' in sample:
|
|
|
+ sample['image2'] = self.apply_im(sample['image2'], offsets, (h, w))
|
|
|
if 'mask' in sample:
|
|
|
sample['mask'] = self.apply_mask(sample['mask'], offsets, (h, w))
|
|
|
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
|
|
@@ -1239,22 +1261,33 @@ class RandomDistort(Transform):
|
|
|
distortions = np.random.permutation(functions)[:self.count]
|
|
|
for func in distortions:
|
|
|
sample['image'] = func(sample['image'])
|
|
|
+ if 'image2' in sample:
|
|
|
+ sample['image2'] = func(sample['image2'])
|
|
|
return sample
|
|
|
|
|
|
sample['image'] = self.apply_brightness(sample['image'])
|
|
|
+ if 'image2' in sample:
|
|
|
+ sample['image2'] = self.apply_brightness(sample['image2'])
|
|
|
mode = np.random.randint(0, 2)
|
|
|
if mode:
|
|
|
sample['image'] = self.apply_contrast(sample['image'])
|
|
|
+ if 'image2' in sample:
|
|
|
+ sample['image2'] = self.apply_contrast(sample['image2'])
|
|
|
sample['image'] = self.apply_saturation(sample['image'])
|
|
|
sample['image'] = self.apply_hue(sample['image'])
|
|
|
+ if 'image2' in sample:
|
|
|
+ sample['image2'] = self.apply_saturation(sample['image2'])
|
|
|
+ sample['image2'] = self.apply_hue(sample['image2'])
|
|
|
if not mode:
|
|
|
sample['image'] = self.apply_contrast(sample['image'])
|
|
|
+ if 'image2' in sample:
|
|
|
+ sample['image2'] = self.apply_contrast(sample['image2'])
|
|
|
|
|
|
if self.shuffle_channel:
|
|
|
if np.random.randint(0, 2):
|
|
|
- sample['image'] = sample['image'][..., np.random.permutation(
|
|
|
- 3)]
|
|
|
-
|
|
|
+ sample['image'] = sample['image'][..., np.random.permutation(3)]
|
|
|
+ if 'image2' in sample:
|
|
|
+ sample['image2'] = sample['image2'][..., np.random.permutation(3)]
|
|
|
return sample
|
|
|
|
|
|
|
|
@@ -1289,7 +1322,8 @@ class RandomBlur(Transform):
|
|
|
if radius > 9:
|
|
|
radius = 9
|
|
|
sample['image'] = self.apply_im(sample['image'], radius)
|
|
|
-
|
|
|
+ if 'image2' in sample:
|
|
|
+ sample['image2'] = self.apply_im(sample['image2'], radius)
|
|
|
return sample
|
|
|
|
|
|
|
|
@@ -1374,6 +1408,8 @@ class _Permute(Transform):
|
|
|
|
|
|
def apply(self, sample):
|
|
|
sample['image'] = permute(sample['image'], False)
|
|
|
+ if 'image2' in sample:
|
|
|
+ sample['image2'] = permute(sample['image2'], False)
|
|
|
return sample
|
|
|
|
|
|
|
|
@@ -1415,8 +1451,8 @@ class ArrangeChangeDetector(Transform):
|
|
|
if 'mask' in sample:
|
|
|
mask = sample['mask']
|
|
|
|
|
|
- image_t1 = permute(sample['image_t1'], False)
|
|
|
- image_t2 = permute(sample['image_t2'], False)
|
|
|
+ image_t1 = permute(sample['image'], False)
|
|
|
+ image_t2 = permute(sample['image2'], False)
|
|
|
if self.mode == 'train':
|
|
|
mask = mask.astype('int64')
|
|
|
return image_t1, image_t2, mask
|