common.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import random
  16. import copy
  17. import os
  18. import os.path as osp
  19. import shutil
  20. from glob import glob
  21. from itertools import count
  22. from functools import partial
  23. from concurrent.futures import ThreadPoolExecutor
  24. import numpy as np
  25. from skimage.io import imread, imsave
  26. from tqdm import tqdm
  27. def get_default_parser():
  28. """
  29. Get argument parser with commonly used options.
  30. Returns:
  31. argparse.ArgumentParser: Argument parser with the following arguments:
  32. --in_dataset_dir: Input dataset directory.
  33. --out_dataset_dir: Output dataset directory.
  34. """
  35. parser = argparse.ArgumentParser()
  36. parser.add_argument(
  37. '--in_dataset_dir',
  38. type=str,
  39. required=True,
  40. help="Input dataset directory.")
  41. parser.add_argument(
  42. '--out_dataset_dir', type=str, help="Output dataset directory.")
  43. return parser
  44. def add_crop_options(parser):
  45. """
  46. Add patch cropping related arguments to an argument parser. The parser will be
  47. modified in place.
  48. Args:
  49. parser (argparse.ArgumentParser): Argument parser.
  50. Returns:
  51. argparse.ArgumentParser: Argument parser with the following arguments:
  52. --crop_size: Size of cropped patches.
  53. --crop_stride: Stride of sliding windows when cropping patches.
  54. """
  55. parser.add_argument(
  56. '--crop_size', type=int, help="Size of cropped patches.")
  57. parser.add_argument(
  58. '--crop_stride',
  59. type=int,
  60. help="Stride of sliding windows when cropping patches. `crop_size` will be used only if `crop_size` is not None.",
  61. )
  62. return parser
  63. def crop_and_save(path,
  64. out_subdir,
  65. crop_size,
  66. stride,
  67. keep_last=False,
  68. pad=True,
  69. pad_val=0):
  70. name, ext = osp.splitext(osp.basename(path))
  71. out_subsubdir = osp.join(out_subdir, name)
  72. if not osp.exists(out_subsubdir):
  73. os.makedirs(out_subsubdir)
  74. img = imread(path)
  75. h, w = img.shape[:2]
  76. if h < crop_size or w < crop_size:
  77. if not pad:
  78. raise ValueError(
  79. f"`crop_size` must be smaller than image size. `crop_size` is {crop_size}, but got image size {h}x{w}."
  80. )
  81. padded_img = np.full(
  82. shape=(max(h, crop_size), max(w, crop_size)) + img.shape[2:],
  83. fill_value=pad_val,
  84. dtype=img.dtype)
  85. padded_img[:h, :w] = img
  86. h, w = padded_img.shape[:2]
  87. img = padded_img
  88. counter = count()
  89. for i in range(0, h, stride):
  90. i_st = i
  91. i_ed = i_st + crop_size
  92. if i_ed > h:
  93. if keep_last:
  94. i_st = h - crop_size
  95. i_ed = h
  96. else:
  97. continue
  98. for j in range(0, w, stride):
  99. j_st = j
  100. j_ed = j_st + crop_size
  101. if j_ed > w:
  102. if keep_last:
  103. j_st = w - crop_size
  104. j_ed = w
  105. else:
  106. continue
  107. imsave(
  108. osp.join(out_subsubdir, '{}_{}{}'.format(name,
  109. next(counter), ext)),
  110. img[i_st:i_ed, j_st:j_ed],
  111. check_contrast=False)
  112. def crop_patches(crop_size,
  113. stride,
  114. data_dir,
  115. out_dir,
  116. subsets=('train', 'val', 'test'),
  117. subdirs=('A', 'B', 'label'),
  118. glob_pattern='*',
  119. max_workers=0,
  120. keep_last=False):
  121. """
  122. Crop patches from images in specific directories.
  123. Args:
  124. crop_size (int): Height and width of the cropped patches will be `crop_size`.
  125. stride (int): Stride of sliding windows when cropping patches.
  126. data_dir (str): Root directory of the dataset that contains the input images.
  127. out_dir (str): Directory to save the cropped patches.
  128. subsets (tuple|list|None, optional): List or tuple of names of subdirectories
  129. or None. Images to be cropped should be stored in `data_dir/subset/subdir/`
  130. or `data_dir/subdir/` (when `subsets` is set to None), where `subset` is an
  131. element of `subsets`. Defaults to ('train', 'val', 'test').
  132. subdirs (tuple|list, optional): List or tuple of names of subdirectories. Images
  133. to be cropped should be stored in `data_dir/subset/subdir/` or
  134. `data_dir/subdir/` (when `subsets` is set to None), where `subdir` is an
  135. element of `subdirs`. Defaults to ('A', 'B', 'label').
  136. glob_pattern (str, optional): Glob pattern used to match image files.
  137. Defaults to '*', which matches arbitrary file.
  138. max_workers (int, optional): Number of worker threads to perform the cropping
  139. operation. Deafults to 0.
  140. keep_last (bool, optional): If True, keep the last patch in each row and each
  141. column. The left and upper border of the last patch will be shifted to
  142. ensure that size of the patch be `crop_size`. Defaults to False.
  143. """
  144. if max_workers < 0:
  145. raise ValueError("`max_workers` must be a non-negative integer!")
  146. if subsets is None:
  147. subsets = ('', )
  148. print("Cropping patches...")
  149. if max_workers == 0:
  150. for subset in subsets:
  151. for subdir in subdirs:
  152. paths = glob(
  153. osp.join(data_dir, subset, subdir, glob_pattern),
  154. recursive=True)
  155. out_subdir = osp.join(out_dir, subset, subdir)
  156. for p in tqdm(paths):
  157. crop_and_save(
  158. p,
  159. out_subdir=out_subdir,
  160. crop_size=crop_size,
  161. stride=stride,
  162. keep_last=keep_last)
  163. else:
  164. # Concurrently crop image patches
  165. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  166. for subset in subsets:
  167. for subdir in subdirs:
  168. paths = glob(
  169. osp.join(data_dir, subset, subdir, glob_pattern),
  170. recursive=True)
  171. out_subdir = osp.join(out_dir, subset, subdir)
  172. for _ in tqdm(
  173. executor.map(partial(
  174. crop_and_save,
  175. out_subdir=out_subdir,
  176. crop_size=crop_size,
  177. stride=stride),
  178. paths),
  179. total=len(paths)):
  180. pass
  181. def get_path_tuples(*dirs, glob_pattern='*', data_dir=None):
  182. """
  183. Get tuples of image paths. Each tuple corresponds to a sample in the dataset.
  184. Args:
  185. *dirs (str): Directories that contains the images.
  186. glob_pattern (str, optional): Glob pattern used to match image files.
  187. Defaults to '*', which matches arbitrary file.
  188. data_dir (str|None, optional): Root directory of the dataset that
  189. contains the images. If not None, `data_dir` will be used to
  190. determine relative paths of images. Defaults to None.
  191. Returns:
  192. list[tuple]: For directories with the following structure:
  193. ├── img
  194. │ ├── im1.png
  195. │ ├── im2.png
  196. │ └── im3.png
  197. ├── mask
  198. │ ├── im1.png
  199. │ ├── im2.png
  200. │ └── im3.png
  201. └── ...
  202. `get_path_tuples('img', 'mask', '*.png')` will return list of tuples:
  203. [('img/im1.png', 'mask/im1.png'), ('img/im2.png', 'mask/im2.png'), ('img/im3.png', 'mask/im3.png')]
  204. """
  205. all_paths = []
  206. for dir_ in dirs:
  207. paths = glob(osp.join(dir_, glob_pattern), recursive=True)
  208. paths = sorted(paths)
  209. if data_dir is not None:
  210. paths = [osp.relpath(p, data_dir) for p in paths]
  211. all_paths.append(paths)
  212. all_paths = list(zip(*all_paths))
  213. return all_paths
  214. def create_file_list(file_list, path_tuples, sep=' '):
  215. """
  216. Create file list.
  217. Args:
  218. file_list (str): Path of file list to create.
  219. path_tuples (list[tuple]): See get_path_tuples().
  220. sep (str, optional): Delimiter to use when writing lines to file list.
  221. Defaults to ' '.
  222. """
  223. with open(file_list, 'w') as f:
  224. for tup in path_tuples:
  225. line = sep.join(tup)
  226. f.write(line + '\n')
  227. def create_label_list(label_list, labels):
  228. """
  229. Create label list.
  230. Args:
  231. label_list (str): Path of label list to create.
  232. labels (list[str]|tuple[str]]): Label names.
  233. """
  234. with open(label_list, 'w') as f:
  235. for label in labels:
  236. f.write(label + '\n')
  237. def link_dataset(src, dst):
  238. """
  239. Make a symbolic link to a dataset.
  240. Args:
  241. src (str): Path of the original dataset.
  242. dst (str): Path of the symbolic link.
  243. """
  244. if osp.exists(dst) and not osp.isdir(dst):
  245. raise ValueError(f"{dst} exists and is not a directory.")
  246. elif not osp.exists(dst):
  247. os.makedirs(dst)
  248. src = osp.realpath(src)
  249. name = osp.basename(osp.normpath(src))
  250. os.symlink(src, osp.join(dst, name), target_is_directory=True)
  251. def copy_dataset(src, dst):
  252. """
  253. Make a copy a dataset.
  254. Args:
  255. src (str): Path of the original dataset.
  256. dst (str): Path to copy to.
  257. """
  258. if osp.exists(dst) and not osp.isdir(dst):
  259. raise ValueError(f"{dst} exists and is not a directory.")
  260. elif not osp.exists(dst):
  261. os.makedirs(dst)
  262. src = osp.realpath(src)
  263. name = osp.basename(osp.normpath(src))
  264. shutil.copytree(src, osp.join(dst, name))
  265. def random_split(samples,
  266. ratios=(0.7, 0.2, 0.1),
  267. inplace=True,
  268. drop_remainder=False):
  269. """
  270. Randomly split the dataset into two or three subsets.
  271. Args:
  272. samples (list): All samples of the dataset.
  273. ratios (tuple[float], optional): If the length of `ratios` is 2,
  274. the two elements indicate the ratios of samples used for training
  275. and evaluation. If the length of `ratios` is 3, the three elements
  276. indicate the ratios of samples used for training, validation, and
  277. testing. Defaults to (0.7, 0.2, 0.1).
  278. inplace (bool, optional): Whether to shuffle `samples` in place.
  279. Defaults to True.
  280. drop_remainder (bool, optional): Whether to discard the remaining samples.
  281. If False, the remaining samples will be included in the last subset.
  282. For example, if `ratios` is (0.7, 0.1) and `drop_remainder` is False,
  283. the two subsets after splitting will contain 70% and 30% of the samples,
  284. respectively. Defaults to False.
  285. """
  286. if not inplace:
  287. samples = copy.deepcopy(samples)
  288. if len(samples) == 0:
  289. raise ValueError("There are no samples!")
  290. if len(ratios) not in (2, 3):
  291. raise ValueError("`len(ratios)` must be 2 or 3!")
  292. random.shuffle(samples)
  293. n_samples = len(samples)
  294. acc_r = 0
  295. st_idx, ed_idx = 0, 0
  296. splits = []
  297. for r in ratios:
  298. acc_r += r
  299. ed_idx = round(acc_r * n_samples)
  300. splits.append(samples[st_idx:ed_idx])
  301. st_idx = ed_idx
  302. if ed_idx < len(ratios) and not drop_remainder:
  303. # Append remainder to the last split
  304. splits[-1].append(splits[ed_idx:])
  305. return splits