12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394 |
- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import json
- import argparse
- import pandas as pd
- def _get_annno(df_image_split, df_anno):
- df_merge = pd.merge(
- df_image_split, df_anno, on="image_id", suffixes=(None, "_r"))
- df_merge = df_merge[[c for c in df_merge.columns if not c.endswith("_r")]]
- df_anno_split = df_merge[df_anno.columns.to_list()]
- df_anno_split = df_anno_split.sort_values(by="id")
- return df_anno_split
- def json_split(json_all_path, json_train_path, json_val_path, val_split_rate,
- val_split_num, keep_val_in_train, img_keyname, anno_keyname):
- print("Split".center(100, "-"))
- print("json read...\n")
- with open(json_all_path, "r") as load_f:
- data = json.load(load_f)
- df_anno = pd.DataFrame(data[anno_keyname])
- df_image = pd.DataFrame(data[img_keyname])
- df_image = df_image.rename(columns={"id": "image_id"})
- df_image = df_image.sample(frac=1, random_state=0)
- if val_split_num is None:
- val_split_num = int(val_split_rate * len(df_image))
- if keep_val_in_train:
- df_image_train = df_image
- df_image_val = df_image[:val_split_num]
- df_anno_train = df_anno
- df_anno_val = _get_annno(df_image_val, df_anno)
- else:
- df_image_train = df_image[val_split_num:]
- df_image_val = df_image[:val_split_num]
- df_anno_train = _get_annno(df_image_train, df_anno)
- df_anno_val = _get_annno(df_image_val, df_anno)
- df_image_train = df_image_train.rename(
- columns={"image_id": "id"}).sort_values(by="id")
- df_image_val = df_image_val.rename(columns={"image_id": "id"}).sort_values(
- by="id")
- data[img_keyname] = json.loads(df_image_train.to_json(orient="records"))
- data[anno_keyname] = json.loads(df_anno_train.to_json(orient="records"))
- str_json = json.dumps(data, ensure_ascii=False)
- with open(json_train_path, "w", encoding="utf-8") as file_obj:
- file_obj.write(str_json)
- data[img_keyname] = json.loads(df_image_val.to_json(orient="records"))
- data[anno_keyname] = json.loads(df_anno_val.to_json(orient="records"))
- str_json = json.dumps(data, ensure_ascii=False)
- with open(json_val_path, "w", encoding="utf-8") as file_obj:
- file_obj.write(str_json)
- print("image total %d, train %d, val %d" %
- (len(df_image), len(df_image_train), len(df_image_val)))
- print("anno total %d, train %d, val %d" %
- (len(df_anno), len(df_anno_train), len(df_anno_val)))
- return df_image
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="Split JSON file")
- parser.add_argument("--json_all_path", type=str, required=True, \
- help="Path to the original JSON file.")
- parser.add_argument("--json_train_path", type=str, required=True, \
- help="Generated JSON file for the train set.")
- parser.add_argument( "--json_val_path", type=str, required=True, \
- help="Generated JSON file for the val set.")
- parser.add_argument("--val_split_rate", type=float, default=0.1, \
- help="Proportion of files in the val set.")
- parser.add_argument("--val_split_num", type=int, default=None, \
- help="Number of val set files. If this parameter is set,`--val_split_rate` will be invalidated.")
- parser.add_argument("--keep_val_in_train", action="store_true", \
- help="Whether to keep the val set samples in the train set.")
- parser.add_argument("--img_keyname", type=str, default="images", \
- help="Image key in the JSON file.")
- parser.add_argument("--anno_keyname", type=str, default="annotations", \
- help="Category key in the JSON file.")
- args = parser.parse_args()
- json_split(args.json_all_path, args.json_train_path, args.json_val_path,
- args.val_split_rate, args.val_split_num, args.keep_val_in_train,
- args.img_keyname, args.anno_keyname)
|