Parcourir la source

Add match_lf_components code

Bobholamovic il y a 2 ans
Parent
commit
bd7c671e21

+ 49 - 0
paddlers/transforms/functions.py

@@ -604,6 +604,55 @@ def match_by_regression(im, ref, pif_loc=None):
     return matched
 
 
+def match_lf_components(im, ref, lf_ratio=0.01):
+    """
+    Match the low-frequency components of two images.
+
+    Args:
+        im (np.ndarray): Input image.
+        ref (np.ndarray): Reference image to match. `ref` must have the same shape 
+            as `im`.
+        lf_ratio (float, optional): Proportion of frequence components that should
+            be recognized as low-frequency components in the frequency domain. 
+            Default: 0.01.
+
+    Returns:
+        np.ndarray: Transformed input image.
+
+    Raises:
+        ValueError: When the shape of `ref` differs from that of `im`.
+    """
+
+    def _replace_lf(im, ref, lf_ratio):
+        h, w = im.shape
+        h_lf, w_lf = int(h // 2 * lf_ratio), int(w // 2 * lf_ratio)
+        freq_im = np.fft.fft2(im)
+        freq_ref = np.fft.fft2(ref)
+        if h_lf > 0:
+            freq_im[:h_lf] = freq_ref[:h_lf]
+            freq_im[-h_lf:] = freq_ref[-h_lf:]
+        if w_lf > 0:
+            freq_im[:, :w_lf] = freq_ref[:, :w_lf]
+            freq_im[:, -w_lf:] = freq_ref[:, -w_lf:]
+        recon_im = np.fft.ifft2(freq_im)
+        recon_im = np.abs(recon_im)
+        return recon_im
+
+    if im.shape != ref.shape:
+        raise ValueError("Image and Reference must have the same shape!")
+
+    if im.ndim > 2:
+        # Multiple channels
+        matched = np.empty(im.shape, dtype=im.dtype)
+        for ch in range(im.shape[-1]):
+            matched[..., ch] = _replace_lf(im[..., ch], ref[..., ch], lf_ratio)
+    else:
+        # Single channel
+        matched = _replace_lf(im, ref, lf_ratio).astype(im.dtype)
+
+    return matched
+
+
 def inv_pca(im, joblib_path):
     """
     Perform inverse PCA transformation.

+ 45 - 48
paddlers/transforms/operators.py

@@ -27,14 +27,8 @@ from PIL import Image
 from joblib import load
 
 import paddlers
+import paddlers.transforms.functions as F
 import paddlers.transforms.indices as indices
-from .functions import (
-    normalize, horizontal_flip, permute, vertical_flip, center_crop, is_poly,
-    horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly,
-    vertical_flip_rle, crop_poly, crop_rle, expand_poly, expand_rle,
-    resize_poly, resize_rle, dehaze, select_bands, to_intensity, to_uint8,
-    img_flip, img_simple_rotate, decode_seg_mask, calc_hr_shape,
-    match_by_regression, match_histograms)
 
 __all__ = [
     "Compose",
@@ -243,7 +237,7 @@ class DecodeImg(Transform):
                 raise IOError('Cannot open', img_path)
             im_data = dataset.ReadAsArray()
             if im_data.ndim == 2 and self.decode_sar:
-                im_data = to_intensity(im_data)
+                im_data = F.to_intensity(im_data)
                 im_data = im_data[:, :, np.newaxis]
             else:
                 if im_data.ndim == 3:
@@ -287,7 +281,7 @@ class DecodeImg(Transform):
             image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
         if self.to_uint8:
-            image = to_uint8(image)
+            image = F.to_uint8(image)
 
         if self.read_geo_info:
             return image, geo_info_dict
@@ -442,15 +436,15 @@ class Resize(Transform):
         im_scale_x, im_scale_y = scale
         resized_segms = []
         for segm in segms:
-            if is_poly(segm):
+            if F.is_poly(segm):
                 # Polygon format
                 resized_segms.append([
-                    resize_poly(poly, im_scale_x, im_scale_y) for poly in segm
+                    F.resize_poly(poly, im_scale_x, im_scale_y) for poly in segm
                 ])
             else:
                 # RLE format
                 resized_segms.append(
-                    resize_rle(segm, im_h, im_w, im_scale_x, im_scale_y))
+                    F.resize_rle(segm, im_h, im_w, im_scale_x, im_scale_y))
 
         return resized_segms
 
@@ -495,7 +489,7 @@ class Resize(Transform):
                 # For SR tasks
                 sample['target'] = self.apply_im(
                     sample['target'], interp,
-                    calc_hr_shape(target_size, sample['sr_factor']))
+                    F.calc_hr_shape(target_size, sample['sr_factor']))
             else:
                 # For non-SR tasks
                 sample['target'] = self.apply_im(sample['target'], interp,
@@ -692,16 +686,16 @@ class RandomFlipOrRotate(Transform):
 
     def apply_im(self, image, mode_id, flip_mode=True):
         if flip_mode:
-            image = img_flip(image, mode_id)
+            image = F.img_flip(image, mode_id)
         else:
-            image = img_simple_rotate(image, mode_id)
+            image = F.img_simple_rotate(image, mode_id)
         return image
 
     def apply_mask(self, mask, mode_id, flip_mode=True):
         if flip_mode:
-            mask = img_flip(mask, mode_id)
+            mask = F.img_flip(mask, mode_id)
         else:
-            mask = img_simple_rotate(mask, mode_id)
+            mask = F.img_simple_rotate(mask, mode_id)
         return mask
 
     def apply_bbox(self, bbox, mode_id, flip_mode=True):
@@ -817,11 +811,11 @@ class RandomHorizontalFlip(Transform):
         self.prob = prob
 
     def apply_im(self, image):
-        image = horizontal_flip(image)
+        image = F.horizontal_flip(image)
         return image
 
     def apply_mask(self, mask):
-        mask = horizontal_flip(mask)
+        mask = F.horizontal_flip(mask)
         return mask
 
     def apply_bbox(self, bbox, width):
@@ -834,13 +828,13 @@ class RandomHorizontalFlip(Transform):
     def apply_segm(self, segms, height, width):
         flipped_segms = []
         for segm in segms:
-            if is_poly(segm):
+            if F.is_poly(segm):
                 # Polygon format
                 flipped_segms.append(
-                    [horizontal_flip_poly(poly, width) for poly in segm])
+                    [F.horizontal_flip_poly(poly, width) for poly in segm])
             else:
                 # RLE format
-                flipped_segms.append(horizontal_flip_rle(segm, height, width))
+                flipped_segms.append(F.horizontal_flip_rle(segm, height, width))
         return flipped_segms
 
     def apply(self, sample):
@@ -877,11 +871,11 @@ class RandomVerticalFlip(Transform):
         self.prob = prob
 
     def apply_im(self, image):
-        image = vertical_flip(image)
+        image = F.vertical_flip(image)
         return image
 
     def apply_mask(self, mask):
-        mask = vertical_flip(mask)
+        mask = F.vertical_flip(mask)
         return mask
 
     def apply_bbox(self, bbox, height):
@@ -894,13 +888,13 @@ class RandomVerticalFlip(Transform):
     def apply_segm(self, segms, height, width):
         flipped_segms = []
         for segm in segms:
-            if is_poly(segm):
+            if F.is_poly(segm):
                 # Polygon format
                 flipped_segms.append(
-                    [vertical_flip_poly(poly, height) for poly in segm])
+                    [F.vertical_flip_poly(poly, height) for poly in segm])
             else:
                 # RLE format
-                flipped_segms.append(vertical_flip_rle(segm, height, width))
+                flipped_segms.append(F.vertical_flip_rle(segm, height, width))
         return flipped_segms
 
     def apply(self, sample):
@@ -978,7 +972,7 @@ class Normalize(Transform):
         mean = np.asarray(
             self.mean, dtype=np.float32)[np.newaxis, np.newaxis, :]
         std = np.asarray(self.std, dtype=np.float32)[np.newaxis, np.newaxis, :]
-        image = normalize(image, mean, std, self.min_val, self.max_val)
+        image = F.normalize(image, mean, std, self.min_val, self.max_val)
         return image
 
     def apply(self, sample):
@@ -1007,12 +1001,12 @@ class CenterCrop(Transform):
         self.crop_size = crop_size
 
     def apply_im(self, image):
-        image = center_crop(image, self.crop_size)
+        image = F.center_crop(image, self.crop_size)
 
         return image
 
     def apply_mask(self, mask):
-        mask = center_crop(mask, self.crop_size)
+        mask = F.center_crop(mask, self.crop_size)
         return mask
 
     def apply(self, sample):
@@ -1159,12 +1153,12 @@ class RandomCrop(Transform):
         crop_segms = []
         for id in valid_ids:
             segm = segms[id]
-            if is_poly(segm):
+            if F.is_poly(segm):
                 # Polygon format
-                crop_segms.append(crop_poly(segm, crop))
+                crop_segms.append(F.crop_poly(segm, crop))
             else:
                 # RLE format
-                crop_segms.append(crop_rle(segm, crop, height, width))
+                crop_segms.append(F.crop_rle(segm, crop, height, width))
 
         return crop_segms
 
@@ -1196,7 +1190,7 @@ class RandomCrop(Transform):
                     delete_id = list()
                     valid_polys = list()
                     for idx, poly in enumerate(crop_polys):
-                        if not crop_poly:
+                        if not poly:
                             delete_id.append(idx)
                         else:
                             valid_polys.append(poly)
@@ -1231,7 +1225,7 @@ class RandomCrop(Transform):
                 if 'sr_factor' in sample:
                     sample['target'] = self.apply_im(
                         sample['target'],
-                        calc_hr_shape(crop_box, sample['sr_factor']))
+                        F.calc_hr_shape(crop_box, sample['sr_factor']))
                 else:
                     sample['target'] = self.apply_im(sample['image'], crop_box)
 
@@ -1393,14 +1387,14 @@ class Pad(Transform):
         h, w = size
         expanded_segms = []
         for segm in segms:
-            if is_poly(segm):
+            if F.is_poly(segm):
                 # Polygon format
                 expanded_segms.append(
-                    [expand_poly(poly, x, y) for poly in segm])
+                    [F.expand_poly(poly, x, y) for poly in segm])
             else:
                 # RLE format
                 expanded_segms.append(
-                    expand_rle(segm, x, y, height, width, h, w))
+                    F.expand_rle(segm, x, y, height, width, h, w))
         return expanded_segms
 
     def _get_offsets(self, im_h, im_w, h, w):
@@ -1450,7 +1444,7 @@ class Pad(Transform):
                 sample['gt_poly'], offsets, im_size=[im_h, im_w], size=[h, w])
         if 'target' in sample:
             if 'sr_factor' in sample:
-                hr_shape = calc_hr_shape((h, w), sample['sr_factor'])
+                hr_shape = F.calc_hr_shape((h, w), sample['sr_factor'])
                 hr_offsets = self._get_offsets(*sample['target'].shape[:2],
                                                *hr_shape)
                 sample['target'] = self.apply_im(sample['target'], hr_offsets,
@@ -1757,7 +1751,7 @@ class Dehaze(Transform):
         self.gamma = gamma
 
     def apply_im(self, image):
-        image = dehaze(image, self.gamma)
+        image = F.dehaze(image, self.gamma)
         return image
 
     def apply(self, sample):
@@ -1819,7 +1813,7 @@ class SelectBand(Transform):
         self.apply_to_tar = apply_to_tar
 
     def apply_im(self, image):
-        image = select_bands(image, self.band_list)
+        image = F.select_bands(image, self.band_list)
         return image
 
     def apply(self, sample):
@@ -1944,10 +1938,10 @@ class RandomSwap(Transform):
 
 class ReloadMask(Transform):
     def apply(self, sample):
-        sample['mask'] = decode_seg_mask(sample['mask_ori'])
+        sample['mask'] = F.decode_seg_mask(sample['mask_ori'])
         if 'aux_masks' in sample:
             sample['aux_masks'] = list(
-                map(decode_seg_mask, sample['aux_masks_ori']))
+                map(F.decode_seg_mask, sample['aux_masks_ori']))
         return sample
 
 
@@ -1987,18 +1981,21 @@ class MatchRadiance(Transform):
 
     Args:
         method (str, optional): Method used to match the radiance of the
-            bi-temporal images. Choices are {'hist', 'lsr'}. 'hist' stands
-            for histogram matching and 'lsr' stands for least-squares 
-            regression. Default: 'hist'.
+            bi-temporal images. Choices are {'hist', 'lsr', 'fft}. 'hist' 
+            stands for histogram matching, 'lsr' stands for least-squares 
+            regression, and 'fft' replaces the low-frequency components of
+            the image to match the reference image. Default: 'hist'.
     """
 
     def __init__(self, method='hist'):
         super(MatchRadiance, self).__init__()
 
         if method == 'hist':
-            self._match_func = match_histograms
+            self._match_func = F.match_histograms
         elif method == 'lsr':
-            self._match_func = match_by_regression
+            self._match_func = F.match_by_regression
+        elif method == 'fft':
+            self._match_func = F.match_lf_components
         else:
             raise ValueError(
                 "{} is not a supported radiometric correction method.".format(

+ 3 - 0
tests/transforms/test_operators.py

@@ -390,6 +390,9 @@ class TestTransform(CpuCommonTest):
         test_lsr = make_test_func(
             T.MatchRadiance, 'lsr', _filter=_filter_only_mt)
         test_lsr(self)
+        test_fft = make_test_func(
+            T.MatchRadiance, 'fft', _filter=_filter_only_mt)
+        test_fft(self)
 
 
 class TestCompose(CpuCommonTest):