Explorar o código

Add util functions for data preparation

Bobholamovic %!s(int64=2) %!d(string=hai) anos
pai
achega
1c30b71959
Modificáronse 1 ficheiros con 68 adicións e 0 borrados
  1. 68 0
      tools/prepare_dataset/common.py

+ 68 - 0
tools/prepare_dataset/common.py

@@ -1,4 +1,6 @@
 import argparse
+import random
+import copy
 import os
 import os.path as osp
 from glob import glob
@@ -198,6 +200,20 @@ def create_file_list(file_list, path_tuples, sep=' '):
             f.write(line + '\n')
 
 
+def create_label_list(label_list, labels):
+    """
+    Create label list.
+    
+    Args:
+        label_list (str): Path of label list to create.
+        labels (list[str]|tuple[str]]): Label names.
+    """
+
+    with open(label_list, 'w') as f:
+        for label in labels:
+            f.write(label + '\n')
+
+
 def link_dataset(src, dst):
     """
     Make a symbolic link to a dataset.
@@ -211,5 +227,57 @@ def link_dataset(src, dst):
         raise ValueError(f"{dst} exists and is not a directory.")
     elif not osp.exists(dst):
         os.makedirs(dst)
+    src = osp.realpath(src)
     name = osp.basename(osp.normpath(src))
     os.symlink(src, osp.join(dst, name), target_is_directory=True)
+
+
+def random_split(samples,
+                 ratios=(0.7, 0.2, 0.1),
+                 inplace=True,
+                 drop_remainder=False):
+    """
+    Randomly split the dataset into two or three subsets.
+    
+    Args:
+        samples (list): All samples of the dataset.
+        ratios (tuple[float], optional): If the length of `ratios` is 2,
+            the two elements indicate the ratios of samples used for training
+            and evaluation. If the length of `ratios` is 3, the three elements
+            indicate the ratios of samples used for training, validation, and 
+            testing. Defaults to (0.7, 0.2, 0.1).
+        inplace (bool, optional): Whether to shuffle `samples` in place. 
+            Defaults to True.
+        drop_remainder (bool, optional): Whether to discard the remaining samples.
+            If False, the remaining samples will be included in the last subset.
+            For example, if `ratios` is (0.7, 0.1) and `drop_remainder` is False, 
+            the two subsets after splitting will contain 70% and 30% of the samples, 
+            respectively. Defaults to False.
+    """
+
+    if not inplace:
+        samples = copy.deepcopy(samples)
+
+    if len(samples) == 0:
+        raise ValueError("There are no samples!")
+
+    if len(ratios) not in (2, 3):
+        raise ValueError("`len(ratios)` must be 2 or 3!")
+
+    random.shuffle(samples)
+
+    n_samples = len(samples)
+    acc_r = 0
+    st_idx, ed_idx = 0, 0
+    splits = []
+    for r in ratios:
+        acc_r += r
+        ed_idx = round(acc_r * n_samples)
+        splits.append(samples[st_idx:ed_idx])
+        st_idx = ed_idx
+
+    if ed_idx < len(ratios) and not drop_remainder:
+        # Append remainder to the last split
+        splits[-1].append(splits[ed_idx:])
+
+    return splits