Bobholamovic преди 2 години
родител
ревизия
62f34b3b68
променени са 2 файла, в които са добавени 205 реда и са изтрити 7 реда
  1. 69 7
      tools/prepare_dataset/common.py
  2. 136 0
      tools/prepare_dataset/prepare_isaid.py

+ 69 - 7
tools/prepare_dataset/common.py

@@ -3,11 +3,13 @@ import random
 import copy
 import os
 import os.path as osp
+import shutil
 from glob import glob
 from itertools import count
 from functools import partial
 from concurrent.futures import ThreadPoolExecutor
 
+import numpy as np
 from skimage.io import imread, imsave
 from tqdm import tqdm
 
@@ -57,20 +59,54 @@ def add_crop_options(parser):
     return parser
 
 
-def crop_and_save(path, out_subdir, crop_size, stride):
+def crop_and_save(path,
+                  out_subdir,
+                  crop_size,
+                  stride,
+                  keep_last=False,
+                  pad=True,
+                  pad_val=0):
     name, ext = osp.splitext(osp.basename(path))
     out_subsubdir = osp.join(out_subdir, name)
     if not osp.exists(out_subsubdir):
         os.makedirs(out_subsubdir)
     img = imread(path)
-    w, h = img.shape[:2]
+    h, w = img.shape[:2]
+    if h < crop_size or w < crop_size:
+        if not pad:
+            raise ValueError(
+                f"`crop_size` must be smaller than image size. `crop_size` is {crop_size}, but got image size {h}x{w}."
+            )
+        padded_img = np.full(
+            shape=(max(h, crop_size), max(w, crop_size)) + img.shape[2:],
+            fill_value=pad_val,
+            dtype=img.dtype)
+        padded_img[:h, :w] = img
+        h, w = padded_img.shape[:2]
+        img = padded_img
     counter = count()
-    for i in range(0, h - crop_size + 1, stride):
-        for j in range(0, w - crop_size + 1, stride):
+    for i in range(0, h, stride):
+        i_st = i
+        i_ed = i_st + crop_size
+        if i_ed > h:
+            if keep_last:
+                i_st = h - crop_size
+                i_ed = h
+            else:
+                continue
+        for j in range(0, w, stride):
+            j_st = j
+            j_ed = j_st + crop_size
+            if j_ed > w:
+                if keep_last:
+                    j_st = w - crop_size
+                    j_ed = w
+                else:
+                    continue
             imsave(
                 osp.join(out_subsubdir, '{}_{}{}'.format(name,
                                                          next(counter), ext)),
-                img[i:i + crop_size, j:j + crop_size],
+                img[i_st:i_ed, j_st:j_ed],
                 check_contrast=False)
 
 
@@ -81,7 +117,8 @@ def crop_patches(crop_size,
                  subsets=('train', 'val', 'test'),
                  subdirs=('A', 'B', 'label'),
                  glob_pattern='*',
-                 max_workers=0):
+                 max_workers=0,
+                 keep_last=False):
     """
     Crop patches from images in specific directories.
     
@@ -102,6 +139,9 @@ def crop_patches(crop_size,
             Defaults to '*', which matches arbitrary file. 
         max_workers (int, optional): Number of worker threads to perform the cropping 
             operation. Deafults to 0.
+        keep_last (bool, optional): If True, keep the last patch in each row and each 
+            column. The left and upper border of the last patch will be shifted to 
+            ensure that size of the patch be `crop_size`. Defaults to False.
     """
 
     if max_workers < 0:
@@ -110,6 +150,8 @@ def crop_patches(crop_size,
     if subsets is None:
         subsets = ('', )
 
+    print("Cropping patches...")
+
     if max_workers == 0:
         for subset in subsets:
             for subdir in subdirs:
@@ -122,7 +164,8 @@ def crop_patches(crop_size,
                         p,
                         out_subdir=out_subdir,
                         crop_size=crop_size,
-                        stride=stride)
+                        stride=stride,
+                        keep_last=keep_last)
     else:
         # Concurrently crop image patches
         with ThreadPoolExecutor(max_workers=max_workers) as executor:
@@ -232,6 +275,25 @@ def link_dataset(src, dst):
     os.symlink(src, osp.join(dst, name), target_is_directory=True)
 
 
+def copy_dataset(src, dst):
+    """
+    Make a copy a dataset.
+    
+    Args:
+        src (str): Path of the original dataset.
+        dst (str): Path to copy to.
+    """
+
+    if osp.exists(dst) and not osp.isdir(dst):
+        raise ValueError(f"{dst} exists and is not a directory.")
+    elif not osp.exists(dst):
+        os.makedirs(dst)
+
+    src = osp.realpath(src)
+    name = osp.basename(osp.normpath(src))
+    shutil.copytree(src, osp.join(dst, name))
+
+
 def random_split(samples,
                  ratios=(0.7, 0.2, 0.1),
                  inplace=True,

+ 136 - 0
tools/prepare_dataset/prepare_isaid.py

@@ -0,0 +1,136 @@
+#!/usr/bin/env python
+
+import os.path as osp
+from glob import glob
+
+from PIL import Image
+from tqdm import tqdm
+
+from common import (get_default_parser, add_crop_options, crop_patches,
+                    create_file_list, copy_dataset, create_label_list,
+                    get_path_tuples)
+
+# According to the official doc(https://github.com/CAPTAIN-WHU/iSAID_Devkit), 
+# the files should be organized as follows:
+# 
+# iSAID
+# ├── test
+# │   └── images
+# │       ├── P0006.png
+# │       └── ...
+# │       └── P0009.png
+# ├── train
+# │   └── images
+# │       ├── P0002_instance_color_RGB.png
+# │       ├── P0002_instance_id_RGB.png
+# │       ├── P0002.png
+# │       ├── ...
+# │       ├── P0010_instance_color_RGB.png
+# │       ├── P0010_instance_id_RGB.png
+# │       └── P0010.png
+# └── val
+#     └── images
+#         ├── P0003_instance_color_RGB.png
+#         ├── P0003_instance_id_RGB.png
+#         ├── P0003.png
+#         ├── ...
+#         ├── P0004_instance_color_RGB.png
+#         ├── P0004_instance_id_RGB.png
+#         └── P0004.png
+
+CLASSES = ('background', 'ship', 'storage_tank', 'baseball_diamond',
+           'tennis_court', 'basketball_court', 'ground_track_field', 'bridge',
+           'large_vehicle', 'small_vehicle', 'helicopter', 'swimming_pool',
+           'roundabout', 'soccer_ball_field', 'plane', 'harbor')
+# Refer to https://github.com/Z-Zheng/FarSeg/blob/master/data/isaid.py
+COLOR_MAP = [[0, 0, 0], [0, 0, 63], [0, 191, 127], [0, 63, 0], [0, 63, 127],
+             [0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, 127],
+             [0, 0, 127], [0, 0, 191], [0, 0, 255], [0, 63, 63], [0, 127, 191],
+             [0, 127, 255], [0, 100, 155]]
+SUBSETS = ('train', 'val')
+SUBDIR = 'images'
+FILE_LIST_PATTERN = "{subset}.txt"
+LABEL_LIST_NAME = "labels.txt"
+URL = ""
+
+
+def flatten(nested_list):
+    flattened_list = []
+    for ele in nested_list:
+        if isinstance(ele, list):
+            flattened_list.extend(flatten(ele))
+        else:
+            flattened_list.append(ele)
+    return flattened_list
+
+
+def rgb2mask(rgb):
+    palette = flatten(COLOR_MAP)
+    # Pad with zero
+    palette = palette + [0] * (256 * 3 - len(palette))
+    ref = Image.new(mode='P', size=(1, 1))
+    ref.putpalette(palette)
+    mask = rgb.quantize(palette=ref, dither=0)
+    return mask
+
+
+if __name__ == '__main__':
+    parser = get_default_parser()
+    parser.add_argument(
+        '--crop_size', type=int, help="Size of cropped patches.", default=800)
+    parser.add_argument(
+        '--crop_stride',
+        type=int,
+        help="Stride of sliding windows when cropping patches. `crop_size` will be used only if `crop_size` is not None.",
+        default=600)
+    args = parser.parse_args()
+
+    out_dir = osp.join(args.out_dataset_dir,
+                       osp.basename(osp.normpath(args.in_dataset_dir)))
+
+    assert args.crop_size is not None
+    # According to https://github.com/CAPTAIN-WHU/iSAID_Devkit/blob/master/preprocess/split.py
+    # Set keep_last=True
+    crop_patches(
+        args.crop_size,
+        args.crop_stride,
+        data_dir=args.in_dataset_dir,
+        out_dir=out_dir,
+        subsets=SUBSETS,
+        subdirs=(SUBDIR, ),
+        glob_pattern='*.png',
+        max_workers=8,
+        keep_last=True)
+
+    for subset in SUBSETS:
+        path_tuples = []
+        print(f"Processing {subset} labels...")
+        for im_subdir in tqdm(glob(osp.join(out_dir, subset, SUBDIR, "*/"))):
+            im_name = osp.basename(im_subdir[:-1])  # Strip trailing '/'
+            if '_' in im_name:
+                # Do not process labels
+                continue
+            mask_subdir = osp.join(out_dir, subset, SUBDIR,
+                                   im_name + '_instance_color_RGB')
+            for mask_path in glob(osp.join(mask_subdir, '*.png')):
+                # Convert RGB files to mask files (pseudo color)
+                rgb = Image.open(mask_path).convert('RGB')
+                mask = rgb2mask(rgb)
+                # Write to the original location
+                mask.save(mask_path)
+            path_tuples.extend(
+                get_path_tuples(
+                    im_subdir,
+                    mask_subdir,
+                    glob_pattern='*.png',
+                    data_dir=args.out_dataset_dir))
+        path_tuples.sort()
+
+        file_list = osp.join(
+            args.out_dataset_dir, FILE_LIST_PATTERN.format(subset=subset))
+        create_file_list(file_list, path_tuples)
+        print(f"Write file list to {file_list}.")
+
+    label_list = osp.join(args.out_dataset_dir, LABEL_LIST_NAME)
+    create_label_list(label_list, CLASSES)
+    print(f"Write label list to {label_list}.")