Browse Source

[Fix] mask2geojson->geojson2mask

Bobholamovic 2 years ago
parent
commit
6c582195e2
3 changed files with 82 additions and 8 deletions
  1. 1 1
      README.md
  2. 8 7
      docs/data/tools.md
  3. 73 0
      tools/geojson2mask.py

+ 1 - 1
README.md

@@ -127,8 +127,8 @@ PaddleRS具有以下五大特色:
         <b>数据格式转换</b><br>
         <ul>
           <li>coco to mask</li>
+          <li>geojson to mask</li>
           <li>mask to shpfile</li>
-          <li>mask to geojson</li>
           <li>...</li>
         </ul>
         <b>数据预处理</b><br>

+ 8 - 7
docs/data/tools.md

@@ -3,8 +3,8 @@
 PaddleRS在`tools`目录中提供了丰富的遥感影像处理工具,包括:
 
 - `coco2mask.py`:用于将COCO格式的标注文件转换为.png格式。
-- `mask2shape.py`:用于将模型推理输出的.png格式栅格标签转换为矢量格式。
-- `mask2geojson.py`:用于将模型推理输出的.png格式栅格标签转换为GeoJSON格式。
+- `mask2shape.py`:用于将模型推理输出的.png格式栅格标签转换为.shp矢量格式。
+- `geojson2mask.py`:用于将GeoJSON格式标签转换为.tif栅格格式。
 - `match.py`:用于实现两幅影像的配准。
 - `split.py`:用于对大幅面影像数据进行切片。
 - `coco_tools/`:COCO工具合集,用于统计处理COCO格式标注文件。
@@ -47,18 +47,19 @@ python mask2shape.py --srcimg_path {带有地理信息的原始影像路径} --m
 - `save_path`:保存shapefile的路径,默认为`output`。
 - `ignore_index`:需要在shapefile中忽略的索引值(例如分割任务中的背景类),默认为`255`。
 
-### mask2geojson
+### geojson2mask
 
-`mask2geojson.py`的主要功能是将.png格式的分割结果转换为GeoJSON格式。使用方式如下:
+`geojson2mask.py`的主要功能是将GeoJSON格式的标签转换为.tif的栅格格式。使用方式如下:
 
 ```shell
-python mask2geojson.py --mask_path {输入分割标签路径} --save_path {输出路径}
+python geojson2mask.py --srcimg_path {带有地理信息的原始影像路径} --geojson_path {输入分割标签路径} --save_path {输出路径}
 ```
 
 其中:
 
-- `mask_path`:模型推理得到的.png格式的分割结果。
-- `save_path`:保存GeoJSON文件的路径。
+- `srcimg_path`:原始影像路径,需要带有地理元信息。
+- `geojson_path`:GeoJSON格式标签路径。
+- `save_path`:保存转换后的栅格文件的路径。
 
 ### match
 

+ 73 - 0
tools/geojson2mask.py

@@ -0,0 +1,73 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import codecs
+import argparse
+
+import paddlers
+import numpy as np
+import cv2
+import geojson
+from tqdm import tqdm
+
+from utils import Raster, save_geotiff, translate_vector, time_it
+
+
+def _gt_convert(x_geo, y_geo, geotf):
+    a = np.array([[geotf[1], geotf[2]], [geotf[4], geotf[5]]])
+    b = np.array([x_geo - geotf[0], y_geo - geotf[3]])
+    return np.round(np.linalg.solve(a,
+                                    b)).tolist()  # Solve a quadratic equation
+
+
+@time_it
+# TODO: update for vector2raster
+def convert_data(image_path, geojson_path, save_path):
+    raster = Raster(image_path)
+    tmp_img = np.zeros((raster.height, raster.width), dtype=np.int32)
+    # vector to EPSG from raster
+    temp_geojson_path = translate_vector(geojson_path, raster.proj)
+    geo_reader = codecs.open(temp_geojson_path, "r", encoding="utf-8")
+    feats = geojson.loads(geo_reader.read())["features"]  # All image patches
+    geo_reader.close()
+    for feat in tqdm(feats):
+        geo = feat["geometry"]
+        if geo["type"] == "Polygon":
+            geo_points = geo["coordinates"][0]
+        elif geo["type"] == "MultiPolygon":
+            geo_points = geo["coordinates"][0][0]
+        else:
+            raise TypeError(
+                "Geometry type must be 'Polygon' or 'MultiPolygon', not {}.".
+                format(geo["type"]))
+        xy_points = np.array([
+            _gt_convert(point[0], point[1], raster.geot) for point in geo_points
+        ]).astype(np.int32)
+        # TODO: Label category
+        cv2.fillPoly(tmp_img, [xy_points], 1)  # Fill with polygons
+    save_geotiff(tmp_img, save_path, raster.proj, raster.geot)
+    os.remove(temp_geojson_path)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--srcimg_path", type=str, required=True, \
+                        help="Path of the original image.")
+    parser.add_argument("--geojson_path", type=str, required=True, \
+                        help="Path of the GeoJSON file (the coordinate system is WGS84).")
+    parser.add_argument("--save_path", type=str, required=True, \
+                        help="Path to store the mask data.")
+    args = parser.parse_args()
+    convert_data(args.srcimg_path, args.geojson_path, args.save_path)