json_split.py 4.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import argparse
  16. import pandas as pd
  17. def _get_annno(df_image_split, df_anno):
  18. df_merge = pd.merge(
  19. df_image_split, df_anno, on="image_id", suffixes=(None, "_r"))
  20. df_merge = df_merge[[c for c in df_merge.columns if not c.endswith("_r")]]
  21. df_anno_split = df_merge[df_anno.columns.to_list()]
  22. df_anno_split = df_anno_split.sort_values(by="id")
  23. return df_anno_split
  24. def json_split(json_all_path, json_train_path, json_val_path, val_split_rate,
  25. val_split_num, keep_val_in_train, img_keyname, anno_keyname):
  26. print("Split".center(100, "-"))
  27. print("json read...\n")
  28. with open(json_all_path, "r") as load_f:
  29. data = json.load(load_f)
  30. df_anno = pd.DataFrame(data[anno_keyname])
  31. df_image = pd.DataFrame(data[img_keyname])
  32. df_image = df_image.rename(columns={"id": "image_id"})
  33. df_image = df_image.sample(frac=1, random_state=0)
  34. if val_split_num is None:
  35. val_split_num = int(val_split_rate * len(df_image))
  36. if keep_val_in_train:
  37. df_image_train = df_image
  38. df_image_val = df_image[:val_split_num]
  39. df_anno_train = df_anno
  40. df_anno_val = _get_annno(df_image_val, df_anno)
  41. else:
  42. df_image_train = df_image[val_split_num:]
  43. df_image_val = df_image[:val_split_num]
  44. df_anno_train = _get_annno(df_image_train, df_anno)
  45. df_anno_val = _get_annno(df_image_val, df_anno)
  46. df_image_train = df_image_train.rename(
  47. columns={"image_id": "id"}).sort_values(by="id")
  48. df_image_val = df_image_val.rename(columns={"image_id": "id"}).sort_values(
  49. by="id")
  50. data[img_keyname] = json.loads(df_image_train.to_json(orient="records"))
  51. data[anno_keyname] = json.loads(df_anno_train.to_json(orient="records"))
  52. str_json = json.dumps(data, ensure_ascii=False)
  53. with open(json_train_path, "w", encoding="utf-8") as file_obj:
  54. file_obj.write(str_json)
  55. data[img_keyname] = json.loads(df_image_val.to_json(orient="records"))
  56. data[anno_keyname] = json.loads(df_anno_val.to_json(orient="records"))
  57. str_json = json.dumps(data, ensure_ascii=False)
  58. with open(json_val_path, "w", encoding="utf-8") as file_obj:
  59. file_obj.write(str_json)
  60. print("image total %d, train %d, val %d" %
  61. (len(df_image), len(df_image_train), len(df_image_val)))
  62. print("anno total %d, train %d, val %d" %
  63. (len(df_anno), len(df_anno_train), len(df_anno_val)))
  64. return df_image
  65. if __name__ == "__main__":
  66. parser = argparse.ArgumentParser(description="Split JSON file")
  67. parser.add_argument("--json_all_path", type=str, required=True, \
  68. help="Path to the original JSON file.")
  69. parser.add_argument("--json_train_path", type=str, required=True, \
  70. help="Generated JSON file for the train set.")
  71. parser.add_argument( "--json_val_path", type=str, required=True, \
  72. help="Generated JSON file for the val set.")
  73. parser.add_argument("--val_split_rate", type=float, default=0.1, \
  74. help="Proportion of files in the val set.")
  75. parser.add_argument("--val_split_num", type=int, default=None, \
  76. help="Number of val set files. If this parameter is set,`--val_split_rate` will be invalidated.")
  77. parser.add_argument("--keep_val_in_train", action="store_true", \
  78. help="Whether to keep the val set samples in the train set.")
  79. parser.add_argument("--img_keyname", type=str, default="images", \
  80. help="Image key in the JSON file.")
  81. parser.add_argument("--anno_keyname", type=str, default="annotations", \
  82. help="Category key in the JSON file.")
  83. args = parser.parse_args()
  84. json_split(args.json_all_path, args.json_train_path, args.json_val_path,
  85. args.val_split_rate, args.val_split_num, args.keep_val_in_train,
  86. args.img_keyname, args.anno_keyname)