Browse Source

Merge remote-tracking branch 'origin/develop' into develop

zhl 3 years ago
parent
commit
ceb8b4eebe
3 changed files with 162 additions and 56 deletions
  1. 25 14
      paddlers/transforms/functions.py
  2. 125 30
      paddlers/transforms/operators.py
  3. 12 12
      requirements.txt

+ 25 - 14
paddlers/transforms/functions.py

@@ -206,6 +206,7 @@ def to_uint8(im):
     Returns:
         np.ndarray: Image on uint8.
     """
+
     # 2% linear stretch
     def _two_percentLinear(image, max_out=255, min_out=0):
         def _gray_process(gray, maxout=max_out, minout=min_out):
@@ -216,6 +217,7 @@ def to_uint8(im):
             processed_gray = ((truncated_gray - low_value) / (high_value - low_value)) * \
                              (maxout - minout)
             return processed_gray
+
         if len(image.shape) == 3:
             processes = []
             for b in range(image.shape[-1]):
@@ -244,7 +246,7 @@ def to_uint8(im):
         lut = []
         for bt in range(0, len(hist), NUMS):
             # step size
-            step = reduce(operator.add, hist[bt : bt + NUMS]) / (NUMS - 1)
+            step = reduce(operator.add, hist[bt:bt + NUMS]) / (NUMS - 1)
             # create balanced lookup table
             n = 0
             for i in range(NUMS):
@@ -301,14 +303,18 @@ def select_bands(im, band_list=[1, 2, 3]):
     Returns:
         np.ndarray: The image after band selected.
     """
+    if len(im.shape) == 2:  # just have one channel
+        return im
+    if not isinstance(band_list, list) or len(band_list) == 0:
+        raise TypeError("band_list must be non empty list.")
     total_band = im.shape[-1]
     result = []
     for band in band_list:
         band = int(band - 1)
         if band < 0 or band >= total_band:
-            raise ValueError(
-                "The element in band_list must > 1 and <= {}.".format(str(total_band)))
-        result.append()
+            raise ValueError("The element in band_list must > 1 and <= {}.".
+                             format(str(total_band)))
+        result.append(im[:, :, band])
     ima = np.stack(result, axis=0)
     return ima
 
@@ -323,6 +329,7 @@ def de_haze(im, gamma=False):
     Returns:
         np.ndarray: The image after defogged.
     """
+
     def _guided_filter(I, p, r, eps):
         m_I = cv2.boxFilter(I, -1, (r, r))
         m_p = cv2.boxFilter(p, -1, (r, r))
@@ -350,16 +357,17 @@ def de_haze(im, gamma=False):
         atmo_illum = np.mean(im, 2)[atmo_mask >= ht[1][lmax]].max()
         atmo_mask = np.minimum(atmo_mask * w, maxatmo_mask)
         return atmo_mask, atmo_illum
-        
+
     if np.max(im) > 1:
         im = im / 255.
     result = np.zeros(im.shape)
-    mask_img, atmo_illum = _de_fog(im, r=81, w=0.95, maxatmo_mask=0.80, eps=1e-8)
+    mask_img, atmo_illum = _de_fog(
+        im, r=81, w=0.95, maxatmo_mask=0.80, eps=1e-8)
     for k in range(3):
         result[:, :, k] = (im[:, :, k] - mask_img) / (1 - mask_img / atmo_illum)
     result = np.clip(result, 0, 1)
     if gamma:
-        result = result ** (np.log(0.5) / np.log(result.mean()))
+        result = result**(np.log(0.5) / np.log(result.mean()))
     return (result * 255).astype("uint8")
 
 
@@ -398,7 +406,8 @@ def match_histograms(im, ref):
         ValueError: When the number of channels of `ref` differs from that of im`.
     """
     # TODO: Check the data types of the inputs to see if they are supported by skimage
-    return exposure.match_histograms(im, ref, channel_axis=-1 if im.ndim>2 else None)
+    return exposure.match_histograms(
+        im, ref, channel_axis=-1 if im.ndim > 2 else None)
 
 
 def match_by_regression(im, ref, pif_loc=None):
@@ -418,27 +427,29 @@ def match_by_regression(im, ref, pif_loc=None):
     Raises:
         ValueError: When the shape of `ref` differs from that of `im`.
     """
+
     def _linear_regress(im, ref, loc):
         regressor = LinearRegression()
         if loc is not None:
             x, y = im[loc], ref[loc]
         else:
             x, y = im, ref
-        x, y = x.reshape(-1,1), y.ravel()
+        x, y = x.reshape(-1, 1), y.ravel()
         regressor.fit(x, y)
-        matched = regressor.predict(im.reshape(-1,1))
+        matched = regressor.predict(im.reshape(-1, 1))
         return matched.reshape(im.shape)
 
     if im.shape != ref.shape:
-        raise  ValueError("Image and Reference must have the same shape!")
+        raise ValueError("Image and Reference must have the same shape!")
 
     if im.ndim > 2:
         # Multiple channels
         matched = np.empty(im.shape, dtype=im.dtype)
         for ch in range(im.shape[-1]):
-            matched[..., ch] = _linear_regress(im[..., ch], ref[..., ch], pif_loc)
+            matched[..., ch] = _linear_regress(im[..., ch], ref[..., ch],
+                                               pif_loc)
     else:
         # Single channel
         matched = _linear_regress(im, ref, pif_loc).astype(im.dtype)
-    
-    return matched
+
+    return matched

+ 125 - 30
paddlers/transforms/operators.py

@@ -31,17 +31,16 @@ import paddlers
 
 from .functions import normalize, horizontal_flip, permute, vertical_flip, center_crop, is_poly, \
     horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly, vertical_flip_rle, crop_poly, \
-    crop_rle, expand_poly, expand_rle, resize_poly, resize_rle
-
+    crop_rle, expand_poly, expand_rle, resize_poly, resize_rle, de_haze, pca, select_bands, \
+    to_intensity, to_uint8
 
 __all__ = [
     "Compose", "ImgDecoder", "Resize", "RandomResize", "ResizeByShort",
     "RandomResizeByShort", "ResizeByLong", "RandomHorizontalFlip",
     "RandomVerticalFlip", "Normalize", "CenterCrop", "RandomCrop",
     "RandomScaleAspect", "RandomExpand", "Padding", "MixupImage",
-    "RandomDistort", "RandomBlur", 
-    "RandomSwap",
-    "ArrangeSegmenter", "ArrangeChangeDetector", 
+    "RandomDistort", "RandomBlur", "RandomSwap", "Defogging", "DimReducing",
+    "BandSelecting", "ArrangeSegmenter", "ArrangeChangeDetector",
     "ArrangeClassifier", "ArrangeDetector"
 ]
 
@@ -85,7 +84,8 @@ class Transform(object):
         if 'gt_bbox' in sample:
             sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'])
         if 'aux_masks' in sample:
-            sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
+            sample['aux_masks'] = list(
+                map(self.apply_mask, sample['aux_masks']))
 
         return sample
 
@@ -105,9 +105,10 @@ class ImgDecoder(Transform):
         to_rgb (bool, optional): If True, convert input images from BGR format to RGB format. Defaults to True.
     """
 
-    def __init__(self, to_rgb=True):
+    def __init__(self, to_rgb=True, to_uint8=True):
         super(ImgDecoder, self).__init__()
         self.to_rgb = to_rgb
+        self.to_uint8 = to_uint8
 
     def read_img(self, img_path, input_channel=3):
         img_format = imghdr.what(img_path)
@@ -129,6 +130,7 @@ class ImgDecoder(Transform):
                 raise Exception('Can not open', img_path)
             im_data = dataset.ReadAsArray()
             if im_data.ndim == 2:
+                im_data = to_intensity(im_data)  # is read SAR
                 im_data = im_data[:, :, np.newaxis]
             elif im_data.ndim == 3:
                 im_data = im_data.transpose((1, 2, 0))
@@ -158,6 +160,9 @@ class ImgDecoder(Transform):
         if self.to_rgb and image.shape[-1] == 3:
             image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
+        if self.to_uint8:
+            image = to_uint8(image)
+
         return image
 
     def apply_mask(self, mask):
@@ -191,7 +196,8 @@ class ImgDecoder(Transform):
                 raise Exception(
                     "The height or width of the im is not same as the mask")
         if 'aux_masks' in sample:
-            sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
+            sample['aux_masks'] = list(
+                map(self.apply_mask, sample['aux_masks']))
             # TODO: check the shape of auxiliary masks
 
         sample['im_shape'] = np.array(
@@ -350,12 +356,16 @@ class Resize(Transform):
 
         sample['image'] = self.apply_im(sample['image'], interp, target_size)
         if 'image2' in sample:
-            sample['image2'] = self.apply_im(sample['image2'], interp, target_size)
+            sample['image2'] = self.apply_im(sample['image2'], interp,
+                                             target_size)
 
         if 'mask' in sample:
             sample['mask'] = self.apply_mask(sample['mask'], target_size)
         if 'aux_masks' in sample:
-            sample['aux_masks'] = list(map(partial(self.apply_mask, target_size=target_size), sample['aux_masks']))
+            sample['aux_masks'] = list(
+                map(partial(
+                    self.apply_mask, target_size=target_size),
+                    sample['aux_masks']))
         if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
             sample['gt_bbox'] = self.apply_bbox(
                 sample['gt_bbox'], [im_scale_x, im_scale_y], target_size)
@@ -557,7 +567,8 @@ class RandomHorizontalFlip(Transform):
             if 'mask' in sample:
                 sample['mask'] = self.apply_mask(sample['mask'])
             if 'aux_masks' in sample:
-                sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
+                sample['aux_masks'] = list(
+                    map(self.apply_mask, sample['aux_masks']))
             if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
                 sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], im_w)
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
@@ -614,7 +625,8 @@ class RandomVerticalFlip(Transform):
             if 'mask' in sample:
                 sample['mask'] = self.apply_mask(sample['mask'])
             if 'aux_masks' in sample:
-                sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
+                sample['aux_masks'] = list(
+                    map(self.apply_mask, sample['aux_masks']))
             if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
                 sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], im_h)
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
@@ -653,8 +665,8 @@ class Normalize(Transform):
 
         from functools import reduce
         if reduce(lambda x, y: x * y, std) == 0:
-            raise ValueError(
-                'Std should not have 0, but received is {}'.format(std))
+            raise ValueError('Std should not have 0, but received is {}'.format(
+                std))
         if is_scale:
             if reduce(lambda x, y: x * y,
                       [a - b for a, b in zip(max_val, min_val)]) == 0:
@@ -679,7 +691,7 @@ class Normalize(Transform):
     def apply(self, sample):
         sample['image'] = self.apply_im(sample['image'])
         if 'image2' in sample:
-                sample['image2'] = self.apply_im(sample['image2'])
+            sample['image2'] = self.apply_im(sample['image2'])
 
         return sample
 
@@ -710,11 +722,12 @@ class CenterCrop(Transform):
     def apply(self, sample):
         sample['image'] = self.apply_im(sample['image'])
         if 'image2' in sample:
-                sample['image2'] = self.apply_im(sample['image2'])
+            sample['image2'] = self.apply_im(sample['image2'])
         if 'mask' in sample:
             sample['mask'] = self.apply_mask(sample['mask'])
         if 'aux_masks' in sample:
-            sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
+            sample['aux_masks'] = list(
+                map(self.apply_mask, sample['aux_masks']))
         return sample
 
 
@@ -779,8 +792,7 @@ class RandomCrop(Transform):
                     if self.cover_all_box and iou.min() < thresh:
                         continue
                     cropped_box, valid_ids = self._crop_box_with_center_constraint(
-                        sample['gt_bbox'],
-                        np.array(
+                        sample['gt_bbox'], np.array(
                             crop_box, dtype=np.float32))
                     if valid_ids.size > 0:
                         return crop_box, cropped_box, valid_ids
@@ -907,7 +919,10 @@ class RandomCrop(Transform):
                 sample['mask'] = self.apply_mask(sample['mask'], crop_box)
 
             if 'aux_masks' in sample:
-                sample['aux_masks'] = list(map(partial(self.apply_mask, crop=crop_box), sample['aux_masks']))
+                sample['aux_masks'] = list(
+                    map(partial(
+                        self.apply_mask, crop=crop_box),
+                        sample['aux_masks']))
 
         if self.crop_size is not None:
             sample = Resize(self.crop_size)(sample)
@@ -1095,11 +1110,14 @@ class Padding(Transform):
 
         sample['image'] = self.apply_im(sample['image'], offsets, (h, w))
         if 'image2' in sample:
-                sample['image2'] = self.apply_im(sample['image2'], offsets, (h, w))
+            sample['image2'] = self.apply_im(sample['image2'], offsets, (h, w))
         if 'mask' in sample:
             sample['mask'] = self.apply_mask(sample['mask'], offsets, (h, w))
         if 'aux_masks' in sample:
-            sample['aux_masks'] = list(map(partial(self.apply_mask, offsets=offsets, target_size=(h,w)), sample['aux_masks']))
+            sample['aux_masks'] = list(
+                map(partial(
+                    self.apply_mask, offsets=offsets, target_size=(h, w)),
+                    sample['aux_masks']))
         if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
             sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], offsets)
         if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
@@ -1251,7 +1269,7 @@ class RandomDistort(Transform):
         res_list = []
         channel = image.shape[2]
         for i in range(channel // 3):
-            sub_img = image[:, :, 3*i : 3*(i+1)]
+            sub_img = image[:, :, 3 * i:3 * (i + 1)]
             sub_img = sub_img.astype(np.float32)
             sub_img = np.dot(image, t)
             res_list.append(sub_img)
@@ -1271,10 +1289,11 @@ class RandomDistort(Transform):
         res_list = []
         channel = image.shape[2]
         for i in range(channel // 3):
-            sub_img = image[:, :, 3*i : 3*(i+1)]
+            sub_img = image[:, :, 3 * i:3 * (i + 1)]
             sub_img = sub_img.astype(np.float32)
             # it works, but result differ from HSV version
-            gray = sub_img * np.array([[[0.299, 0.587, 0.114]]], dtype=np.float32)
+            gray = sub_img * np.array(
+                [[[0.299, 0.587, 0.114]]], dtype=np.float32)
             gray = gray.sum(axis=2, keepdims=True)
             gray *= (1.0 - delta)
             sub_img *= delta
@@ -1340,7 +1359,8 @@ class RandomDistort(Transform):
             if np.random.randint(0, 2):
                 sample['image'] = sample['image'][..., np.random.permutation(3)]
                 if 'image2' in sample:
-                    sample['image2'] = sample['image2'][..., np.random.permutation(3)]
+                    sample['image2'] = sample['image2'][
+                        ..., np.random.permutation(3)]
         return sample
 
 
@@ -1380,6 +1400,77 @@ class RandomBlur(Transform):
         return sample
 
 
+class Defogging(Transform):
+    """
+    Defog input image(s).
+
+    Args: 
+        gamma (bool, optional): Use gamma correction or not. Defaults to False.
+    """
+
+    def __init__(self, gamma=False):
+        super(Defogging, self).__init__()
+        self.gamma = gamma
+
+    def apply_im(self, image):
+        image = de_haze(image, self.gamma)
+        return image
+
+    def apply(self, sample):
+        sample['image'] = self.apply_im(sample['image'])
+        if 'image2' in sample:
+            sample['image2'] = self.apply_im(sample['image2'])
+        return sample
+
+
+class DimReducing(Transform):
+    """
+    Use PCA to reduce input image(s) dimension.
+
+    Args: 
+        dim (int, optional): Reserved dimensions. Defaults to 3.
+        whiten (bool, optional): PCA whiten or not. Defaults to True.
+    """
+
+    def __init__(self, dim=3, whiten=True):
+        super(DimReducing, self).__init__()
+        self.dim = dim
+        self.whiten = whiten
+
+    def apply_im(self, image):
+        image = pca(image, self.gamma)
+        return image
+
+    def apply(self, sample):
+        sample['image'] = self.apply_im(sample['image'])
+        if 'image2' in sample:
+            sample['image2'] = self.apply_im(sample['image2'])
+        return sample
+
+
+class BandSelecting(Transform):
+    """
+    Select the band of the input image(s).
+
+    Args: 
+        band_list (list, optional): Bands of selected (Start with 1). Defaults to [1, 2, 3].
+    """
+
+    def __init__(self, band_list=[1, 2, 3]):
+        super(BandSelecting, self).__init__()
+        self.band_list = band_list
+
+    def apply_im(self, image):
+        image = select_bands(image, self.band_list)
+        return image
+
+    def apply(self, sample):
+        sample['image'] = self.apply_im(sample['image'])
+        if 'image2' in sample:
+            sample['image2'] = self.apply_im(sample['image2'])
+        return sample
+
+
 class _PadBox(Transform):
     def __init__(self, num_max_boxes=50):
         """
@@ -1464,7 +1555,7 @@ class _Permute(Transform):
         if 'image2' in sample:
             sample['image2'] = permute(sample['image2'], False)
         return sample
-        
+
 
 class RandomSwap(Transform):
     """
@@ -1482,7 +1573,8 @@ class RandomSwap(Transform):
         if 'image2' not in sample:
             raise ValueError('image2 is not found in the sample.')
         if random.random() < self.prob:
-            sample['image'], sample['image2'] = sample['image2'], sample['image']
+            sample['image'], sample['image2'] = sample['image2'], sample[
+                'image']
         return sample
 
 
@@ -1530,8 +1622,11 @@ class ArrangeChangeDetector(Transform):
             mask = mask.astype('int64')
             masks = [mask]
             if 'aux_masks' in sample:
-                masks.extend(map(methodcaller('astype', 'int64'), sample['aux_masks']))
-            return (image_t1, image_t2,) + tuple(masks)
+                masks.extend(
+                    map(methodcaller('astype', 'int64'), sample['aux_masks']))
+            return (
+                image_t1,
+                image_t2, ) + tuple(masks)
         if self.mode == 'eval':
             mask = np.asarray(Image.open(mask))
             mask = mask[np.newaxis, :, :].astype('int64')

+ 12 - 12
requirements.txt

@@ -1,23 +1,23 @@
-tqdm
+paddlepaddle-gpu >= 2.2.0
+paddleslim >= 2.2.1
+visualdl >= 2.1.1
+opencv-contrib-python == 4.3.0.38
+numba == 0.53.1
+scikit-learn == 0.23.2
+scikit-image >= 0.14.0
+# numpy == 1.22.3
+pandas
 scipy
-colorama
 cython
 pycocotools
-visualdl >= 2.1.1
-paddleslim == 2.2.1
 shapely
-paddlepaddle-gpu >= 2.2.0
-opencv-python
-opencv-contrib-python==4.3.0.38
-scikit-learn == 0.23.2  # 0.20.3
 lap
 motmetrics
-matplotlib
 chardet
 openpyxl
-# GDAL >= 3.1.3  # install through https://www.lfd.uci.edu/~gohlke/pythonlibs/#gdal under windows
-scikit-image>=0.14.0
-numba==0.53.1
 easydict
 munch
 natsort
+
+# # Self installation
+# GDAL >= 3.1.3