Browse Source

Add eager sliding window

Bobholamovic 2 years ago
parent
commit
b641b2fd93

+ 2 - 0
docs/apis/infer.md

@@ -159,6 +159,7 @@ def slider_predict(self,
                    invalid_value=255,
                    merge_strategy='keep_last',
                    batch_size=1,
+                   eager_load=False,
                    quiet=False):
 ```
 
@@ -174,6 +175,7 @@ def slider_predict(self,
 |`invalid_value`|`int`|输出影像中用于标记无效像素的数值。|`255`|
 |`merge_strategy`|`str`|合并滑窗重叠区域使用的策略。`'keep_first'`表示保留遍历顺序(从左至右,从上往下,列优先)最靠前的窗口的预测类别;`'keep_last'`表示保留遍历顺序最靠后的窗口的预测类别;`'accum'`表示通过将各窗口在重叠区域给出的预测概率累加,计算最终预测类别。需要注意的是,在对大尺寸影像进行`overlap`较大的密集推理时,使用`'accum'`策略可能导致较长的推理时间,但一般能够在窗口交界部分取得更好的表现。|`'keep_last'`|
 |`batch_size`|`int`|预测时使用的mini-batch大小。|`1`|
+|`eager_load`|`bool`|若为`True`,则不使用延迟内存载入,而是在预测开始时一次性将整幅影像载入到内存。|`False`|
 |`quiet`|`bool`|若为`True`,不显示预测进度。|`False`|
 
 变化检测任务的滑窗推理API与图像分割任务类似,但需要注意的是输出结果中存储的地理变换、投影等信息以从第一时相影像中读取的信息为准,存储滑窗推理结果的文件名也与第一时相影像文件相同。

+ 3 - 0
paddlers/deploy/predictor.py

@@ -332,6 +332,7 @@ class Predictor(object):
                        invalid_value=255,
                        merge_strategy='keep_last',
                        batch_size=1,
+                       eager_load=False,
                        quiet=False):
         """
         Do inference using sliding windows. Only semantic segmentation and change detection models are supported in the 
@@ -356,6 +357,7 @@ class Predictor(object):
                 the last block in traversal order, respectively. 'accum' means determining the class of an overlapping pixel 
                 according to accumulated probabilities. Defaults to 'keep_last'.
             batch_size (int, optional): Batch size used in inference. Defaults to 1.
+            eager_load (bool, optional): Whether to load the whole image(s) eagerly. Defaults to False.
             quiet (bool, optional): If True, disable the progress bar. Defaults to False.
         """
 
@@ -375,6 +377,7 @@ class Predictor(object):
             invalid_value,
             merge_strategy,
             batch_size,
+            eager_load,
             not quiet)
 
     def batch_predict(self, image_list, **params):

+ 4 - 1
paddlers/tasks/change_detector.py

@@ -591,6 +591,7 @@ class BaseChangeDetector(BaseModel):
                        invalid_value=255,
                        merge_strategy='keep_last',
                        batch_size=1,
+                       eager_load=False,
                        quiet=False):
         """
         Do inference using sliding windows.
@@ -615,12 +616,14 @@ class BaseChangeDetector(BaseModel):
                 order, respectively. 'accum' means determining the class of an overlapping 
                 pixel according to accumulated probabilities. Defaults to 'keep_last'.
             batch_size (int, optional): Batch size used in inference. Defaults to 1.
+            eager_load (bool, optional): Whether to load the whole image(s) eagerly.
+                Defaults to False.
             quiet (bool, optional): If True, disable the progress bar. Defaults to False.
         """
 
         slider_predict(self.predict, img_files, save_dir, block_size, overlap,
                        transforms, invalid_value, merge_strategy, batch_size,
-                       not quiet)
+                       eager_load, not quiet)
 
     def preprocess(self, images, transforms, to_tensor=True):
         self._check_transforms(transforms, 'test')

+ 6 - 1
paddlers/tasks/segmenter.py

@@ -557,6 +557,7 @@ class BaseSegmenter(BaseModel):
                        invalid_value=255,
                        merge_strategy='keep_last',
                        batch_size=1,
+                       eager_load=False,
                        quiet=False):
         """
         Do inference using sliding windows.
@@ -581,12 +582,14 @@ class BaseSegmenter(BaseModel):
                 order, respectively. 'accum' means determining the class of an overlapping 
                 pixel according to accumulated probabilities. Defaults to 'keep_last'.
             batch_size (int, optional): Batch size used in inference. Defaults to 1.
+            eager_load (bool, optional): Whether to load the whole image(s) eagerly.
+                Defaults to False.
             quiet (bool, optional): If True, disable the progress bar. Defaults to False.
         """
 
         slider_predict(self.predict, img_file, save_dir, block_size, overlap,
                        transforms, invalid_value, merge_strategy, batch_size,
-                       not quiet)
+                       eager_load, not quiet)
 
     def preprocess(self, images, transforms, to_tensor=True):
         self._check_transforms(transforms, 'test')
@@ -609,6 +612,8 @@ class BaseSegmenter(BaseModel):
 
     @staticmethod
     def get_transforms_shape_info(batch_ori_shape, transforms):
+        # TODO: Store transform meta info when applying transforms
+        # and not here
         batch_restore_list = list()
         for ori_shape in batch_ori_shape:
             restore_list = list()

+ 72 - 30
paddlers/tasks/utils/slider_predict.py

@@ -212,36 +212,63 @@ def assign_border_weights(array, weight=0.5, border_ratio=0.25, inplace=True):
     return array
 
 
-def read_block(ds,
-               xoff,
-               yoff,
-               xsize,
-               ysize,
-               tar_xsize=None,
-               tar_ysize=None,
-               pad_val=0):
-    if tar_xsize is None:
-        tar_xsize = xsize
-    if tar_ysize is None:
-        tar_ysize = ysize
-    # Read data from dataset
-    block = ds.ReadAsArray(xoff, yoff, xsize, ysize)
-    c, real_ysize, real_xsize = block.shape
-    assert real_ysize == ysize and real_xsize == xsize
-    # [c, h, w] -> [h, w, c]
-    block = block.transpose((1, 2, 0))
-    if (real_ysize, real_xsize) != (tar_ysize, tar_xsize):
-        if real_ysize >= tar_ysize or real_xsize >= tar_xsize:
-            raise ValueError
-        padded_block = np.full(
-            (tar_ysize, tar_xsize, c), fill_value=pad_val, dtype=block.dtype)
-        # Fill
-        padded_block[:real_ysize, :real_xsize] = block
-        return padded_block
-    else:
+class BlockReader(metaclass=ABCMeta):
+    def __init__(self, ds):
+        super().__init__()
+        self.ds = ds
+
+    @abstractmethod
+    def read_block(self, xoff, yoff, xsize, ysize):
+        pass
+
+    def get_block(self,
+                  xoff,
+                  yoff,
+                  xsize,
+                  ysize,
+                  tar_xsize=None,
+                  tar_ysize=None,
+                  pad_val=0):
+        if tar_xsize is None:
+            tar_xsize = xsize
+        if tar_ysize is None:
+            tar_ysize = ysize
+        block = self.read_block(xoff, yoff, xsize, ysize)
+        c, real_ysize, real_xsize = block.shape
+        assert real_ysize == ysize and real_xsize == xsize
+        # [c, h, w] -> [h, w, c]
+        block = block.transpose((1, 2, 0))
+        if (real_ysize, real_xsize) != (tar_ysize, tar_xsize):
+            if real_ysize >= tar_ysize or real_xsize >= tar_xsize:
+                raise ValueError
+            padded_block = np.full(
+                (tar_ysize, tar_xsize, c),
+                fill_value=pad_val,
+                dtype=block.dtype)
+            # Fill
+            padded_block[:real_ysize, :real_xsize] = block
+            return padded_block
+        else:
+            return block
+
+
+class GDALLazyBlockReader(BlockReader):
+    def read_block(self, xoff, yoff, xsize, ysize):
+        block = self.ds.ReadAsArray(xoff, yoff, xsize, ysize)
         return block
 
 
+class EagerBlockReader(BlockReader):
+    def __init__(self, ds):
+        super().__init__(ds)
+        # Read the whole image eagerly
+        self._whole_image = self.ds.ReadAsArray()
+
+    def read_block(self, xoff, yoff, xsize, ysize):
+        # First dim is channel
+        return self._whole_image[:, yoff:yoff + ysize, xoff:xoff + xsize]
+
+
 def slider_predict(predict_func,
                    img_file,
                    save_dir,
@@ -251,6 +278,7 @@ def slider_predict(predict_func,
                    invalid_value,
                    merge_strategy,
                    batch_size,
+                   eager_load=False,
                    show_progress=False):
     """
     Do inference using sliding windows.
@@ -275,10 +303,19 @@ def slider_predict(predict_func,
             traversal order, respectively. 'accum' means determining the class 
             of an overlapping pixel according to accumulated probabilities.
         batch_size (int): Batch size used in inference.
+        eager_load (bool, optional): Whether to load the whole image(s) eagerly.
+            Defaults to False.
         show_progress (bool, optional): Whether to show prediction progress with a 
             progress bar. Defaults to True.
     """
 
+    def _construct_reader(eager_load, *args, **kwargs):
+        if eager_load:
+            reader = EagerBlockReader(*args, **kwargs)
+        else:
+            reader = GDALLazyBlockReader(*args, **kwargs)
+        return reader
+
     try:
         from osgeo import gdal
     except:
@@ -311,11 +348,14 @@ def slider_predict(predict_func,
             raise ValueError("Tuple `img_file` must have the length of two.")
         # Assume that two input images have the same size
         src_data = gdal.Open(img_file[0])
+        reader = _construct_reader(eager_load=eager_load, ds=src_data)
         src2_data = gdal.Open(img_file[1])
+        reader2 = _construct_reader(eager_load=eager_load, ds=src2_data)
         # Output name is the same as the name of the first image
         file_name = osp.basename(osp.normpath(img_file[0]))
     else:
         src_data = gdal.Open(img_file)
+        reader = _construct_reader(eager_load=eager_load, ds=src_data)
         file_name = osp.basename(osp.normpath(img_file))
 
     # Get size of original raster
@@ -395,10 +435,10 @@ def slider_predict(predict_func,
                 is_end_of_col = False
 
             # Read
-            im = read_block(src_data, xoff, yoff, xsize, ysize)
+            im = reader.get_block(xoff, yoff, xsize, ysize)
 
             if isinstance(img_file, tuple):
-                im2 = read_block(src2_data, xoff, yoff, xsize, ysize)
+                im2 = reader2.get_block(xoff, yoff, xsize, ysize)
                 batch_data.append((im, im2))
             else:
                 batch_data.append(im)
@@ -423,7 +463,6 @@ def slider_predict(predict_func,
                     # Write to file
                     band.WriteArray(pred, xoff_, yoff_)
 
-                dst_data.FlushCache()
                 batch_data.clear()
                 batch_offsets.clear()
 
@@ -433,6 +472,9 @@ def slider_predict(predict_func,
                 pb.update(1)
                 pb.set_description("{} out of {} blocks processed.".format(
                     cnt, num_blocks))
+        # Flush cache when finishing each row
+        dst_data.FlushCache()
 
+    dst_data.FlushCache()
     dst_data = None
     logging.info("GeoTiff file saved in {}.".format(save_file))

+ 27 - 0
tests/tasks/test_slider_predict.py

@@ -72,6 +72,33 @@ class _TestSliderPredictNamespace:
                     self.model.slider_predict(self.image_path, save_dir, 512, 0,
                                               self.transforms)
 
+        def test_eager_load(self):
+            with tempfile.TemporaryDirectory() as td:
+                # Lazy
+                save_dir = osp.join(td, 'lazy')
+                self.model.slider_predict(self.image_path, save_dir, 128, 64,
+                                          self.transforms)
+                pred_lazy = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+
+                # Eager
+                save_dir = osp.join(td, 'eager')
+                self.model.slider_predict(
+                    self.image_path,
+                    save_dir,
+                    128,
+                    64,
+                    self.transforms,
+                    eager_load=True)
+                pred_eager = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+
+                self.check_output_equal(pred_lazy, pred_eager)
+
         def test_merge_strategy(self):
             with tempfile.TemporaryDirectory() as td:
                 # Whole-image inference using predict()