Ver Fonte

[Feature] Update geojson2mask

geoyee há 3 anos atrás
pai
commit
05f0ef0059
1 ficheiros alterados com 25 adições e 8 exclusões
  1. 25 8
      tools/geojson2mask.py

+ 25 - 8
tools/geojson2mask.py

@@ -1,3 +1,17 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import os
 import os.path as osp
 import shutil
@@ -26,10 +40,8 @@ def _save_palette(label, save_path):
     visualimg.save(save_path, format='PNG')
 
 
-def _save_mask(annotation, image, save_path):
-    if isinstance(image, str):
-        image = cv2.imread(image)
-    mask = np.zeros(image.shape[:2], dtype=np.int32)
+def _save_mask(annotation, image_size, save_path):
+    mask = np.zeros(image_size, dtype=np.int32)
     for contour_points in annotation:
         contour_points = np.array(contour_points).reshape((-1, 2))
         contour_points = np.round(contour_points).astype(np.int32)[np.newaxis, :]
@@ -42,13 +54,15 @@ def _read_geojson(json_path):
         jsoner = json.load(f)
         imgs = jsoner["images"]
         images = defaultdict(list)
+        sizes = defaultdict(list)
         for img in imgs:
             images[img["id"]] = img["file_name"]
+            sizes[img["file_name"]] = (img["height"], img["width"])
         anns = jsoner["annotations"]
         annotations = defaultdict(list)
         for ann in anns:
             annotations[images[ann["image_id"]]].append(ann["segmentation"])
-        return annotations
+        return annotations, sizes
 
 
 def convertData(raw_folder, end_folder):
@@ -61,9 +75,12 @@ def convertData(raw_folder, end_folder):
     names = os.listdir(img_folder)
     print("-- Loading annotations --")
     anns = {}
+    sizes = {}
     jsons = glob.glob(osp.join(raw_folder, "*.json"))
     for json in jsons:
-        anns.update(_read_geojson(json))
+        j_ann, j_size = _read_geojson(json)
+        anns.update(j_ann)
+        sizes.update(j_size)
     print("-- Converting datas --")
     for k in tqdm(names):
     # for k in tqdm(anns.keys()):
@@ -73,9 +90,9 @@ def convertData(raw_folder, end_folder):
         lab_save_path = osp.join(save_lab_folder, k.replace(ext, ".png"))
         shutil.copy(img_path, img_save_path)
         if k in anns.keys():
-            _save_mask(anns[k], img_path, lab_save_path)
+            _save_mask(anns[k], sizes[k], lab_save_path)
         else:  # have not anns
-            _save_palette(np.zeros(cv2.imread(img_path).shape[:2], dtype="uint8"), \
+            _save_palette(np.zeros(sizes[k], dtype="uint8"), \
                           lab_save_path)