Procházet zdrojové kódy

Merge remote-tracking branch 'origin/develop' into develop

zhl před 3 roky
rodič
revize
ceb8b4eebe
3 změnil soubory, kde provedl 162 přidání a 56 odebrání
  1. 25 14
      paddlers/transforms/functions.py
  2. 125 30
      paddlers/transforms/operators.py
  3. 12 12
      requirements.txt

+ 25 - 14
paddlers/transforms/functions.py

@@ -206,6 +206,7 @@ def to_uint8(im):
     Returns:
     Returns:
         np.ndarray: Image on uint8.
         np.ndarray: Image on uint8.
     """
     """
+
     # 2% linear stretch
     # 2% linear stretch
     def _two_percentLinear(image, max_out=255, min_out=0):
     def _two_percentLinear(image, max_out=255, min_out=0):
         def _gray_process(gray, maxout=max_out, minout=min_out):
         def _gray_process(gray, maxout=max_out, minout=min_out):
@@ -216,6 +217,7 @@ def to_uint8(im):
             processed_gray = ((truncated_gray - low_value) / (high_value - low_value)) * \
             processed_gray = ((truncated_gray - low_value) / (high_value - low_value)) * \
                              (maxout - minout)
                              (maxout - minout)
             return processed_gray
             return processed_gray
+
         if len(image.shape) == 3:
         if len(image.shape) == 3:
             processes = []
             processes = []
             for b in range(image.shape[-1]):
             for b in range(image.shape[-1]):
@@ -244,7 +246,7 @@ def to_uint8(im):
         lut = []
         lut = []
         for bt in range(0, len(hist), NUMS):
         for bt in range(0, len(hist), NUMS):
             # step size
             # step size
-            step = reduce(operator.add, hist[bt : bt + NUMS]) / (NUMS - 1)
+            step = reduce(operator.add, hist[bt:bt + NUMS]) / (NUMS - 1)
             # create balanced lookup table
             # create balanced lookup table
             n = 0
             n = 0
             for i in range(NUMS):
             for i in range(NUMS):
@@ -301,14 +303,18 @@ def select_bands(im, band_list=[1, 2, 3]):
     Returns:
     Returns:
         np.ndarray: The image after band selected.
         np.ndarray: The image after band selected.
     """
     """
+    if len(im.shape) == 2:  # just have one channel
+        return im
+    if not isinstance(band_list, list) or len(band_list) == 0:
+        raise TypeError("band_list must be non empty list.")
     total_band = im.shape[-1]
     total_band = im.shape[-1]
     result = []
     result = []
     for band in band_list:
     for band in band_list:
         band = int(band - 1)
         band = int(band - 1)
         if band < 0 or band >= total_band:
         if band < 0 or band >= total_band:
-            raise ValueError(
-                "The element in band_list must > 1 and <= {}.".format(str(total_band)))
-        result.append()
+            raise ValueError("The element in band_list must > 1 and <= {}.".
+                             format(str(total_band)))
+        result.append(im[:, :, band])
     ima = np.stack(result, axis=0)
     ima = np.stack(result, axis=0)
     return ima
     return ima
 
 
@@ -323,6 +329,7 @@ def de_haze(im, gamma=False):
     Returns:
     Returns:
         np.ndarray: The image after defogged.
         np.ndarray: The image after defogged.
     """
     """
+
     def _guided_filter(I, p, r, eps):
     def _guided_filter(I, p, r, eps):
         m_I = cv2.boxFilter(I, -1, (r, r))
         m_I = cv2.boxFilter(I, -1, (r, r))
         m_p = cv2.boxFilter(p, -1, (r, r))
         m_p = cv2.boxFilter(p, -1, (r, r))
@@ -350,16 +357,17 @@ def de_haze(im, gamma=False):
         atmo_illum = np.mean(im, 2)[atmo_mask >= ht[1][lmax]].max()
         atmo_illum = np.mean(im, 2)[atmo_mask >= ht[1][lmax]].max()
         atmo_mask = np.minimum(atmo_mask * w, maxatmo_mask)
         atmo_mask = np.minimum(atmo_mask * w, maxatmo_mask)
         return atmo_mask, atmo_illum
         return atmo_mask, atmo_illum
-        
+
     if np.max(im) > 1:
     if np.max(im) > 1:
         im = im / 255.
         im = im / 255.
     result = np.zeros(im.shape)
     result = np.zeros(im.shape)
-    mask_img, atmo_illum = _de_fog(im, r=81, w=0.95, maxatmo_mask=0.80, eps=1e-8)
+    mask_img, atmo_illum = _de_fog(
+        im, r=81, w=0.95, maxatmo_mask=0.80, eps=1e-8)
     for k in range(3):
     for k in range(3):
         result[:, :, k] = (im[:, :, k] - mask_img) / (1 - mask_img / atmo_illum)
         result[:, :, k] = (im[:, :, k] - mask_img) / (1 - mask_img / atmo_illum)
     result = np.clip(result, 0, 1)
     result = np.clip(result, 0, 1)
     if gamma:
     if gamma:
-        result = result ** (np.log(0.5) / np.log(result.mean()))
+        result = result**(np.log(0.5) / np.log(result.mean()))
     return (result * 255).astype("uint8")
     return (result * 255).astype("uint8")
 
 
 
 
@@ -398,7 +406,8 @@ def match_histograms(im, ref):
         ValueError: When the number of channels of `ref` differs from that of im`.
         ValueError: When the number of channels of `ref` differs from that of im`.
     """
     """
     # TODO: Check the data types of the inputs to see if they are supported by skimage
     # TODO: Check the data types of the inputs to see if they are supported by skimage
-    return exposure.match_histograms(im, ref, channel_axis=-1 if im.ndim>2 else None)
+    return exposure.match_histograms(
+        im, ref, channel_axis=-1 if im.ndim > 2 else None)
 
 
 
 
 def match_by_regression(im, ref, pif_loc=None):
 def match_by_regression(im, ref, pif_loc=None):
@@ -418,27 +427,29 @@ def match_by_regression(im, ref, pif_loc=None):
     Raises:
     Raises:
         ValueError: When the shape of `ref` differs from that of `im`.
         ValueError: When the shape of `ref` differs from that of `im`.
     """
     """
+
     def _linear_regress(im, ref, loc):
     def _linear_regress(im, ref, loc):
         regressor = LinearRegression()
         regressor = LinearRegression()
         if loc is not None:
         if loc is not None:
             x, y = im[loc], ref[loc]
             x, y = im[loc], ref[loc]
         else:
         else:
             x, y = im, ref
             x, y = im, ref
-        x, y = x.reshape(-1,1), y.ravel()
+        x, y = x.reshape(-1, 1), y.ravel()
         regressor.fit(x, y)
         regressor.fit(x, y)
-        matched = regressor.predict(im.reshape(-1,1))
+        matched = regressor.predict(im.reshape(-1, 1))
         return matched.reshape(im.shape)
         return matched.reshape(im.shape)
 
 
     if im.shape != ref.shape:
     if im.shape != ref.shape:
-        raise  ValueError("Image and Reference must have the same shape!")
+        raise ValueError("Image and Reference must have the same shape!")
 
 
     if im.ndim > 2:
     if im.ndim > 2:
         # Multiple channels
         # Multiple channels
         matched = np.empty(im.shape, dtype=im.dtype)
         matched = np.empty(im.shape, dtype=im.dtype)
         for ch in range(im.shape[-1]):
         for ch in range(im.shape[-1]):
-            matched[..., ch] = _linear_regress(im[..., ch], ref[..., ch], pif_loc)
+            matched[..., ch] = _linear_regress(im[..., ch], ref[..., ch],
+                                               pif_loc)
     else:
     else:
         # Single channel
         # Single channel
         matched = _linear_regress(im, ref, pif_loc).astype(im.dtype)
         matched = _linear_regress(im, ref, pif_loc).astype(im.dtype)
-    
-    return matched
+
+    return matched

+ 125 - 30
paddlers/transforms/operators.py

@@ -31,17 +31,16 @@ import paddlers
 
 
 from .functions import normalize, horizontal_flip, permute, vertical_flip, center_crop, is_poly, \
 from .functions import normalize, horizontal_flip, permute, vertical_flip, center_crop, is_poly, \
     horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly, vertical_flip_rle, crop_poly, \
     horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly, vertical_flip_rle, crop_poly, \
-    crop_rle, expand_poly, expand_rle, resize_poly, resize_rle
-
+    crop_rle, expand_poly, expand_rle, resize_poly, resize_rle, de_haze, pca, select_bands, \
+    to_intensity, to_uint8
 
 
 __all__ = [
 __all__ = [
     "Compose", "ImgDecoder", "Resize", "RandomResize", "ResizeByShort",
     "Compose", "ImgDecoder", "Resize", "RandomResize", "ResizeByShort",
     "RandomResizeByShort", "ResizeByLong", "RandomHorizontalFlip",
     "RandomResizeByShort", "ResizeByLong", "RandomHorizontalFlip",
     "RandomVerticalFlip", "Normalize", "CenterCrop", "RandomCrop",
     "RandomVerticalFlip", "Normalize", "CenterCrop", "RandomCrop",
     "RandomScaleAspect", "RandomExpand", "Padding", "MixupImage",
     "RandomScaleAspect", "RandomExpand", "Padding", "MixupImage",
-    "RandomDistort", "RandomBlur", 
-    "RandomSwap",
-    "ArrangeSegmenter", "ArrangeChangeDetector", 
+    "RandomDistort", "RandomBlur", "RandomSwap", "Defogging", "DimReducing",
+    "BandSelecting", "ArrangeSegmenter", "ArrangeChangeDetector",
     "ArrangeClassifier", "ArrangeDetector"
     "ArrangeClassifier", "ArrangeDetector"
 ]
 ]
 
 
@@ -85,7 +84,8 @@ class Transform(object):
         if 'gt_bbox' in sample:
         if 'gt_bbox' in sample:
             sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'])
             sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'])
         if 'aux_masks' in sample:
         if 'aux_masks' in sample:
-            sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
+            sample['aux_masks'] = list(
+                map(self.apply_mask, sample['aux_masks']))
 
 
         return sample
         return sample
 
 
@@ -105,9 +105,10 @@ class ImgDecoder(Transform):
         to_rgb (bool, optional): If True, convert input images from BGR format to RGB format. Defaults to True.
         to_rgb (bool, optional): If True, convert input images from BGR format to RGB format. Defaults to True.
     """
     """
 
 
-    def __init__(self, to_rgb=True):
+    def __init__(self, to_rgb=True, to_uint8=True):
         super(ImgDecoder, self).__init__()
         super(ImgDecoder, self).__init__()
         self.to_rgb = to_rgb
         self.to_rgb = to_rgb
+        self.to_uint8 = to_uint8
 
 
     def read_img(self, img_path, input_channel=3):
     def read_img(self, img_path, input_channel=3):
         img_format = imghdr.what(img_path)
         img_format = imghdr.what(img_path)
@@ -129,6 +130,7 @@ class ImgDecoder(Transform):
                 raise Exception('Can not open', img_path)
                 raise Exception('Can not open', img_path)
             im_data = dataset.ReadAsArray()
             im_data = dataset.ReadAsArray()
             if im_data.ndim == 2:
             if im_data.ndim == 2:
+                im_data = to_intensity(im_data)  # is read SAR
                 im_data = im_data[:, :, np.newaxis]
                 im_data = im_data[:, :, np.newaxis]
             elif im_data.ndim == 3:
             elif im_data.ndim == 3:
                 im_data = im_data.transpose((1, 2, 0))
                 im_data = im_data.transpose((1, 2, 0))
@@ -158,6 +160,9 @@ class ImgDecoder(Transform):
         if self.to_rgb and image.shape[-1] == 3:
         if self.to_rgb and image.shape[-1] == 3:
             image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
             image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
 
+        if self.to_uint8:
+            image = to_uint8(image)
+
         return image
         return image
 
 
     def apply_mask(self, mask):
     def apply_mask(self, mask):
@@ -191,7 +196,8 @@ class ImgDecoder(Transform):
                 raise Exception(
                 raise Exception(
                     "The height or width of the im is not same as the mask")
                     "The height or width of the im is not same as the mask")
         if 'aux_masks' in sample:
         if 'aux_masks' in sample:
-            sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
+            sample['aux_masks'] = list(
+                map(self.apply_mask, sample['aux_masks']))
             # TODO: check the shape of auxiliary masks
             # TODO: check the shape of auxiliary masks
 
 
         sample['im_shape'] = np.array(
         sample['im_shape'] = np.array(
@@ -350,12 +356,16 @@ class Resize(Transform):
 
 
         sample['image'] = self.apply_im(sample['image'], interp, target_size)
         sample['image'] = self.apply_im(sample['image'], interp, target_size)
         if 'image2' in sample:
         if 'image2' in sample:
-            sample['image2'] = self.apply_im(sample['image2'], interp, target_size)
+            sample['image2'] = self.apply_im(sample['image2'], interp,
+                                             target_size)
 
 
         if 'mask' in sample:
         if 'mask' in sample:
             sample['mask'] = self.apply_mask(sample['mask'], target_size)
             sample['mask'] = self.apply_mask(sample['mask'], target_size)
         if 'aux_masks' in sample:
         if 'aux_masks' in sample:
-            sample['aux_masks'] = list(map(partial(self.apply_mask, target_size=target_size), sample['aux_masks']))
+            sample['aux_masks'] = list(
+                map(partial(
+                    self.apply_mask, target_size=target_size),
+                    sample['aux_masks']))
         if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
         if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
             sample['gt_bbox'] = self.apply_bbox(
             sample['gt_bbox'] = self.apply_bbox(
                 sample['gt_bbox'], [im_scale_x, im_scale_y], target_size)
                 sample['gt_bbox'], [im_scale_x, im_scale_y], target_size)
@@ -557,7 +567,8 @@ class RandomHorizontalFlip(Transform):
             if 'mask' in sample:
             if 'mask' in sample:
                 sample['mask'] = self.apply_mask(sample['mask'])
                 sample['mask'] = self.apply_mask(sample['mask'])
             if 'aux_masks' in sample:
             if 'aux_masks' in sample:
-                sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
+                sample['aux_masks'] = list(
+                    map(self.apply_mask, sample['aux_masks']))
             if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
             if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
                 sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], im_w)
                 sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], im_w)
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
@@ -614,7 +625,8 @@ class RandomVerticalFlip(Transform):
             if 'mask' in sample:
             if 'mask' in sample:
                 sample['mask'] = self.apply_mask(sample['mask'])
                 sample['mask'] = self.apply_mask(sample['mask'])
             if 'aux_masks' in sample:
             if 'aux_masks' in sample:
-                sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
+                sample['aux_masks'] = list(
+                    map(self.apply_mask, sample['aux_masks']))
             if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
             if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
                 sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], im_h)
                 sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], im_h)
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
@@ -653,8 +665,8 @@ class Normalize(Transform):
 
 
         from functools import reduce
         from functools import reduce
         if reduce(lambda x, y: x * y, std) == 0:
         if reduce(lambda x, y: x * y, std) == 0:
-            raise ValueError(
-                'Std should not have 0, but received is {}'.format(std))
+            raise ValueError('Std should not have 0, but received is {}'.format(
+                std))
         if is_scale:
         if is_scale:
             if reduce(lambda x, y: x * y,
             if reduce(lambda x, y: x * y,
                       [a - b for a, b in zip(max_val, min_val)]) == 0:
                       [a - b for a, b in zip(max_val, min_val)]) == 0:
@@ -679,7 +691,7 @@ class Normalize(Transform):
     def apply(self, sample):
     def apply(self, sample):
         sample['image'] = self.apply_im(sample['image'])
         sample['image'] = self.apply_im(sample['image'])
         if 'image2' in sample:
         if 'image2' in sample:
-                sample['image2'] = self.apply_im(sample['image2'])
+            sample['image2'] = self.apply_im(sample['image2'])
 
 
         return sample
         return sample
 
 
@@ -710,11 +722,12 @@ class CenterCrop(Transform):
     def apply(self, sample):
     def apply(self, sample):
         sample['image'] = self.apply_im(sample['image'])
         sample['image'] = self.apply_im(sample['image'])
         if 'image2' in sample:
         if 'image2' in sample:
-                sample['image2'] = self.apply_im(sample['image2'])
+            sample['image2'] = self.apply_im(sample['image2'])
         if 'mask' in sample:
         if 'mask' in sample:
             sample['mask'] = self.apply_mask(sample['mask'])
             sample['mask'] = self.apply_mask(sample['mask'])
         if 'aux_masks' in sample:
         if 'aux_masks' in sample:
-            sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
+            sample['aux_masks'] = list(
+                map(self.apply_mask, sample['aux_masks']))
         return sample
         return sample
 
 
 
 
@@ -779,8 +792,7 @@ class RandomCrop(Transform):
                     if self.cover_all_box and iou.min() < thresh:
                     if self.cover_all_box and iou.min() < thresh:
                         continue
                         continue
                     cropped_box, valid_ids = self._crop_box_with_center_constraint(
                     cropped_box, valid_ids = self._crop_box_with_center_constraint(
-                        sample['gt_bbox'],
-                        np.array(
+                        sample['gt_bbox'], np.array(
                             crop_box, dtype=np.float32))
                             crop_box, dtype=np.float32))
                     if valid_ids.size > 0:
                     if valid_ids.size > 0:
                         return crop_box, cropped_box, valid_ids
                         return crop_box, cropped_box, valid_ids
@@ -907,7 +919,10 @@ class RandomCrop(Transform):
                 sample['mask'] = self.apply_mask(sample['mask'], crop_box)
                 sample['mask'] = self.apply_mask(sample['mask'], crop_box)
 
 
             if 'aux_masks' in sample:
             if 'aux_masks' in sample:
-                sample['aux_masks'] = list(map(partial(self.apply_mask, crop=crop_box), sample['aux_masks']))
+                sample['aux_masks'] = list(
+                    map(partial(
+                        self.apply_mask, crop=crop_box),
+                        sample['aux_masks']))
 
 
         if self.crop_size is not None:
         if self.crop_size is not None:
             sample = Resize(self.crop_size)(sample)
             sample = Resize(self.crop_size)(sample)
@@ -1095,11 +1110,14 @@ class Padding(Transform):
 
 
         sample['image'] = self.apply_im(sample['image'], offsets, (h, w))
         sample['image'] = self.apply_im(sample['image'], offsets, (h, w))
         if 'image2' in sample:
         if 'image2' in sample:
-                sample['image2'] = self.apply_im(sample['image2'], offsets, (h, w))
+            sample['image2'] = self.apply_im(sample['image2'], offsets, (h, w))
         if 'mask' in sample:
         if 'mask' in sample:
             sample['mask'] = self.apply_mask(sample['mask'], offsets, (h, w))
             sample['mask'] = self.apply_mask(sample['mask'], offsets, (h, w))
         if 'aux_masks' in sample:
         if 'aux_masks' in sample:
-            sample['aux_masks'] = list(map(partial(self.apply_mask, offsets=offsets, target_size=(h,w)), sample['aux_masks']))
+            sample['aux_masks'] = list(
+                map(partial(
+                    self.apply_mask, offsets=offsets, target_size=(h, w)),
+                    sample['aux_masks']))
         if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
         if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
             sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], offsets)
             sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], offsets)
         if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
         if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
@@ -1251,7 +1269,7 @@ class RandomDistort(Transform):
         res_list = []
         res_list = []
         channel = image.shape[2]
         channel = image.shape[2]
         for i in range(channel // 3):
         for i in range(channel // 3):
-            sub_img = image[:, :, 3*i : 3*(i+1)]
+            sub_img = image[:, :, 3 * i:3 * (i + 1)]
             sub_img = sub_img.astype(np.float32)
             sub_img = sub_img.astype(np.float32)
             sub_img = np.dot(image, t)
             sub_img = np.dot(image, t)
             res_list.append(sub_img)
             res_list.append(sub_img)
@@ -1271,10 +1289,11 @@ class RandomDistort(Transform):
         res_list = []
         res_list = []
         channel = image.shape[2]
         channel = image.shape[2]
         for i in range(channel // 3):
         for i in range(channel // 3):
-            sub_img = image[:, :, 3*i : 3*(i+1)]
+            sub_img = image[:, :, 3 * i:3 * (i + 1)]
             sub_img = sub_img.astype(np.float32)
             sub_img = sub_img.astype(np.float32)
             # it works, but result differ from HSV version
             # it works, but result differ from HSV version
-            gray = sub_img * np.array([[[0.299, 0.587, 0.114]]], dtype=np.float32)
+            gray = sub_img * np.array(
+                [[[0.299, 0.587, 0.114]]], dtype=np.float32)
             gray = gray.sum(axis=2, keepdims=True)
             gray = gray.sum(axis=2, keepdims=True)
             gray *= (1.0 - delta)
             gray *= (1.0 - delta)
             sub_img *= delta
             sub_img *= delta
@@ -1340,7 +1359,8 @@ class RandomDistort(Transform):
             if np.random.randint(0, 2):
             if np.random.randint(0, 2):
                 sample['image'] = sample['image'][..., np.random.permutation(3)]
                 sample['image'] = sample['image'][..., np.random.permutation(3)]
                 if 'image2' in sample:
                 if 'image2' in sample:
-                    sample['image2'] = sample['image2'][..., np.random.permutation(3)]
+                    sample['image2'] = sample['image2'][
+                        ..., np.random.permutation(3)]
         return sample
         return sample
 
 
 
 
@@ -1380,6 +1400,77 @@ class RandomBlur(Transform):
         return sample
         return sample
 
 
 
 
+class Defogging(Transform):
+    """
+    Defog input image(s).
+
+    Args: 
+        gamma (bool, optional): Use gamma correction or not. Defaults to False.
+    """
+
+    def __init__(self, gamma=False):
+        super(Defogging, self).__init__()
+        self.gamma = gamma
+
+    def apply_im(self, image):
+        image = de_haze(image, self.gamma)
+        return image
+
+    def apply(self, sample):
+        sample['image'] = self.apply_im(sample['image'])
+        if 'image2' in sample:
+            sample['image2'] = self.apply_im(sample['image2'])
+        return sample
+
+
+class DimReducing(Transform):
+    """
+    Use PCA to reduce input image(s) dimension.
+
+    Args: 
+        dim (int, optional): Reserved dimensions. Defaults to 3.
+        whiten (bool, optional): PCA whiten or not. Defaults to True.
+    """
+
+    def __init__(self, dim=3, whiten=True):
+        super(DimReducing, self).__init__()
+        self.dim = dim
+        self.whiten = whiten
+
+    def apply_im(self, image):
+        image = pca(image, self.gamma)
+        return image
+
+    def apply(self, sample):
+        sample['image'] = self.apply_im(sample['image'])
+        if 'image2' in sample:
+            sample['image2'] = self.apply_im(sample['image2'])
+        return sample
+
+
+class BandSelecting(Transform):
+    """
+    Select the band of the input image(s).
+
+    Args: 
+        band_list (list, optional): Bands of selected (Start with 1). Defaults to [1, 2, 3].
+    """
+
+    def __init__(self, band_list=[1, 2, 3]):
+        super(BandSelecting, self).__init__()
+        self.band_list = band_list
+
+    def apply_im(self, image):
+        image = select_bands(image, self.band_list)
+        return image
+
+    def apply(self, sample):
+        sample['image'] = self.apply_im(sample['image'])
+        if 'image2' in sample:
+            sample['image2'] = self.apply_im(sample['image2'])
+        return sample
+
+
 class _PadBox(Transform):
 class _PadBox(Transform):
     def __init__(self, num_max_boxes=50):
     def __init__(self, num_max_boxes=50):
         """
         """
@@ -1464,7 +1555,7 @@ class _Permute(Transform):
         if 'image2' in sample:
         if 'image2' in sample:
             sample['image2'] = permute(sample['image2'], False)
             sample['image2'] = permute(sample['image2'], False)
         return sample
         return sample
-        
+
 
 
 class RandomSwap(Transform):
 class RandomSwap(Transform):
     """
     """
@@ -1482,7 +1573,8 @@ class RandomSwap(Transform):
         if 'image2' not in sample:
         if 'image2' not in sample:
             raise ValueError('image2 is not found in the sample.')
             raise ValueError('image2 is not found in the sample.')
         if random.random() < self.prob:
         if random.random() < self.prob:
-            sample['image'], sample['image2'] = sample['image2'], sample['image']
+            sample['image'], sample['image2'] = sample['image2'], sample[
+                'image']
         return sample
         return sample
 
 
 
 
@@ -1530,8 +1622,11 @@ class ArrangeChangeDetector(Transform):
             mask = mask.astype('int64')
             mask = mask.astype('int64')
             masks = [mask]
             masks = [mask]
             if 'aux_masks' in sample:
             if 'aux_masks' in sample:
-                masks.extend(map(methodcaller('astype', 'int64'), sample['aux_masks']))
-            return (image_t1, image_t2,) + tuple(masks)
+                masks.extend(
+                    map(methodcaller('astype', 'int64'), sample['aux_masks']))
+            return (
+                image_t1,
+                image_t2, ) + tuple(masks)
         if self.mode == 'eval':
         if self.mode == 'eval':
             mask = np.asarray(Image.open(mask))
             mask = np.asarray(Image.open(mask))
             mask = mask[np.newaxis, :, :].astype('int64')
             mask = mask[np.newaxis, :, :].astype('int64')

+ 12 - 12
requirements.txt

@@ -1,23 +1,23 @@
-tqdm
+paddlepaddle-gpu >= 2.2.0
+paddleslim >= 2.2.1
+visualdl >= 2.1.1
+opencv-contrib-python == 4.3.0.38
+numba == 0.53.1
+scikit-learn == 0.23.2
+scikit-image >= 0.14.0
+# numpy == 1.22.3
+pandas
 scipy
 scipy
-colorama
 cython
 cython
 pycocotools
 pycocotools
-visualdl >= 2.1.1
-paddleslim == 2.2.1
 shapely
 shapely
-paddlepaddle-gpu >= 2.2.0
-opencv-python
-opencv-contrib-python==4.3.0.38
-scikit-learn == 0.23.2  # 0.20.3
 lap
 lap
 motmetrics
 motmetrics
-matplotlib
 chardet
 chardet
 openpyxl
 openpyxl
-# GDAL >= 3.1.3  # install through https://www.lfd.uci.edu/~gohlke/pythonlibs/#gdal under windows
-scikit-image>=0.14.0
-numba==0.53.1
 easydict
 easydict
 munch
 munch
 natsort
 natsort
+
+# # Self installation
+# GDAL >= 3.1.3