Przeglądaj źródła

Add eager sliding window

Bobholamovic 2 lat temu
rodzic
commit
b641b2fd93

+ 2 - 0
docs/apis/infer.md

@@ -159,6 +159,7 @@ def slider_predict(self,
                    invalid_value=255,
                    invalid_value=255,
                    merge_strategy='keep_last',
                    merge_strategy='keep_last',
                    batch_size=1,
                    batch_size=1,
+                   eager_load=False,
                    quiet=False):
                    quiet=False):
 ```
 ```
 
 
@@ -174,6 +175,7 @@ def slider_predict(self,
 |`invalid_value`|`int`|输出影像中用于标记无效像素的数值。|`255`|
 |`invalid_value`|`int`|输出影像中用于标记无效像素的数值。|`255`|
 |`merge_strategy`|`str`|合并滑窗重叠区域使用的策略。`'keep_first'`表示保留遍历顺序(从左至右,从上往下,列优先)最靠前的窗口的预测类别;`'keep_last'`表示保留遍历顺序最靠后的窗口的预测类别;`'accum'`表示通过将各窗口在重叠区域给出的预测概率累加,计算最终预测类别。需要注意的是,在对大尺寸影像进行`overlap`较大的密集推理时,使用`'accum'`策略可能导致较长的推理时间,但一般能够在窗口交界部分取得更好的表现。|`'keep_last'`|
 |`merge_strategy`|`str`|合并滑窗重叠区域使用的策略。`'keep_first'`表示保留遍历顺序(从左至右,从上往下,列优先)最靠前的窗口的预测类别;`'keep_last'`表示保留遍历顺序最靠后的窗口的预测类别;`'accum'`表示通过将各窗口在重叠区域给出的预测概率累加,计算最终预测类别。需要注意的是,在对大尺寸影像进行`overlap`较大的密集推理时,使用`'accum'`策略可能导致较长的推理时间,但一般能够在窗口交界部分取得更好的表现。|`'keep_last'`|
 |`batch_size`|`int`|预测时使用的mini-batch大小。|`1`|
 |`batch_size`|`int`|预测时使用的mini-batch大小。|`1`|
+|`eager_load`|`bool`|若为`True`,则不使用延迟内存载入,而是在预测开始时一次性将整幅影像载入到内存。|`False`|
 |`quiet`|`bool`|若为`True`,不显示预测进度。|`False`|
 |`quiet`|`bool`|若为`True`,不显示预测进度。|`False`|
 
 
 变化检测任务的滑窗推理API与图像分割任务类似,但需要注意的是输出结果中存储的地理变换、投影等信息以从第一时相影像中读取的信息为准,存储滑窗推理结果的文件名也与第一时相影像文件相同。
 变化检测任务的滑窗推理API与图像分割任务类似,但需要注意的是输出结果中存储的地理变换、投影等信息以从第一时相影像中读取的信息为准,存储滑窗推理结果的文件名也与第一时相影像文件相同。

+ 3 - 0
paddlers/deploy/predictor.py

@@ -332,6 +332,7 @@ class Predictor(object):
                        invalid_value=255,
                        invalid_value=255,
                        merge_strategy='keep_last',
                        merge_strategy='keep_last',
                        batch_size=1,
                        batch_size=1,
+                       eager_load=False,
                        quiet=False):
                        quiet=False):
         """
         """
         Do inference using sliding windows. Only semantic segmentation and change detection models are supported in the 
         Do inference using sliding windows. Only semantic segmentation and change detection models are supported in the 
@@ -356,6 +357,7 @@ class Predictor(object):
                 the last block in traversal order, respectively. 'accum' means determining the class of an overlapping pixel 
                 the last block in traversal order, respectively. 'accum' means determining the class of an overlapping pixel 
                 according to accumulated probabilities. Defaults to 'keep_last'.
                 according to accumulated probabilities. Defaults to 'keep_last'.
             batch_size (int, optional): Batch size used in inference. Defaults to 1.
             batch_size (int, optional): Batch size used in inference. Defaults to 1.
+            eager_load (bool, optional): Whether to load the whole image(s) eagerly. Defaults to False.
             quiet (bool, optional): If True, disable the progress bar. Defaults to False.
             quiet (bool, optional): If True, disable the progress bar. Defaults to False.
         """
         """
 
 
@@ -375,6 +377,7 @@ class Predictor(object):
             invalid_value,
             invalid_value,
             merge_strategy,
             merge_strategy,
             batch_size,
             batch_size,
+            eager_load,
             not quiet)
             not quiet)
 
 
     def batch_predict(self, image_list, **params):
     def batch_predict(self, image_list, **params):

+ 4 - 1
paddlers/tasks/change_detector.py

@@ -591,6 +591,7 @@ class BaseChangeDetector(BaseModel):
                        invalid_value=255,
                        invalid_value=255,
                        merge_strategy='keep_last',
                        merge_strategy='keep_last',
                        batch_size=1,
                        batch_size=1,
+                       eager_load=False,
                        quiet=False):
                        quiet=False):
         """
         """
         Do inference using sliding windows.
         Do inference using sliding windows.
@@ -615,12 +616,14 @@ class BaseChangeDetector(BaseModel):
                 order, respectively. 'accum' means determining the class of an overlapping 
                 order, respectively. 'accum' means determining the class of an overlapping 
                 pixel according to accumulated probabilities. Defaults to 'keep_last'.
                 pixel according to accumulated probabilities. Defaults to 'keep_last'.
             batch_size (int, optional): Batch size used in inference. Defaults to 1.
             batch_size (int, optional): Batch size used in inference. Defaults to 1.
+            eager_load (bool, optional): Whether to load the whole image(s) eagerly.
+                Defaults to False.
             quiet (bool, optional): If True, disable the progress bar. Defaults to False.
             quiet (bool, optional): If True, disable the progress bar. Defaults to False.
         """
         """
 
 
         slider_predict(self.predict, img_files, save_dir, block_size, overlap,
         slider_predict(self.predict, img_files, save_dir, block_size, overlap,
                        transforms, invalid_value, merge_strategy, batch_size,
                        transforms, invalid_value, merge_strategy, batch_size,
-                       not quiet)
+                       eager_load, not quiet)
 
 
     def preprocess(self, images, transforms, to_tensor=True):
     def preprocess(self, images, transforms, to_tensor=True):
         self._check_transforms(transforms, 'test')
         self._check_transforms(transforms, 'test')

+ 6 - 1
paddlers/tasks/segmenter.py

@@ -557,6 +557,7 @@ class BaseSegmenter(BaseModel):
                        invalid_value=255,
                        invalid_value=255,
                        merge_strategy='keep_last',
                        merge_strategy='keep_last',
                        batch_size=1,
                        batch_size=1,
+                       eager_load=False,
                        quiet=False):
                        quiet=False):
         """
         """
         Do inference using sliding windows.
         Do inference using sliding windows.
@@ -581,12 +582,14 @@ class BaseSegmenter(BaseModel):
                 order, respectively. 'accum' means determining the class of an overlapping 
                 order, respectively. 'accum' means determining the class of an overlapping 
                 pixel according to accumulated probabilities. Defaults to 'keep_last'.
                 pixel according to accumulated probabilities. Defaults to 'keep_last'.
             batch_size (int, optional): Batch size used in inference. Defaults to 1.
             batch_size (int, optional): Batch size used in inference. Defaults to 1.
+            eager_load (bool, optional): Whether to load the whole image(s) eagerly.
+                Defaults to False.
             quiet (bool, optional): If True, disable the progress bar. Defaults to False.
             quiet (bool, optional): If True, disable the progress bar. Defaults to False.
         """
         """
 
 
         slider_predict(self.predict, img_file, save_dir, block_size, overlap,
         slider_predict(self.predict, img_file, save_dir, block_size, overlap,
                        transforms, invalid_value, merge_strategy, batch_size,
                        transforms, invalid_value, merge_strategy, batch_size,
-                       not quiet)
+                       eager_load, not quiet)
 
 
     def preprocess(self, images, transforms, to_tensor=True):
     def preprocess(self, images, transforms, to_tensor=True):
         self._check_transforms(transforms, 'test')
         self._check_transforms(transforms, 'test')
@@ -609,6 +612,8 @@ class BaseSegmenter(BaseModel):
 
 
     @staticmethod
     @staticmethod
     def get_transforms_shape_info(batch_ori_shape, transforms):
     def get_transforms_shape_info(batch_ori_shape, transforms):
+        # TODO: Store transform meta info when applying transforms
+        # and not here
         batch_restore_list = list()
         batch_restore_list = list()
         for ori_shape in batch_ori_shape:
         for ori_shape in batch_ori_shape:
             restore_list = list()
             restore_list = list()

+ 72 - 30
paddlers/tasks/utils/slider_predict.py

@@ -212,36 +212,63 @@ def assign_border_weights(array, weight=0.5, border_ratio=0.25, inplace=True):
     return array
     return array
 
 
 
 
-def read_block(ds,
-               xoff,
-               yoff,
-               xsize,
-               ysize,
-               tar_xsize=None,
-               tar_ysize=None,
-               pad_val=0):
-    if tar_xsize is None:
-        tar_xsize = xsize
-    if tar_ysize is None:
-        tar_ysize = ysize
-    # Read data from dataset
-    block = ds.ReadAsArray(xoff, yoff, xsize, ysize)
-    c, real_ysize, real_xsize = block.shape
-    assert real_ysize == ysize and real_xsize == xsize
-    # [c, h, w] -> [h, w, c]
-    block = block.transpose((1, 2, 0))
-    if (real_ysize, real_xsize) != (tar_ysize, tar_xsize):
-        if real_ysize >= tar_ysize or real_xsize >= tar_xsize:
-            raise ValueError
-        padded_block = np.full(
-            (tar_ysize, tar_xsize, c), fill_value=pad_val, dtype=block.dtype)
-        # Fill
-        padded_block[:real_ysize, :real_xsize] = block
-        return padded_block
-    else:
+class BlockReader(metaclass=ABCMeta):
+    def __init__(self, ds):
+        super().__init__()
+        self.ds = ds
+
+    @abstractmethod
+    def read_block(self, xoff, yoff, xsize, ysize):
+        pass
+
+    def get_block(self,
+                  xoff,
+                  yoff,
+                  xsize,
+                  ysize,
+                  tar_xsize=None,
+                  tar_ysize=None,
+                  pad_val=0):
+        if tar_xsize is None:
+            tar_xsize = xsize
+        if tar_ysize is None:
+            tar_ysize = ysize
+        block = self.read_block(xoff, yoff, xsize, ysize)
+        c, real_ysize, real_xsize = block.shape
+        assert real_ysize == ysize and real_xsize == xsize
+        # [c, h, w] -> [h, w, c]
+        block = block.transpose((1, 2, 0))
+        if (real_ysize, real_xsize) != (tar_ysize, tar_xsize):
+            if real_ysize >= tar_ysize or real_xsize >= tar_xsize:
+                raise ValueError
+            padded_block = np.full(
+                (tar_ysize, tar_xsize, c),
+                fill_value=pad_val,
+                dtype=block.dtype)
+            # Fill
+            padded_block[:real_ysize, :real_xsize] = block
+            return padded_block
+        else:
+            return block
+
+
+class GDALLazyBlockReader(BlockReader):
+    def read_block(self, xoff, yoff, xsize, ysize):
+        block = self.ds.ReadAsArray(xoff, yoff, xsize, ysize)
         return block
         return block
 
 
 
 
+class EagerBlockReader(BlockReader):
+    def __init__(self, ds):
+        super().__init__(ds)
+        # Read the whole image eagerly
+        self._whole_image = self.ds.ReadAsArray()
+
+    def read_block(self, xoff, yoff, xsize, ysize):
+        # First dim is channel
+        return self._whole_image[:, yoff:yoff + ysize, xoff:xoff + xsize]
+
+
 def slider_predict(predict_func,
 def slider_predict(predict_func,
                    img_file,
                    img_file,
                    save_dir,
                    save_dir,
@@ -251,6 +278,7 @@ def slider_predict(predict_func,
                    invalid_value,
                    invalid_value,
                    merge_strategy,
                    merge_strategy,
                    batch_size,
                    batch_size,
+                   eager_load=False,
                    show_progress=False):
                    show_progress=False):
     """
     """
     Do inference using sliding windows.
     Do inference using sliding windows.
@@ -275,10 +303,19 @@ def slider_predict(predict_func,
             traversal order, respectively. 'accum' means determining the class 
             traversal order, respectively. 'accum' means determining the class 
             of an overlapping pixel according to accumulated probabilities.
             of an overlapping pixel according to accumulated probabilities.
         batch_size (int): Batch size used in inference.
         batch_size (int): Batch size used in inference.
+        eager_load (bool, optional): Whether to load the whole image(s) eagerly.
+            Defaults to False.
         show_progress (bool, optional): Whether to show prediction progress with a 
         show_progress (bool, optional): Whether to show prediction progress with a 
             progress bar. Defaults to True.
             progress bar. Defaults to True.
     """
     """
 
 
+    def _construct_reader(eager_load, *args, **kwargs):
+        if eager_load:
+            reader = EagerBlockReader(*args, **kwargs)
+        else:
+            reader = GDALLazyBlockReader(*args, **kwargs)
+        return reader
+
     try:
     try:
         from osgeo import gdal
         from osgeo import gdal
     except:
     except:
@@ -311,11 +348,14 @@ def slider_predict(predict_func,
             raise ValueError("Tuple `img_file` must have the length of two.")
             raise ValueError("Tuple `img_file` must have the length of two.")
         # Assume that two input images have the same size
         # Assume that two input images have the same size
         src_data = gdal.Open(img_file[0])
         src_data = gdal.Open(img_file[0])
+        reader = _construct_reader(eager_load=eager_load, ds=src_data)
         src2_data = gdal.Open(img_file[1])
         src2_data = gdal.Open(img_file[1])
+        reader2 = _construct_reader(eager_load=eager_load, ds=src2_data)
         # Output name is the same as the name of the first image
         # Output name is the same as the name of the first image
         file_name = osp.basename(osp.normpath(img_file[0]))
         file_name = osp.basename(osp.normpath(img_file[0]))
     else:
     else:
         src_data = gdal.Open(img_file)
         src_data = gdal.Open(img_file)
+        reader = _construct_reader(eager_load=eager_load, ds=src_data)
         file_name = osp.basename(osp.normpath(img_file))
         file_name = osp.basename(osp.normpath(img_file))
 
 
     # Get size of original raster
     # Get size of original raster
@@ -395,10 +435,10 @@ def slider_predict(predict_func,
                 is_end_of_col = False
                 is_end_of_col = False
 
 
             # Read
             # Read
-            im = read_block(src_data, xoff, yoff, xsize, ysize)
+            im = reader.get_block(xoff, yoff, xsize, ysize)
 
 
             if isinstance(img_file, tuple):
             if isinstance(img_file, tuple):
-                im2 = read_block(src2_data, xoff, yoff, xsize, ysize)
+                im2 = reader2.get_block(xoff, yoff, xsize, ysize)
                 batch_data.append((im, im2))
                 batch_data.append((im, im2))
             else:
             else:
                 batch_data.append(im)
                 batch_data.append(im)
@@ -423,7 +463,6 @@ def slider_predict(predict_func,
                     # Write to file
                     # Write to file
                     band.WriteArray(pred, xoff_, yoff_)
                     band.WriteArray(pred, xoff_, yoff_)
 
 
-                dst_data.FlushCache()
                 batch_data.clear()
                 batch_data.clear()
                 batch_offsets.clear()
                 batch_offsets.clear()
 
 
@@ -433,6 +472,9 @@ def slider_predict(predict_func,
                 pb.update(1)
                 pb.update(1)
                 pb.set_description("{} out of {} blocks processed.".format(
                 pb.set_description("{} out of {} blocks processed.".format(
                     cnt, num_blocks))
                     cnt, num_blocks))
+        # Flush cache when finishing each row
+        dst_data.FlushCache()
 
 
+    dst_data.FlushCache()
     dst_data = None
     dst_data = None
     logging.info("GeoTiff file saved in {}.".format(save_file))
     logging.info("GeoTiff file saved in {}.".format(save_file))

+ 27 - 0
tests/tasks/test_slider_predict.py

@@ -72,6 +72,33 @@ class _TestSliderPredictNamespace:
                     self.model.slider_predict(self.image_path, save_dir, 512, 0,
                     self.model.slider_predict(self.image_path, save_dir, 512, 0,
                                               self.transforms)
                                               self.transforms)
 
 
+        def test_eager_load(self):
+            with tempfile.TemporaryDirectory() as td:
+                # Lazy
+                save_dir = osp.join(td, 'lazy')
+                self.model.slider_predict(self.image_path, save_dir, 128, 64,
+                                          self.transforms)
+                pred_lazy = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+
+                # Eager
+                save_dir = osp.join(td, 'eager')
+                self.model.slider_predict(
+                    self.image_path,
+                    save_dir,
+                    128,
+                    64,
+                    self.transforms,
+                    eager_load=True)
+                pred_eager = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+
+                self.check_output_equal(pred_lazy, pred_eager)
+
         def test_merge_strategy(self):
         def test_merge_strategy(self):
             with tempfile.TemporaryDirectory() as td:
             with tempfile.TemporaryDirectory() as td:
                 # Whole-image inference using predict()
                 # Whole-image inference using predict()