common.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import argparse
  2. import os
  3. import os.path as osp
  4. from glob import glob
  5. from itertools import count
  6. from functools import partial
  7. from concurrent.futures import ThreadPoolExecutor
  8. from skimage.io import imread, imsave
  9. from tqdm import tqdm
  10. def get_default_parser():
  11. parser = argparse.ArgumentParser()
  12. parser.add_argument(
  13. '--in_dataset_dir',
  14. type=str,
  15. required=True,
  16. help="Input dataset directory.")
  17. parser.add_argument(
  18. '--out_dataset_dir', type=str, help="Output dataset directory.")
  19. return parser
  20. def add_crop_options(parser):
  21. parser.add_argument(
  22. '--crop_size', type=int, help="Size of cropped patches.")
  23. parser.add_argument(
  24. '--crop_stride',
  25. type=int,
  26. help="Stride of sliding windows when cropping patches. `crop_size` will be used only if `crop_size` is not None.",
  27. )
  28. return parser
  29. def crop_and_save(path, out_subdir, crop_size, stride):
  30. name, ext = osp.splitext(osp.basename(path))
  31. out_subsubdir = osp.join(out_subdir, name)
  32. if not osp.exists(out_subsubdir):
  33. os.makedirs(out_subsubdir)
  34. img = imread(path)
  35. w, h = img.shape[:2]
  36. counter = count()
  37. for i in range(0, h - crop_size + 1, stride):
  38. for j in range(0, w - crop_size + 1, stride):
  39. imsave(
  40. osp.join(out_subsubdir, '{}_{}{}'.format(name,
  41. next(counter), ext)),
  42. img[i:i + crop_size, j:j + crop_size],
  43. check_contrast=False)
  44. def crop_patches(crop_size,
  45. stride,
  46. data_dir,
  47. out_dir,
  48. subsets=('train', 'val', 'test'),
  49. subdirs=('A', 'B', 'label'),
  50. glob_pattern='*',
  51. max_workers=0):
  52. if max_workers < 0:
  53. raise ValueError("`max_workers` must be a non-negative integer!")
  54. if max_workers == 0:
  55. for subset in subsets:
  56. for subdir in subdirs:
  57. paths = glob(
  58. osp.join(data_dir, subset, subdir, glob_pattern),
  59. recursive=True)
  60. out_subdir = osp.join(out_dir, subset, subdir)
  61. for p in tqdm(paths):
  62. crop_and_save(
  63. p,
  64. out_subdir=out_subdir,
  65. crop_size=crop_size,
  66. stride=stride)
  67. else:
  68. # Concurrently crop image patches
  69. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  70. for subset in subsets:
  71. for subdir in subdirs:
  72. paths = glob(
  73. osp.join(data_dir, subset, subdir, glob_pattern),
  74. recursive=True)
  75. out_subdir = osp.join(out_dir, subset, subdir)
  76. for _ in tqdm(
  77. executor.map(partial(
  78. crop_and_save,
  79. out_subdir=out_subdir,
  80. crop_size=crop_size,
  81. stride=stride),
  82. paths),
  83. total=len(paths)):
  84. pass
  85. def get_path_tuples(*dirs, glob_pattern='*', data_dir=None):
  86. all_paths = []
  87. for dir_ in dirs:
  88. paths = glob(osp.join(dir_, glob_pattern), recursive=True)
  89. paths = sorted(paths)
  90. if data_dir is not None:
  91. paths = [osp.relpath(p, data_dir) for p in paths]
  92. all_paths.append(paths)
  93. all_paths = list(zip(*all_paths))
  94. return all_paths
  95. def create_file_list(file_list, path_tuples, sep=' '):
  96. with open(file_list, 'w') as f:
  97. for tup in path_tuples:
  98. line = sep.join(tup)
  99. f.write(line + '\n')
  100. def link_dataset(src, dst):
  101. if osp.exists(dst) and not osp.isdir(dst):
  102. raise ValueError(f"{dst} exists and is not a directory.")
  103. elif not osp.exists(dst):
  104. os.makedirs(dst)
  105. name = osp.basename(osp.normpath(src))
  106. os.symlink(src, osp.join(dst, name), target_is_directory=True)