Browse Source

[Feature] Update matching tool

geoyee 3 năm trước cách đây
mục cha
commit
35eee43fba
3 tập tin đã thay đổi với 25 bổ sung18 xóa
  1. 4 10
      docs/tools.md
  2. 10 6
      tools/matcher.py
  3. 11 2
      tools/utils/raster.py

+ 4 - 10
docs/tools.md

@@ -51,25 +51,19 @@ python mask2shp.py --srcimg_path xxx.tif --mask_path xxx.png [--save_path output
 ` matcher`的主要功能是在进行变化检测的推理前,匹配两期影像的位置,并将转换后的`im2`图像保存在原地址下,命名为`im2_M.tif`。使用代码如下:
 
 ```shell
-python geojson2mask.py --raw_folder xxx --save_folder xxx
-```
-
-### spliter
-
-`geojson2mask`的主要功能是将图像以及对应json格式的分割标签转换为图像与png格式的标签,结果会分别存放在`img`和`gt`两个文件夹中。相关的数据样例可以参考[中国典型城市建筑物实例数据集](https://www.scidb.cn/detail?dataSetId=806674532768153600&dataSetType=journal)。使用代码如下:
-
-```shell
-python matcher.py --im1_path xxx.tif --im2_path xxx.xxx
+python matcher.py --im1_path xxx.tif --im2_path xxx.xxx [--im1_bands 1 2 3] [--im2_bands 1 2 3]
 ```
 
 其中:
 
 - `im1_path`:时段一的图像路径,该图像需要存在地理信息,且以该图像为基准图像。
 - `im2_path`:时段二的图像路径,该图像可以为非遥感格式的图像,该图像为带匹配图像。
+- `im1_bands`:时段一图像所用于配准的波段,为RGB或单通道,默认为[1, 2, 3]。
+- `im2_bands`:时段二图像所用于配准的波段,为RGB或单通道,默认为[1, 2, 3]。
 
 ### spliter
 
-` spliter`的主要功能是在划分大的遥感图像为图像块,便于进行训练。使用代码如下:
+`spliter`的主要功能是在划分大的遥感图像为图像块,便于进行训练。使用代码如下:
 
 ```shell
 python spliter.py --image_path xxx.tif [--block_size 512] [--save_folder output]

+ 10 - 6
tools/matcher.py

@@ -57,21 +57,21 @@ def _get_match_img(raster, bands):
         band_i = raster.GetRasterBand(b).ReadAsArray()
         band_array.append(band_i)
     if len(band_array) == 1:
-        ima = raster2uint8(band_array)
+        ima = raster2uint8(band_array[0])
     else:
         ima = raster2uint8(np.stack(band_array, axis=-1))
         ima = cv2.cvtColor(ima, cv2.COLOR_RGB2GRAY)
     return ima
 
 
-def _img2tif(ima, save_path, proj, geot):
+def _img2tif(ima, save_path, proj, geot, dtype):
     if len(ima.shape) == 3:
         row, columns, bands = ima.shape
     else:
         row, columns = ima.shape
         bands = 1
     driver = gdal.GetDriverByName("GTiff")
-    dst_ds = driver.Create(save_path, columns, row, bands, gdal.GDT_UInt16)
+    dst_ds = driver.Create(save_path, columns, row, bands, dtype)
     dst_ds.SetGeoTransform(geot)
     dst_ds.SetProjection(proj)
     if bands != 1:
@@ -94,19 +94,23 @@ def matching(im1_path, im2_path, im1_bands=[1, 2, 3], im2_bands=[1, 2, 3]):
     # cv2.imwrite("B_M.png", cv2.cvtColor(im2_t, cv2.COLOR_RGB2BGR))
     im2_arr_t = cv2.warpPerspective(im2_ras.getArray(), H, (im1_ras.width, im1_ras.height))
     save_path = im2_ras.path.replace(("." + im2_ras.ext_type), "_M.tif")
-    _img2tif(im2_arr_t, save_path, im1_ras.proj, im1_ras.geot)
+    _img2tif(im2_arr_t, save_path, im1_ras.proj, im1_ras.geot, im1_ras.datatype)
     
 
 parser = argparse.ArgumentParser(description="input parameters")
 parser.add_argument("--im1_path", type=str, required=True, \
                     help="The path of time1 image (with geoinfo).")
 parser.add_argument("--im2_path", type=str, required=True, \
-                    help="The path of time1 image.")
+                    help="The path of time2 image.")
+parser.add_argument("--im1_bands", type=int, nargs="+", default=[1, 2, 3], \
+                    help="The time1 image's band used for matching, RGB or monochrome, `[1, 2, 3]` is the default.")
+parser.add_argument("--im2_bands", type=int, nargs="+", default=[1, 2, 3], \
+                    help="The time2 image's band used for matching, RGB or monochrome, `[1, 2, 3]` is the default.")
 
 
 if __name__ == "__main__":
     args = parser.parse_args()
     start_time = time.time()
-    matching(args.im1_path, args.im2_path)
+    matching(args.im1_path, args.im2_path, args.im1_bands, args.im2_bands)
     end_time = time.time()
     print("Total time:", (end_time - start_time))

+ 11 - 2
tools/utils/raster.py

@@ -50,9 +50,9 @@ class Raster:
                     self._src_data = gdal.Open(path)
                 except:
                     raise TypeError("Unsupported data format: `{}`".format(self.ext_type))
-            self._getInfo()
             self.to_uint8 = to_uint8
             self.setBands(band_list)
+            self._getInfo()
         else:
             raise ValueError("The path {0} not exists.".format(path))
 
@@ -101,8 +101,11 @@ class Raster:
             self.height = self._src_data.RasterYSize
             self.geot = self._src_data.GetGeoTransform()
             self.proj = self._src_data.GetProjection()
+            d_name = self._getBlock([0, 0], [1, 1]).dtype.name
         else:
-            d_shape = self._getNumpy().shape
+            d_img = self._getNumpy()
+            d_shape = d_img.shape
+            d_name = d_img.dtype.name
             if len(d_shape) == 3:
                 self.height, self.width, self.bands = d_shape
             else:
@@ -110,6 +113,12 @@ class Raster:
                 self.bands = 1
             self.geot = None
             self.proj = None
+        if "int8" in d_name:
+            self.datatype = gdal.GDT_Byte
+        elif "int16" in d_name:
+            self.datatype = gdal.GDT_UInt16
+        else:
+            self.datatype = gdal.GDT_Float32
 
     def _getNumpy(self):
         ima = np.load(self.path)