Explorar o código

Add prepare_isaid.py

Bobholamovic %!s(int64=2) %!d(string=hai) anos
pai
achega
62f34b3b68
Modificáronse 2 ficheiros con 205 adicións e 7 borrados
  1. 69 7
      tools/prepare_dataset/common.py
  2. 136 0
      tools/prepare_dataset/prepare_isaid.py

+ 69 - 7
tools/prepare_dataset/common.py

@@ -3,11 +3,13 @@ import random
 import copy
 import copy
 import os
 import os
 import os.path as osp
 import os.path as osp
+import shutil
 from glob import glob
 from glob import glob
 from itertools import count
 from itertools import count
 from functools import partial
 from functools import partial
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
 
 
+import numpy as np
 from skimage.io import imread, imsave
 from skimage.io import imread, imsave
 from tqdm import tqdm
 from tqdm import tqdm
 
 
@@ -57,20 +59,54 @@ def add_crop_options(parser):
     return parser
     return parser
 
 
 
 
-def crop_and_save(path, out_subdir, crop_size, stride):
+def crop_and_save(path,
+                  out_subdir,
+                  crop_size,
+                  stride,
+                  keep_last=False,
+                  pad=True,
+                  pad_val=0):
     name, ext = osp.splitext(osp.basename(path))
     name, ext = osp.splitext(osp.basename(path))
     out_subsubdir = osp.join(out_subdir, name)
     out_subsubdir = osp.join(out_subdir, name)
     if not osp.exists(out_subsubdir):
     if not osp.exists(out_subsubdir):
         os.makedirs(out_subsubdir)
         os.makedirs(out_subsubdir)
     img = imread(path)
     img = imread(path)
-    w, h = img.shape[:2]
+    h, w = img.shape[:2]
+    if h < crop_size or w < crop_size:
+        if not pad:
+            raise ValueError(
+                f"`crop_size` must be smaller than image size. `crop_size` is {crop_size}, but got image size {h}x{w}."
+            )
+        padded_img = np.full(
+            shape=(max(h, crop_size), max(w, crop_size)) + img.shape[2:],
+            fill_value=pad_val,
+            dtype=img.dtype)
+        padded_img[:h, :w] = img
+        h, w = padded_img.shape[:2]
+        img = padded_img
     counter = count()
     counter = count()
-    for i in range(0, h - crop_size + 1, stride):
-        for j in range(0, w - crop_size + 1, stride):
+    for i in range(0, h, stride):
+        i_st = i
+        i_ed = i_st + crop_size
+        if i_ed > h:
+            if keep_last:
+                i_st = h - crop_size
+                i_ed = h
+            else:
+                continue
+        for j in range(0, w, stride):
+            j_st = j
+            j_ed = j_st + crop_size
+            if j_ed > w:
+                if keep_last:
+                    j_st = w - crop_size
+                    j_ed = w
+                else:
+                    continue
             imsave(
             imsave(
                 osp.join(out_subsubdir, '{}_{}{}'.format(name,
                 osp.join(out_subsubdir, '{}_{}{}'.format(name,
                                                          next(counter), ext)),
                                                          next(counter), ext)),
-                img[i:i + crop_size, j:j + crop_size],
+                img[i_st:i_ed, j_st:j_ed],
                 check_contrast=False)
                 check_contrast=False)
 
 
 
 
@@ -81,7 +117,8 @@ def crop_patches(crop_size,
                  subsets=('train', 'val', 'test'),
                  subsets=('train', 'val', 'test'),
                  subdirs=('A', 'B', 'label'),
                  subdirs=('A', 'B', 'label'),
                  glob_pattern='*',
                  glob_pattern='*',
-                 max_workers=0):
+                 max_workers=0,
+                 keep_last=False):
     """
     """
     Crop patches from images in specific directories.
     Crop patches from images in specific directories.
     
     
@@ -102,6 +139,9 @@ def crop_patches(crop_size,
             Defaults to '*', which matches arbitrary file. 
             Defaults to '*', which matches arbitrary file. 
         max_workers (int, optional): Number of worker threads to perform the cropping 
         max_workers (int, optional): Number of worker threads to perform the cropping 
             operation. Deafults to 0.
             operation. Deafults to 0.
+        keep_last (bool, optional): If True, keep the last patch in each row and each 
+            column. The left and upper border of the last patch will be shifted to 
+            ensure that size of the patch be `crop_size`. Defaults to False.
     """
     """
 
 
     if max_workers < 0:
     if max_workers < 0:
@@ -110,6 +150,8 @@ def crop_patches(crop_size,
     if subsets is None:
     if subsets is None:
         subsets = ('', )
         subsets = ('', )
 
 
+    print("Cropping patches...")
+
     if max_workers == 0:
     if max_workers == 0:
         for subset in subsets:
         for subset in subsets:
             for subdir in subdirs:
             for subdir in subdirs:
@@ -122,7 +164,8 @@ def crop_patches(crop_size,
                         p,
                         p,
                         out_subdir=out_subdir,
                         out_subdir=out_subdir,
                         crop_size=crop_size,
                         crop_size=crop_size,
-                        stride=stride)
+                        stride=stride,
+                        keep_last=keep_last)
     else:
     else:
         # Concurrently crop image patches
         # Concurrently crop image patches
         with ThreadPoolExecutor(max_workers=max_workers) as executor:
         with ThreadPoolExecutor(max_workers=max_workers) as executor:
@@ -232,6 +275,25 @@ def link_dataset(src, dst):
     os.symlink(src, osp.join(dst, name), target_is_directory=True)
     os.symlink(src, osp.join(dst, name), target_is_directory=True)
 
 
 
 
+def copy_dataset(src, dst):
+    """
+    Make a copy a dataset.
+    
+    Args:
+        src (str): Path of the original dataset.
+        dst (str): Path to copy to.
+    """
+
+    if osp.exists(dst) and not osp.isdir(dst):
+        raise ValueError(f"{dst} exists and is not a directory.")
+    elif not osp.exists(dst):
+        os.makedirs(dst)
+
+    src = osp.realpath(src)
+    name = osp.basename(osp.normpath(src))
+    shutil.copytree(src, osp.join(dst, name))
+
+
 def random_split(samples,
 def random_split(samples,
                  ratios=(0.7, 0.2, 0.1),
                  ratios=(0.7, 0.2, 0.1),
                  inplace=True,
                  inplace=True,

+ 136 - 0
tools/prepare_dataset/prepare_isaid.py

@@ -0,0 +1,136 @@
+#!/usr/bin/env python
+
+import os.path as osp
+from glob import glob
+
+from PIL import Image
+from tqdm import tqdm
+
+from common import (get_default_parser, add_crop_options, crop_patches,
+                    create_file_list, copy_dataset, create_label_list,
+                    get_path_tuples)
+
+# According to the official doc(https://github.com/CAPTAIN-WHU/iSAID_Devkit), 
+# the files should be organized as follows:
+# 
+# iSAID
+# ├── test
+# │   └── images
+# │       ├── P0006.png
+# │       └── ...
+# │       └── P0009.png
+# ├── train
+# │   └── images
+# │       ├── P0002_instance_color_RGB.png
+# │       ├── P0002_instance_id_RGB.png
+# │       ├── P0002.png
+# │       ├── ...
+# │       ├── P0010_instance_color_RGB.png
+# │       ├── P0010_instance_id_RGB.png
+# │       └── P0010.png
+# └── val
+#     └── images
+#         ├── P0003_instance_color_RGB.png
+#         ├── P0003_instance_id_RGB.png
+#         ├── P0003.png
+#         ├── ...
+#         ├── P0004_instance_color_RGB.png
+#         ├── P0004_instance_id_RGB.png
+#         └── P0004.png
+
+CLASSES = ('background', 'ship', 'storage_tank', 'baseball_diamond',
+           'tennis_court', 'basketball_court', 'ground_track_field', 'bridge',
+           'large_vehicle', 'small_vehicle', 'helicopter', 'swimming_pool',
+           'roundabout', 'soccer_ball_field', 'plane', 'harbor')
+# Refer to https://github.com/Z-Zheng/FarSeg/blob/master/data/isaid.py
+COLOR_MAP = [[0, 0, 0], [0, 0, 63], [0, 191, 127], [0, 63, 0], [0, 63, 127],
+             [0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, 127],
+             [0, 0, 127], [0, 0, 191], [0, 0, 255], [0, 63, 63], [0, 127, 191],
+             [0, 127, 255], [0, 100, 155]]
+SUBSETS = ('train', 'val')
+SUBDIR = 'images'
+FILE_LIST_PATTERN = "{subset}.txt"
+LABEL_LIST_NAME = "labels.txt"
+URL = ""
+
+
+def flatten(nested_list):
+    flattened_list = []
+    for ele in nested_list:
+        if isinstance(ele, list):
+            flattened_list.extend(flatten(ele))
+        else:
+            flattened_list.append(ele)
+    return flattened_list
+
+
+def rgb2mask(rgb):
+    palette = flatten(COLOR_MAP)
+    # Pad with zero
+    palette = palette + [0] * (256 * 3 - len(palette))
+    ref = Image.new(mode='P', size=(1, 1))
+    ref.putpalette(palette)
+    mask = rgb.quantize(palette=ref, dither=0)
+    return mask
+
+
+if __name__ == '__main__':
+    parser = get_default_parser()
+    parser.add_argument(
+        '--crop_size', type=int, help="Size of cropped patches.", default=800)
+    parser.add_argument(
+        '--crop_stride',
+        type=int,
+        help="Stride of sliding windows when cropping patches. `crop_size` will be used only if `crop_size` is not None.",
+        default=600)
+    args = parser.parse_args()
+
+    out_dir = osp.join(args.out_dataset_dir,
+                       osp.basename(osp.normpath(args.in_dataset_dir)))
+
+    assert args.crop_size is not None
+    # According to https://github.com/CAPTAIN-WHU/iSAID_Devkit/blob/master/preprocess/split.py
+    # Set keep_last=True
+    crop_patches(
+        args.crop_size,
+        args.crop_stride,
+        data_dir=args.in_dataset_dir,
+        out_dir=out_dir,
+        subsets=SUBSETS,
+        subdirs=(SUBDIR, ),
+        glob_pattern='*.png',
+        max_workers=8,
+        keep_last=True)
+
+    for subset in SUBSETS:
+        path_tuples = []
+        print(f"Processing {subset} labels...")
+        for im_subdir in tqdm(glob(osp.join(out_dir, subset, SUBDIR, "*/"))):
+            im_name = osp.basename(im_subdir[:-1])  # Strip trailing '/'
+            if '_' in im_name:
+                # Do not process labels
+                continue
+            mask_subdir = osp.join(out_dir, subset, SUBDIR,
+                                   im_name + '_instance_color_RGB')
+            for mask_path in glob(osp.join(mask_subdir, '*.png')):
+                # Convert RGB files to mask files (pseudo color)
+                rgb = Image.open(mask_path).convert('RGB')
+                mask = rgb2mask(rgb)
+                # Write to the original location
+                mask.save(mask_path)
+            path_tuples.extend(
+                get_path_tuples(
+                    im_subdir,
+                    mask_subdir,
+                    glob_pattern='*.png',
+                    data_dir=args.out_dataset_dir))
+        path_tuples.sort()
+
+        file_list = osp.join(
+            args.out_dataset_dir, FILE_LIST_PATTERN.format(subset=subset))
+        create_file_list(file_list, path_tuples)
+        print(f"Write file list to {file_list}.")
+
+    label_list = osp.join(args.out_dataset_dir, LABEL_LIST_NAME)
+    create_label_list(label_list, CLASSES)
+    print(f"Write label list to {label_list}.")