Преглед на файлове

[Feature][Tools] Add OIF and Timer (#30)

* [Feature][Tools] Add OIF and Timer

* [Fix] Spelling modification
Yizhou Chen преди 3 години
родител
ревизия
144a47fe42
променени са 7 файла, в които са добавени 120 реда и са изтрити 25 реда
  1. 8 6
      tools/geojson2mask.py
  2. 6 4
      tools/mask2shp.py
  3. 8 10
      tools/matcher.py
  4. 64 0
      tools/oif.py
  5. 6 4
      tools/spliter.py
  6. 2 1
      tools/utils/__init__.py
  7. 26 0
      tools/utils/timer.py

+ 8 - 6
tools/geojson2mask.py

@@ -23,6 +23,7 @@ import glob
 from tqdm import tqdm
 from PIL import Image
 from collections import defaultdict
+from utils import Timer
 
 
 def _mkdir_p(path):
@@ -34,9 +35,9 @@ def _save_palette(label, save_path):
     bin_colormap = np.ones((256, 3)) * 255
     bin_colormap[0, :] = [0, 0, 0]
     bin_colormap = bin_colormap.astype(np.uint8)
-    visualimg  = Image.fromarray(label, "P")
+    visualimg = Image.fromarray(label, "P")
     palette = bin_colormap
-    visualimg.putpalette(palette) 
+    visualimg.putpalette(palette)
     visualimg.save(save_path, format='PNG')
 
 
@@ -44,7 +45,8 @@ def _save_mask(annotation, image_size, save_path):
     mask = np.zeros(image_size, dtype=np.int32)
     for contour_points in annotation:
         contour_points = np.array(contour_points).reshape((-1, 2))
-        contour_points = np.round(contour_points).astype(np.int32)[np.newaxis, :]
+        contour_points = np.round(contour_points).astype(np.int32)[
+            np.newaxis, :]
         cv2.fillPoly(mask, contour_points, 1)
     _save_palette(mask.astype("uint8"), save_path)
 
@@ -65,6 +67,7 @@ def _read_geojson(json_path):
         return annotations, sizes
 
 
+@Timer
 def convert_data(raw_folder, end_folder):
     print("-- Initializing --")
     img_folder = osp.join(raw_folder, "images")
@@ -83,7 +86,7 @@ def convert_data(raw_folder, end_folder):
         sizes.update(j_size)
     print("-- Converting datas --")
     for k in tqdm(names):
-    # for k in tqdm(anns.keys()):
+        # for k in tqdm(anns.keys()):
         img_path = osp.join(img_folder, k)
         img_save_path = osp.join(save_img_folder, k)
         ext = "." + k.split(".")[-1]
@@ -102,7 +105,6 @@ parser.add_argument("--raw_folder", type=str, required=True, \
 parser.add_argument("--save_folder", type=str, required=True, \
                     help="The folder path to save the results, where `img` saves the image and `gt` saves the label.")
 
-
 if __name__ == "__main__":
     args = parser.parse_args()
-    convert_data(args.raw_folder, args.save_folder)
+    convert_data(args.raw_folder, args.save_folder)

+ 6 - 4
tools/mask2shp.py

@@ -17,7 +17,7 @@ import os.path as osp
 import numpy as np
 import argparse
 from PIL import Image
-from utils import Raster
+from utils import Raster, Timer
 
 try:
     from osgeo import gdal, ogr, osr
@@ -54,7 +54,8 @@ def _polygonize_raster(mask_path, shp_save_path, proj, geot, ignore_index):
         os.remove(shp_save_path)
     dst_ds = drv.CreateDataSource(shp_save_path)
     prosrs = osr.SpatialReference(wkt=ds.GetProjection())
-    dst_layer = dst_ds.CreateLayer("Building boundary", geom_type=ogr.wkbPolygon, srs=prosrs)
+    dst_layer = dst_ds.CreateLayer(
+        "Building boundary", geom_type=ogr.wkbPolygon, srs=prosrs)
     dst_fieldname = "DN"
     fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
     dst_layer.CreateField(fd)
@@ -68,6 +69,7 @@ def _polygonize_raster(mask_path, shp_save_path, proj, geot, ignore_index):
     os.remove(tmp_path)
 
 
+@Timer
 def raster2shp(srcimg_path, mask_path, save_path, ignore_index=255):
     src = Raster(srcimg_path)
     _polygonize_raster(mask_path, save_path, src.proj, src.geot, ignore_index)
@@ -84,7 +86,7 @@ parser.add_argument("--save_path", type=str, default="output", \
 parser.add_argument("--ignore_index", type=int, default=255, \
                     help="It will not be converted to the value of SHP, `255` is the default.")
 
-
 if __name__ == "__main__":
     args = parser.parse_args()
-    raster2shp(args.srcimg_path, args.mask_path, args.save_path, args.ignore_index)
+    raster2shp(args.srcimg_path, args.mask_path, args.save_path,
+               args.ignore_index)

+ 8 - 10
tools/matcher.py

@@ -15,8 +15,7 @@
 import numpy as np
 import cv2
 import argparse
-import time
-from utils import Raster, raster2uint8
+from utils import Raster, raster2uint8, Timer
 
 try:
     from osgeo import gdal
@@ -24,7 +23,7 @@ except ImportError:
     import gdal
 
 
-class MatchError (Exception):
+class MatchError(Exception):
     def __str__(self):
         return "Cannot match two images."
 
@@ -45,7 +44,8 @@ def _calcu_tf(im1, im2):
                                       for m in good_matches]).reshape(-1, 1, 2)
     den_automatic_points = np.float32([kp1[m[0].trainIdx].pt \
                                       for m in good_matches]).reshape(-1, 1, 2)
-    H, _ = cv2.findHomography(src_automatic_points, den_automatic_points, cv2.RANSAC, 5.0)
+    H, _ = cv2.findHomography(src_automatic_points, den_automatic_points,
+                              cv2.RANSAC, 5.0)
     return H
 
 
@@ -83,6 +83,7 @@ def _img2tif(ima, save_path, proj, geot, dtype):
     return dst_ds
 
 
+@Timer
 def matching(im1_path, im2_path, im1_bands=[1, 2, 3], im2_bands=[1, 2, 3]):
     im1_ras = Raster(im1_path)
     im2_ras = Raster(im2_path)
@@ -92,10 +93,11 @@ def matching(im1_path, im2_path, im1_bands=[1, 2, 3], im2_bands=[1, 2, 3]):
     # test
     # im2_t = cv2.warpPerspective(im2, H, (im1.shape[1], im1.shape[0]))
     # cv2.imwrite("B_M.png", cv2.cvtColor(im2_t, cv2.COLOR_RGB2BGR))
-    im2_arr_t = cv2.warpPerspective(im2_ras.getArray(), H, (im1_ras.width, im1_ras.height))
+    im2_arr_t = cv2.warpPerspective(im2_ras.getArray(), H,
+                                    (im1_ras.width, im1_ras.height))
     save_path = im2_ras.path.replace(("." + im2_ras.ext_type), "_M.tif")
     _img2tif(im2_arr_t, save_path, im1_ras.proj, im1_ras.geot, im1_ras.datatype)
-    
+
 
 parser = argparse.ArgumentParser(description="input parameters")
 parser.add_argument("--im1_path", type=str, required=True, \
@@ -107,10 +109,6 @@ parser.add_argument("--im1_bands", type=int, nargs="+", default=[1, 2, 3], \
 parser.add_argument("--im2_bands", type=int, nargs="+", default=[1, 2, 3], \
                     help="The time2 image's band used for matching, RGB or monochrome, `[1, 2, 3]` is the default.")
 
-
 if __name__ == "__main__":
     args = parser.parse_args()
-    start_time = time.time()
     matching(args.im1_path, args.im2_path, args.im1_bands, args.im2_bands)
-    end_time = time.time()
-    print("Total time:", (end_time - start_time))

+ 64 - 0
tools/oif.py

@@ -0,0 +1,64 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import pandas as pd
+import itertools
+import argparse
+from utils import Raster, Timer
+from easydict import EasyDict as edict
+
+
+def _calcOIF(rgb, stds, rho):
+    r, g, b = rgb
+    s1 = stds[int(r)]
+    s2 = stds[int(g)]
+    s3 = stds[int(b)]
+    r12 = rho[int(r), int(g)]
+    r23 = rho[int(g), int(b)]
+    r31 = rho[int(b), int(r)]
+    return (s1 + s2 + s3) / (abs(r12) + abs(r23) + abs(r31))
+
+
+@Timer
+def oif(img_path, topk=5):
+    raster = Raster(img_path)
+    img = raster.getArray()
+    img_flatten = img.reshape([-1, raster.bands])
+    stds = np.std(img_flatten, axis=0)
+    datas = edict()
+    for c in range(raster.bands):
+        datas[str(c + 1)] = img_flatten[:, c]
+    datas = pd.DataFrame(datas)
+    rho = datas.corr().values
+    band_combs = edict()
+    for rgb in itertools.combinations(list(range(raster.bands)), 3):
+        band_combs[str(rgb)] = _calcOIF(rgb, stds, rho)
+    band_combs = sorted(
+        band_combs.items(), key=lambda kv: (kv[1], kv[0]), reverse=True)
+    print("== Optimal band combination ==")
+    for i in range(topk):
+        k, v = band_combs[i]
+        print("Bands: {0}, OIF value: {1}.".format(k, v))
+
+
+parser = argparse.ArgumentParser(description="input parameters")
+parser.add_argument("--im_path", type=str, required=True, \
+                    help="The path of HSIs image.")
+parser.add_argument("--topk", type=int, default=5, \
+                    help="Number of top results, `5` is the default.")
+
+if __name__ == "__main__":
+    args = parser.parse_args()
+    oif(args.im_path, args.topk)

+ 6 - 4
tools/spliter.py

@@ -17,9 +17,10 @@ import os.path as osp
 import argparse
 from math import ceil
 from PIL import Image
-from utils import Raster
+from utils import Raster, Timer
 
 
+@Timer
 def split_data(image_path, block_size, save_folder):
     if not osp.exists(save_folder):
         os.makedirs(save_folder)
@@ -31,8 +32,10 @@ def split_data(image_path, block_size, save_folder):
     for r in range(rows):
         for c in range(cols):
             loc_start = (c * block_size, r * block_size)
-            title = Image.fromarray(raster.getArray(loc_start, (block_size, block_size)))
-            save_path = osp.join(save_folder, (image_name + "_" + str(r) + "_" + str(c) + ".png"))
+            title = Image.fromarray(
+                raster.getArray(loc_start, (block_size, block_size)))
+            save_path = osp.join(save_folder, (
+                image_name + "_" + str(r) + "_" + str(c) + ".png"))
             title.save(save_path, "PNG")
             print("-- {:d}/{:d} --".format(int(r * cols + c + 1), total_number))
 
@@ -45,7 +48,6 @@ parser.add_argument("--block_size", type=int, default=512, \
 parser.add_argument("--save_folder", type=str, default="output", \
                     help="The folder path to save the results, `output` is the default.")
 
-
 if __name__ == "__main__":
     args = parser.parse_args()
     split_data(args.image_path, args.block_size, args.save_folder)

+ 2 - 1
tools/utils/__init__.py

@@ -16,4 +16,5 @@ import sys
 import os.path as osp
 sys.path.insert(0, osp.abspath(".."))  # add workspace
 
-from .raster import Raster, raster2uint8
+from .raster import Raster, raster2uint8
+from .timer import Timer

+ 26 - 0
tools/utils/timer.py

@@ -0,0 +1,26 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+
+
+class Timer(object):
+    def __init__(self, func):
+        self.func = func
+
+    def __call__(self, *args, **kwds):
+        start_time = time.time()
+        func_t = self.func(*args, **kwds)
+        print("Total time: {0}.".format(time.time() - start_time))
+        return func_t