Procházet zdrojové kódy

[Feature] Add slider in cd (#92)

* [Feature] Add cd slider

* [Fix] Tuple instead of list

* [Fix] Spell repair

* [Fix] Spell repair
Yizhou Chen před 2 roky
rodič
revize
05e1eeee6e
2 změnil soubory, kde provedl 86 přidání a 3 odebrání
  1. 83 0
      paddlers/tasks/change_detector.py
  2. 3 3
      paddlers/tasks/segmenter.py

+ 83 - 0
paddlers/tasks/change_detector.py

@@ -13,6 +13,7 @@
 # limitations under the License.
 # limitations under the License.
 
 
 import math
 import math
+import os
 import os.path as osp
 import os.path as osp
 from collections import OrderedDict
 from collections import OrderedDict
 from operator import attrgetter
 from operator import attrgetter
@@ -545,6 +546,88 @@ class BaseChangeDetector(BaseModel):
             }
             }
         return prediction
         return prediction
 
 
+    def slider_predict(self, img_file, save_dir, block_size, overlap=36, transforms=None):
+        """
+        Do inference.
+        Args:
+            Args:
+            img_file(List[str]):
+                List of image paths.
+            save_dir(str):
+                Directory that contains saved geotiff file.
+            block_size(List[int] or Tuple[int], int):
+                The size of block.
+            overlap(List[int] or Tuple[int], int):
+                The overlap between two blocks. Defaults to 36.
+            transforms(paddlers.transforms.Compose or None, optional):
+                Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
+        """
+        try:
+            from osgeo import gdal
+        except:
+            import gdal
+        
+        if len(img_file) != 2:
+            raise ValueError("`img_file` must be a list of length 2.")
+        if isinstance(block_size, int):
+            block_size = (block_size, block_size)
+        elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
+            block_size = tuple(block_size)
+        else:
+            raise ValueError("`block_size` must be a tuple/list of length 2 or an integer.")
+        if isinstance(overlap, int):
+            overlap = (overlap, overlap)
+        elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
+            overlap = tuple(overlap)
+        else:
+            raise ValueError("`overlap` must be a tuple/list of length 2 or an integer.")
+
+        src1_data = gdal.Open(img_file[0])
+        src2_data = gdal.Open(img_file[1])
+        width = src1_data.RasterXSize
+        height = src1_data.RasterYSize
+        bands = src1_data.RasterCount
+
+        driver = gdal.GetDriverByName("GTiff")
+        file_name = osp.splitext(osp.normpath(img_file[0]).split(os.sep)[-1])[0] + ".tif"
+        if not osp.exists(save_dir):
+            os.makedirs(save_dir)
+        save_file = osp.join(save_dir, file_name)
+        dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
+        dst_data.SetGeoTransform(src1_data.GetGeoTransform())
+        dst_data.SetProjection(src1_data.GetProjection())
+        band = dst_data.GetRasterBand(1)
+        band.WriteArray(255 * np.ones((height, width), dtype="uint8"))
+
+        step = np.array(block_size) - np.array(overlap)
+        for yoff in range(0, height, step[1]):
+            for xoff in range(0, width, step[0]):
+                xsize, ysize = block_size
+                if xoff + xsize > width:
+                    xsize = int(width - xoff)
+                if yoff + ysize > height:
+                    ysize = int(height - yoff)
+                im1 = src1_data.ReadAsArray(int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0))
+                im2 = src2_data.ReadAsArray(int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0))
+                # fill
+                h, w = im1.shape[:2]
+                im1_fill = np.zeros((block_size[1], block_size[0], bands), dtype=im1.dtype)
+                im2_fill = im1_fill.copy()
+                im1_fill[:h, :w, :] = im1
+                im2_fill[:h, :w, :] = im2
+                im_fill = (im1_fill, im2_fill)
+                # predict
+                pred = self.predict(im_fill, transforms)["label_map"].astype("uint8")
+                # overlap
+                rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize)
+                mask = (rd_block == pred[:h, :w]) | (rd_block == 255)
+                temp = pred[:h, :w].copy()
+                temp[mask == False] = 0
+                band.WriteArray(temp, int(xoff), int(yoff))
+                dst_data.FlushCache()
+        dst_data = None
+        print("GeoTiff saved in {}.".format(save_file))
+
     def _preprocess(self, images, transforms, to_tensor=True):
     def _preprocess(self, images, transforms, to_tensor=True):
         arrange_transforms(
         arrange_transforms(
             model_type=self.model_type, transforms=transforms, mode='test')
             model_type=self.model_type, transforms=transforms, mode='test')

+ 3 - 3
paddlers/tasks/segmenter.py

@@ -527,7 +527,7 @@ class BaseSegmenter(BaseModel):
             img_file(str):
             img_file(str):
                 Image path.
                 Image path.
             save_dir(str):
             save_dir(str):
-                Folder of geotiff saved.
+                Directory that contains saved geotiff file.
             block_size(List[int] or Tuple[int], int):
             block_size(List[int] or Tuple[int], int):
                 The size of block.
                 The size of block.
             overlap(List[int] or Tuple[int], int):
             overlap(List[int] or Tuple[int], int):
@@ -545,13 +545,13 @@ class BaseSegmenter(BaseModel):
         elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
         elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
             block_size = tuple(block_size)
             block_size = tuple(block_size)
         else:
         else:
-            raise ValueError("`block_size` must be a tuple/list of length 2 or a integer.")
+            raise ValueError("`block_size` must be a tuple/list of length 2 or an integer.")
         if isinstance(overlap, int):
         if isinstance(overlap, int):
             overlap = (overlap, overlap)
             overlap = (overlap, overlap)
         elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
         elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
             overlap = tuple(overlap)
             overlap = tuple(overlap)
         else:
         else:
-            raise ValueError("`overlap` must be a tuple/list of length 2 or a integer.")
+            raise ValueError("`overlap` must be a tuple/list of length 2 or an integer.")
 
 
         src_data = gdal.Open(img_file)
         src_data = gdal.Open(img_file)
         width = src_data.RasterXSize
         width = src_data.RasterXSize