Procházet zdrojové kódy

[Feature] Add mask2shp tool

geoyee před 3 roky
rodič
revize
7eb11d30b1
3 změnil soubory, kde provedl 140 přidání a 13 odebrání
  1. 46 12
      paddlers/datasets/raster.py
  2. 1 1
      tools/geojson2mask.py
  3. 93 0
      tools/mask2shp.py

+ 46 - 12
paddlers/datasets/raster.py

@@ -40,8 +40,16 @@ class Raster:
         super(Raster, self).__init__()
         super(Raster, self).__init__()
         if osp.exists(path):
         if osp.exists(path):
             self.path = path
             self.path = path
-            self.__src_data = np.load(path) if path.split(".")[-1] == "npy" \
-                                            else gdal.Open(path)
+            self.ext_type = path.split(".")[-1]
+            if self.ext_type.lower() in ["npy", "npz"]:
+                self._src_data = None
+            else:
+                try:
+                    # raster format support in GDAL: 
+                    # https://www.osgeo.cn/gdal/drivers/raster/index.html
+                    self._src_data = gdal.Open(path)
+                except:
+                    raise TypeError("Unsupported data format: `{}`".format(self.ext_type))
             self._getInfo()
             self._getInfo()
             self.to_uint8 = to_uint8
             self.to_uint8 = to_uint8
             self.setBands(band_list)
             self.setBands(band_list)
@@ -77,31 +85,57 @@ class Raster:
         Returns:
         Returns:
             np.ndarray: data's ndarray.
             np.ndarray: data's ndarray.
         """
         """
-        if start_loc is None:
-            return self._getAarray()
+        if self._src_data is not None:
+            if start_loc is None:
+                return self._getAarray()
+            else:
+                return self._getBlock(start_loc, block_size)
         else:
         else:
-            return self._getBlock(start_loc, block_size)
+            print("Numpy doesn't support blocking temporarily.")
+            return self._getNumpy()
 
 
     def _getInfo(self) -> None:
     def _getInfo(self) -> None:
-        self.bands = self.__src_data.RasterCount
-        self.width = self.__src_data.RasterXSize
-        self.height = self.__src_data.RasterYSize
+        if self._src_data is not None:
+            self.bands = self._src_data.RasterCount
+            self.width = self._src_data.RasterXSize
+            self.height = self._src_data.RasterYSize
+            self.geot = self._src_data.GetGeoTransform()
+            self.proj = self._src_data.GetProjection()
+        else:
+            d_shape = self._getNumpy().shape
+            if len(d_shape) == 3:
+                self.height, self.width, self.bands = d_shape
+            else:
+                self.height, self.width = d_shape
+                self.bands = 1
+            self.geot = None
+            self.proj = None
+
+    def _getNumpy(self):
+        ima = np.load(self.path)
+        if self.band_list is not None:
+            band_array = []
+            for b in self.band_list:
+                band_i = ima[:, :, b - 1]
+                band_array.append(band_i)
+            ima = np.stack(band_array, axis=0)
+        return ima
 
 
     def _getAarray(self, window: Union[None, List[int], Tuple[int]]=None) -> np.ndarray:
     def _getAarray(self, window: Union[None, List[int], Tuple[int]]=None) -> np.ndarray:
         if window is not None:
         if window is not None:
             xoff, yoff, xsize, ysize = window
             xoff, yoff, xsize, ysize = window
         if self.band_list is None:
         if self.band_list is None:
             if window is None:
             if window is None:
-                ima = self.__src_data.ReadAsArray()
+                ima = self._src_data.ReadAsArray()
             else:
             else:
-                ima = self.__src_data.ReadAsArray(xoff, yoff, xsize, ysize)
+                ima = self._src_data.ReadAsArray(xoff, yoff, xsize, ysize)
         else:
         else:
             band_array = []
             band_array = []
             for b in self.band_list:
             for b in self.band_list:
                 if window is None:
                 if window is None:
-                    band_i = self.__src_data.GetRasterBand(b).ReadAsArray()
+                    band_i = self._src_data.GetRasterBand(b).ReadAsArray()
                 else:
                 else:
-                    band_i = self.__src_data.GetRasterBand(b).ReadAsArray(xoff, yoff, xsize, ysize)
+                    band_i = self._src_data.GetRasterBand(b).ReadAsArray(xoff, yoff, xsize, ysize)
                 band_array.append(band_i)
                 band_array.append(band_i)
             ima = np.stack(band_array, axis=0)
             ima = np.stack(band_array, axis=0)
         if self.bands == 1:
         if self.bands == 1:

+ 1 - 1
tools/geojson2mask.py

@@ -100,7 +100,7 @@ parser = argparse.ArgumentParser(description="input parameters")
 parser.add_argument("--raw_folder", type=str, required=True, \
 parser.add_argument("--raw_folder", type=str, required=True, \
                     help="The folder path about original data, where `images` saves the original image, `annotation.json` saves the corresponding annotation information.")
                     help="The folder path about original data, where `images` saves the original image, `annotation.json` saves the corresponding annotation information.")
 parser.add_argument("--save_folder", type=str, required=True, \
 parser.add_argument("--save_folder", type=str, required=True, \
-                    help="The folder path to save the results, where `img` saves the image and `gt` saves the label")
+                    help="The folder path to save the results, where `img` saves the image and `gt` saves the label.")
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":

+ 93 - 0
tools/mask2shp.py

@@ -0,0 +1,93 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import os.path as osp
+sys.path.insert(0, osp.abspath(".."))  # add workspace
+
+import os
+import numpy as np
+import argparse
+from PIL import Image
+from paddlers.datasets.raster import Raster
+
+try:
+    from osgeo import gdal, ogr, osr
+except ImportError:
+    import gdal
+    import ogr
+    import osr
+
+
+def _mask2tif(mask_path, tmp_path, proj, geot):
+    mask = np.asarray(Image.open(mask_path))
+    if len(mask.shape) == 3:
+        mask = mask[:, :, 0]
+    row, columns = mask.shape[:2]
+    driver = gdal.GetDriverByName("GTiff")
+    dst_ds = driver.Create(tmp_path, columns, row, 1, gdal.GDT_UInt16)
+    dst_ds.SetGeoTransform(geot)
+    dst_ds.SetProjection(proj)
+    dst_ds.GetRasterBand(1).WriteArray(mask)
+    dst_ds.FlushCache()
+    return dst_ds
+
+
+def _polygonize_raster(mask_path, shp_save_path, proj, geot, ignore_index):
+    tmp_path = shp_save_path.replace(".shp", ".tif")
+    ds = _mask2tif(mask_path, tmp_path, proj, geot)
+    srcband = ds.GetRasterBand(1)
+    maskband = srcband.GetMaskBand()
+    gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
+    gdal.SetConfigOption("SHAPE_ENCODING", "UTF-8")
+    ogr.RegisterAll()
+    drv = ogr.GetDriverByName("ESRI Shapefile")
+    if osp.exists(shp_save_path):
+        os.remove(shp_save_path)
+    dst_ds = drv.CreateDataSource(shp_save_path)
+    prosrs = osr.SpatialReference(wkt=ds.GetProjection())
+    dst_layer = dst_ds.CreateLayer("Building boundary", geom_type=ogr.wkbPolygon, srs=prosrs)
+    dst_fieldname = "DN"
+    fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
+    dst_layer.CreateField(fd)
+    gdal.Polygonize(srcband, maskband, dst_layer, 0, [])
+    lyr = dst_ds.GetLayer()
+    lyr.SetAttributeFilter("DN = '{}'".format(str(ignore_index)))
+    for holes in lyr:
+        lyr.DeleteFeature(holes.GetFID())
+    dst_ds.Destroy()
+    ds = None
+    os.remove(tmp_path)
+
+
+def raster2shp(srcimg_path, mask_path, save_path, ignore_index=255):
+    src = Raster(srcimg_path)
+    _polygonize_raster(mask_path, save_path, src.proj, src.geot, ignore_index)
+    src = None
+
+
+parser = argparse.ArgumentParser(description="input parameters")
+parser.add_argument("--srcimg_path", type=str, required=True, \
+                    help="The path of original data with geoinfos.")
+parser.add_argument("--mask_path", type=str, required=True, \
+                    help="The path of mask data.")
+parser.add_argument("--save_path", type=str, default="output", \
+                    help="The path to save the results shapefile, `output` is the default.")
+parser.add_argument("--ignore_index", type=int, default=255, \
+                    help="It will not be converted to the value of SHP, `255` is the default.")
+
+
+if __name__ == "__main__":
+    args = parser.parse_args()
+    raster2shp(args.srcimg_path, args.mask_path, args.save_path, args.ignore_index)