common.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import argparse
  2. import os
  3. import os.path as osp
  4. from glob import glob
  5. from itertools import count
  6. from functools import partial
  7. from concurrent.futures import ThreadPoolExecutor
  8. from skimage.io import imread, imsave
  9. from tqdm import tqdm
  10. def get_default_parser():
  11. """
  12. Get argument parser with commonly used options.
  13. Returns:
  14. argparse.ArgumentParser: Argument parser with the following arguments:
  15. --in_dataset_dir: Input dataset directory.
  16. --out_dataset_dir: Output dataset directory.
  17. """
  18. parser = argparse.ArgumentParser()
  19. parser.add_argument(
  20. '--in_dataset_dir',
  21. type=str,
  22. required=True,
  23. help="Input dataset directory.")
  24. parser.add_argument(
  25. '--out_dataset_dir', type=str, help="Output dataset directory.")
  26. return parser
  27. def add_crop_options(parser):
  28. """
  29. Add patch cropping related arguments to an argument parser. The parser will be
  30. modified in place.
  31. Args:
  32. parser (argparse.ArgumentParser): Argument parser.
  33. Returns:
  34. argparse.ArgumentParser: Argument parser with the following arguments:
  35. --crop_size: Size of cropped patches.
  36. --crop_stride: Stride of sliding windows when cropping patches.
  37. """
  38. parser.add_argument(
  39. '--crop_size', type=int, help="Size of cropped patches.")
  40. parser.add_argument(
  41. '--crop_stride',
  42. type=int,
  43. help="Stride of sliding windows when cropping patches. `crop_size` will be used only if `crop_size` is not None.",
  44. )
  45. return parser
  46. def crop_and_save(path, out_subdir, crop_size, stride):
  47. name, ext = osp.splitext(osp.basename(path))
  48. out_subsubdir = osp.join(out_subdir, name)
  49. if not osp.exists(out_subsubdir):
  50. os.makedirs(out_subsubdir)
  51. img = imread(path)
  52. w, h = img.shape[:2]
  53. counter = count()
  54. for i in range(0, h - crop_size + 1, stride):
  55. for j in range(0, w - crop_size + 1, stride):
  56. imsave(
  57. osp.join(out_subsubdir, '{}_{}{}'.format(name,
  58. next(counter), ext)),
  59. img[i:i + crop_size, j:j + crop_size],
  60. check_contrast=False)
  61. def crop_patches(crop_size,
  62. stride,
  63. data_dir,
  64. out_dir,
  65. subsets=('train', 'val', 'test'),
  66. subdirs=('A', 'B', 'label'),
  67. glob_pattern='*',
  68. max_workers=0):
  69. """
  70. Crop patches from images in specific directories.
  71. Args:
  72. crop_size (int): Height and width of the cropped patches will be `crop_size`.
  73. stride (int): Stride of sliding windows when cropping patches.
  74. data_dir (str): Root directory of the dataset that contains the input images.
  75. out_dir (str): Directory to save the cropped patches.
  76. subsets (tuple|list|None, optional): List or tuple of names of subdirectories
  77. or None. Images to be cropped should be stored in `data_dir/subset/subdir/`
  78. or `data_dir/subdir/` (when `subsets` is set to None), where `subset` is an
  79. element of `subsets`. Defaults to ('train', 'val', 'test').
  80. subdirs (tuple|list, optional): List or tuple of names of subdirectories. Images
  81. to be cropped should be stored in `data_dir/subset/subdir/` or
  82. `data_dir/subdir/` (when `subsets` is set to None), where `subdir` is an
  83. element of `subdirs`. Defaults to ('A', 'B', 'label').
  84. glob_pattern (str, optional): Glob pattern used to match image files.
  85. Defaults to '*', which matches arbitrary file.
  86. max_workers (int, optional): Number of worker threads to perform the cropping
  87. operation. Deafults to 0.
  88. """
  89. if max_workers < 0:
  90. raise ValueError("`max_workers` must be a non-negative integer!")
  91. if subset is None:
  92. subsets = ('', )
  93. if max_workers == 0:
  94. for subset in subsets:
  95. for subdir in subdirs:
  96. paths = glob(
  97. osp.join(data_dir, subset, subdir, glob_pattern),
  98. recursive=True)
  99. out_subdir = osp.join(out_dir, subset, subdir)
  100. for p in tqdm(paths):
  101. crop_and_save(
  102. p,
  103. out_subdir=out_subdir,
  104. crop_size=crop_size,
  105. stride=stride)
  106. else:
  107. # Concurrently crop image patches
  108. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  109. for subset in subsets:
  110. for subdir in subdirs:
  111. paths = glob(
  112. osp.join(data_dir, subset, subdir, glob_pattern),
  113. recursive=True)
  114. out_subdir = osp.join(out_dir, subset, subdir)
  115. for _ in tqdm(
  116. executor.map(partial(
  117. crop_and_save,
  118. out_subdir=out_subdir,
  119. crop_size=crop_size,
  120. stride=stride),
  121. paths),
  122. total=len(paths)):
  123. pass
  124. def get_path_tuples(*dirs, glob_pattern='*', data_dir=None):
  125. """
  126. Get tuples of image paths. Each tuple corresponds to a sample in the dataset.
  127. Args:
  128. *dirs (str): Directories that contains the images.
  129. glob_pattern (str, optional): Glob pattern used to match image files.
  130. Defaults to '*', which matches arbitrary file.
  131. data_dir (str|None, optional): Root directory of the dataset that
  132. contains the images. If not None, `data_dir` will be used to
  133. determine relative paths of images. Defaults to None.
  134. Returns:
  135. list[tuple]: For directories with the following structure:
  136. ├── img
  137. │ ├── im1.png
  138. │ ├── im2.png
  139. │ └── im3.png
  140. ├── mask
  141. │ ├── im1.png
  142. │ ├── im2.png
  143. │ └── im3.png
  144. └── ...
  145. `get_path_tuples('img', 'mask', '*.png')` will return list of tuples:
  146. [('img/im1.png', 'mask/im1.png'), ('img/im2.png', 'mask/im2.png'), ('img/im3.png', 'mask/im3.png')]
  147. """
  148. all_paths = []
  149. for dir_ in dirs:
  150. paths = glob(osp.join(dir_, glob_pattern), recursive=True)
  151. paths = sorted(paths)
  152. if data_dir is not None:
  153. paths = [osp.relpath(p, data_dir) for p in paths]
  154. all_paths.append(paths)
  155. all_paths = list(zip(*all_paths))
  156. return all_paths
  157. def create_file_list(file_list, path_tuples, sep=' '):
  158. """
  159. Create file list.
  160. Args:
  161. file_list (str): Path of file list to create.
  162. path_tuples (list[tuple]): See get_path_tuples().
  163. sep (str, optional): Delimiter to use when writing lines to file list.
  164. Defaults to ' '.
  165. """
  166. with open(file_list, 'w') as f:
  167. for tup in path_tuples:
  168. line = sep.join(tup)
  169. f.write(line + '\n')
  170. def link_dataset(src, dst):
  171. """
  172. Make a symbolic link to a dataset.
  173. Args:
  174. src (str): Path of the original dataset.
  175. dst (str): Path of the symbolic link.
  176. """
  177. if osp.exists(dst) and not osp.isdir(dst):
  178. raise ValueError(f"{dst} exists and is not a directory.")
  179. elif not osp.exists(dst):
  180. os.makedirs(dst)
  181. name = osp.basename(osp.normpath(src))
  182. os.symlink(src, osp.join(dst, name), target_is_directory=True)