Эх сурвалжийг харах

Merge pull request #34 from Bobholamovic/enhance_slide

[Feat] Enhance `slider_predict()`
cc 3 жил өмнө
parent
commit
18b399cb0a

+ 4 - 2
docs/apis/data.md

@@ -134,10 +134,12 @@
 |-------|----|--------|-----|
 |`im_path`|`str`|输入图像路径。||
 |`to_rgb`|`bool`|若为`True`,则执行BGR到RGB格式的转换。|`True`|
-|`to_uint8`|`bool`|若为`True`,则将读取的像数据量化并转换为uint8类型。|`True`|
+|`to_uint8`|`bool`|若为`True`,则将读取的像数据量化并转换为uint8类型。|`True`|
 |`decode_bgr`|`bool`|若为`True`,则自动将非地学格式影像(如jpeg影像)解析为BGR格式。|`True`|
-|`decode_sar`|`bool`|若为`True`,则自动将2通道的地学格式影像(如GeoTiff影像)作为SAR影像解析。|`True`|
+|`decode_sar`|`bool`|若为`True`,则自动将通道的地学格式影像(如GeoTiff影像)作为SAR影像解析。|`True`|
 |`read_geo_info`|`bool`|若为`True`,则从影像中读取地理信息。|`False`|
+|`use_stretch`|`bool`|是否对影像亮度进行2%线性拉伸。仅当`to_uint8`为`True`时有效。|`False`|
+|`read_raw`|`bool`|若为`True`,等价于指定`to_rgb`和`to_uint8`为`False`,且该参数的优先级高于上述参数。|`False`|
 
 返回格式如下:
 

+ 17 - 4
docs/apis/infer.md

@@ -155,7 +155,11 @@ def slider_predict(self,
                    save_dir,
                    block_size,
                    overlap=36,
-                   transforms=None):
+                   transforms=None,
+                   invalid_value=255,
+                   merge_strategy='keep_last',
+                   batch_size=1,
+                   quiet=False):
 ```
 
 输入参数列表:
@@ -164,11 +168,15 @@ def slider_predict(self,
 |-------|----|--------|-----|
 |`img_file`|`str`|输入影像路径。||
 |`save_dir`|`str`|预测结果输出路径。||
-|`block_size`|`list[int]` \| `tuple[int]` \| `int`|滑窗的窗口大小(以列表或元组指定长、宽或以一个整数指定相同的宽)。||
-|`overlap`|`list[int]` \| `tuple[int]` \| `int`|滑窗的滑动步长(以列表或元组指定长、宽或以一个整数指定相同的宽)。|`36`|
+|`block_size`|`list[int]` \| `tuple[int]` \| `int`|滑窗的窗口大小(以列表或元组指定宽度、高度或以一个整数指定相同的宽)。||
+|`overlap`|`list[int]` \| `tuple[int]` \| `int`|滑窗的滑动步长(以列表或元组指定宽度、高度或以一个整数指定相同的宽)。|`36`|
 |`transforms`|`paddlers.transforms.Compose` \| `None`|对输入数据应用的数据变换算子。若为`None`,则使用训练器在验证阶段使用的数据变换算子。|`None`|
+|`invalid_value`|`int`|输出影像中用于标记无效像素的数值。|`255`|
+|`merge_strategy`|`str`|合并滑窗重叠区域使用的策略。`'keep_first'`表示保留遍历顺序(从左至右,从上往下,列优先)最靠前的窗口的预测类别;`'keep_last'`表示保留遍历顺序最靠后的窗口的预测类别;`'accum'`表示通过将各窗口在重叠区域给出的预测概率累加,计算最终预测类别。需要注意的是,在对大尺寸影像进行`overlap`较大的密集推理时,使用`'accum'`策略可能导致较长的推理时间,但一般能够在窗口交界部分取得更好的表现。|`'keep_last'`|
+|`batch_size`|`int`|预测时使用的mini-batch大小。|`1`|
+|`quiet`|`bool`|若为`True`,不显示预测进度。|`False`|
 
-变化检测任务的滑窗推理API与图像分割任务类似,但需要注意的是输出结果中存储的地理变换、投影等信息以从第一时相影像中读取的信息为准。
+变化检测任务的滑窗推理API与图像分割任务类似,但需要注意的是输出结果中存储的地理变换、投影等信息以从第一时相影像中读取的信息为准,存储滑窗推理结果的文件名也与第一时相影像文件相同
 
 ## 静态图推理API
 
@@ -216,5 +224,10 @@ def predict(self,
 |`transforms`|`paddlers.transforms.Compose`\|`None`|对输入数据应用的数据变换算子。若为`None`,则使用从`model.yml`中读取的算子。|`None`|
 |`warmup_iters`|`int`|预热轮数,用于评估模型推理以及前后处理速度。若大于1,将预先重复执行`warmup_iters`次推理,而后才开始正式的预测及其速度评估。|`0`|
 |`repeats`|`int`|重复次数,用于评估模型推理以及前后处理速度。若大于1,将执行`repeats`次预测并取时间平均值。|`1`|
+|`quiet`|`bool`|若为`True`,不打印计时信息。|`False`|
 
 `Predictor.predict()`的返回格式与相应的动态图推理API的返回格式完全相同,详情请参考[动态图推理API](#动态图推理api)。
+
+### `Predictor.slider_predict()`
+
+实现滑窗推理功能。用法与`BaseSegmenter`和`BaseChangeDetector`的`slider_predict()`方法相同。

+ 58 - 5
paddlers/deploy/predictor.py

@@ -14,6 +14,7 @@
 
 import os.path as osp
 from operator import itemgetter
+from functools import partial
 
 import numpy as np
 import paddle
@@ -23,6 +24,7 @@ from paddle.inference import PrecisionType
 
 from paddlers.tasks import load_model
 from paddlers.utils import logging, Timer
+from paddlers.tasks.utils.slider_predict import slider_predict
 
 
 class Predictor(object):
@@ -271,22 +273,24 @@ class Predictor(object):
                 topk=1,
                 transforms=None,
                 warmup_iters=0,
-                repeats=1):
+                repeats=1,
+                quiet=False):
         """
-        Do prediction.
+        Do inference.
 
         Args:
             img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For scene classification, image restoration, 
                 object detection and semantic segmentation tasks, `img_file` should be either the path of the image to predict,
                 a decoded image (a np.ndarray, which should be consistent with what you get from passing image path to
-                paddlers.transforms.decode_image()), or a list of image paths or decoded images. For change detection tasks,
-                img_file should be a tuple of image paths, a tuple of decoded images, or a list of tuples.
+                paddlers.transforms.decode_image(..., read_raw=True)), or a list of image paths or decoded images. For change 
+                detection tasks, `img_file` should be a tuple of image paths, a tuple of decoded images, or a list of tuples.
             topk(int, optional): Top-k values to reserve in a classification result. Defaults to 1.
             transforms (paddlers.transforms.Compose|None, optional): Pipeline of data preprocessing. If None, load transforms
                 from `model.yml`. Defaults to None.
             warmup_iters (int, optional): Warm-up iterations before measuring the execution time. Defaults to 0.
             repeats (int, optional): Number of repetitions to evaluate model inference and data processing speed. If greater than
                 1, the reported time consumption is the average of all repeats. Defaults to 1.
+            quiet (bool, optional): If True, do not display the timing information. Defaults to False.
         """
 
         if repeats < 1:
@@ -313,12 +317,61 @@ class Predictor(object):
 
         self.timer.repeats = repeats
         self.timer.img_num = len(images)
-        self.timer.info(average=True)
+        if not quiet:
+            self.timer.info(average=True)
 
         if isinstance(img_file, (str, np.ndarray, tuple)):
             results = results[0]
 
         return results
 
+    def slider_predict(self,
+                       img_file,
+                       save_dir,
+                       block_size,
+                       overlap=36,
+                       transforms=None,
+                       invalid_value=255,
+                       merge_strategy='keep_last',
+                       batch_size=1,
+                       quiet=False):
+        """
+        Do inference using sliding windows. Only semantic segmentation and change detection models are supported in the 
+            sliding-predicting mode.
+
+        Args:
+            img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For semantic segmentation tasks, `img_file` 
+                should be either the path of the image to predict, a decoded image (a np.ndarray, which should be 
+                consistent with what you get from passing image path to paddlers.transforms.decode_image(..., read_raw=True)), 
+                or a list of image paths or decoded images. For change detection tasks, `img_file` should be a tuple of 
+                image paths, a tuple of decoded images, or a list of tuples.
+            save_dir (str): Directory that contains saved geotiff file.
+            block_size (list[int] | tuple[int] | int): Size of block. If `block_size` is a list or tuple, it should be in 
+                (W, H) format.
+            overlap (list[int] | tuple[int] | int, optional): Overlap between two blocks. If `overlap` is a list or tuple, 
+                it should be in (W, H) format. Defaults to 36.
+            transforms (paddlers.transforms.Compose|None, optional): Pipeline of data preprocessing. If None, load transforms
+                from `model.yml`. Defaults to None.
+            invalid_value (int, optional): Value that marks invalid pixels in output image. Defaults to 255.
+            merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices are 
+                {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' means keeping the values of the first and 
+                the last block in traversal order, respectively. 'accum' means determining the class of an overlapping pixel 
+                according to accumulated probabilities. Defaults to 'keep_last'.
+            batch_size (int, optional): Batch size used in inference. Defaults to 1.
+            quiet (bool, optional): If True, disable the progress bar. Defaults to False.
+        """
+        slider_predict(
+            partial(
+                self.predict, quiet=True),
+            img_file,
+            save_dir,
+            block_size,
+            overlap,
+            transforms,
+            invalid_value,
+            merge_strategy,
+            batch_size,
+            not quiet)
+
     def batch_predict(self, image_list, **params):
         return self.predict(img_file=image_list, **params)

+ 16 - 1
paddlers/tasks/base.py

@@ -86,7 +86,7 @@ class BaseModel(metaclass=ModelMeta):
         self.quant_config = None
         self.fixed_input_shape = None
 
-    def net_initialize(self,
+    def initialize_net(self,
                        pretrain_weights=None,
                        save_dir='.',
                        resume_checkpoint=None,
@@ -677,3 +677,18 @@ class BaseModel(metaclass=ModelMeta):
             raise ValueError(
                 f"Incorrect arrange mode! Expected {mode} but got {arrange_obj.mode}."
             )
+
+    def run(self, net, inputs, mode):
+        raise NotImplementedError
+
+    def train(self, *args, **kwargs):
+        raise NotImplementedError
+
+    def evaluate(self, *args, **kwargs):
+        raise NotImplementedError
+
+    def preprocess(self, images, transforms, to_tensor):
+        raise NotImplementedError
+
+    def postprocess(self, *args, **kwargs):
+        raise NotImplementedError

+ 33 - 84
paddlers/tasks/change_detector.py

@@ -35,6 +35,7 @@ from paddlers.utils.checkpoint import cd_pretrain_weights_dict
 from .base import BaseModel
 from .utils import seg_metrics as metrics
 from .utils.infer_nets import InferCDNet
+from .utils.slider_predict import slider_predict
 
 __all__ = [
     "CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT",
@@ -315,7 +316,7 @@ class BaseChangeDetector(BaseModel):
                         exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         is_backbone_weights = pretrain_weights == 'IMAGENET'
-        self.net_initialize(
+        self.initialize_net(
             pretrain_weights=pretrain_weights,
             save_dir=pretrained_dir,
             resume_checkpoint=resume_checkpoint,
@@ -581,96 +582,44 @@ class BaseChangeDetector(BaseModel):
         return prediction
 
     def slider_predict(self,
-                       img_file,
+                       img_files,
                        save_dir,
                        block_size,
                        overlap=36,
-                       transforms=None):
+                       transforms=None,
+                       invalid_value=255,
+                       merge_strategy='keep_last',
+                       batch_size=1,
+                       quiet=False):
         """
-        Do inference.
+        Do inference using sliding windows.
 
         Args:
-            img_file (tuple[str]): Tuple of image paths.
+            img_files (tuple[str]): Tuple of image paths.
             save_dir (str): Directory that contains saved geotiff file.
-            block_size (list[int] | tuple[int] | int, optional): Size of block.
-            overlap (list[int] | tuple[int] | int, optional): Overlap between two blocks. 
-                Defaults to 36.
-            transforms (paddlers.transforms.Compose|None, optional): Transforms for inputs.
-                If None, the transforms for evaluation process will be used. Defaults to None.
+            block_size (list[int] | tuple[int] | int):
+                Size of block. If `block_size` is a list or tuple, it should be in 
+                (W, H) format.
+            overlap (list[int] | tuple[int] | int, optional):
+                Overlap between two blocks. If `overlap` is a list or tuple, it should
+                be in (W, H) format. Defaults to 36.
+            transforms (paddlers.transforms.Compose|None, optional): Transforms for 
+                inputs. If None, the transforms for evaluation process will be used. 
+                Defaults to None.
+            invalid_value (int, optional): Value that marks invalid pixels in output 
+                image. Defaults to 255.
+            merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices
+                are {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' 
+                means keeping the values of the first and the last block in traversal 
+                order, respectively. 'accum' means determining the class of an overlapping 
+                pixel according to accumulated probabilities. Defaults to 'keep_last'.
+            batch_size (int, optional): Batch size used in inference. Defaults to 1.
+            quiet (bool, optional): If True, disable the progress bar. Defaults to False.
         """
 
-        try:
-            from osgeo import gdal
-        except:
-            import gdal
-
-        if not isinstance(img_file, tuple) or len(img_file) != 2:
-            raise ValueError("`img_file` must be a tuple of length 2.")
-        if isinstance(block_size, int):
-            block_size = (block_size, block_size)
-        elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
-            block_size = tuple(block_size)
-        else:
-            raise ValueError(
-                "`block_size` must be a tuple/list of length 2 or an integer.")
-        if isinstance(overlap, int):
-            overlap = (overlap, overlap)
-        elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
-            overlap = tuple(overlap)
-        else:
-            raise ValueError(
-                "`overlap` must be a tuple/list of length 2 or an integer.")
-
-        src1_data = gdal.Open(img_file[0])
-        src2_data = gdal.Open(img_file[1])
-        width = src1_data.RasterXSize
-        height = src1_data.RasterYSize
-        bands = src1_data.RasterCount
-
-        driver = gdal.GetDriverByName("GTiff")
-        file_name = osp.splitext(osp.normpath(img_file[0]).split(os.sep)[-1])[
-            0] + ".tif"
-        if not osp.exists(save_dir):
-            os.makedirs(save_dir)
-        save_file = osp.join(save_dir, file_name)
-        dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
-        dst_data.SetGeoTransform(src1_data.GetGeoTransform())
-        dst_data.SetProjection(src1_data.GetProjection())
-        band = dst_data.GetRasterBand(1)
-        band.WriteArray(255 * np.ones((height, width), dtype="uint8"))
-
-        step = np.array(block_size) - np.array(overlap)
-        for yoff in range(0, height, step[1]):
-            for xoff in range(0, width, step[0]):
-                xsize, ysize = block_size
-                if xoff + xsize > width:
-                    xsize = int(width - xoff)
-                if yoff + ysize > height:
-                    ysize = int(height - yoff)
-                im1 = src1_data.ReadAsArray(
-                    int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0))
-                im2 = src2_data.ReadAsArray(
-                    int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0))
-                # Fill
-                h, w = im1.shape[:2]
-                im1_fill = np.zeros(
-                    (block_size[1], block_size[0], bands), dtype=im1.dtype)
-                im2_fill = im1_fill.copy()
-                im1_fill[:h, :w, :] = im1
-                im2_fill[:h, :w, :] = im2
-                im_fill = (im1_fill, im2_fill)
-                # Predict
-                pred = self.predict(im_fill,
-                                    transforms)["label_map"].astype("uint8")
-                # Overlap
-                rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize)
-                mask = (rd_block == pred[:h, :w]) | (rd_block == 255)
-                temp = pred[:h, :w].copy()
-                temp[mask == False] = 0
-                band.WriteArray(temp, int(xoff), int(yoff))
-                dst_data.FlushCache()
-        dst_data = None
-        print("GeoTiff saved in {}.".format(save_file))
+        slider_predict(self.predict, img_files, save_dir, block_size, overlap,
+                       transforms, invalid_value, merge_strategy, batch_size,
+                       not quiet)
 
     def preprocess(self, images, transforms, to_tensor=True):
         self._check_transforms(transforms, 'test')
@@ -678,8 +627,8 @@ class BaseChangeDetector(BaseModel):
         batch_ori_shape = list()
         for im1, im2 in images:
             if isinstance(im1, str) or isinstance(im2, str):
-                im1 = decode_image(im1, to_rgb=False)
-                im2 = decode_image(im2, to_rgb=False)
+                im1 = decode_image(im1, read_raw=True)
+                im2 = decode_image(im2, read_raw=True)
             ori_shape = im1.shape[:2]
             # XXX: sample do not contain 'image_t1' and 'image_t2'.
             sample = {'image': im1, 'image2': im2}

+ 2 - 2
paddlers/tasks/classifier.py

@@ -288,7 +288,7 @@ class BaseClassifier(BaseModel):
                         exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         is_backbone_weights = False
-        self.net_initialize(
+        self.initialize_net(
             pretrain_weights=pretrain_weights,
             save_dir=pretrained_dir,
             resume_checkpoint=resume_checkpoint,
@@ -497,7 +497,7 @@ class BaseClassifier(BaseModel):
         batch_ori_shape = list()
         for im in images:
             if isinstance(im, str):
-                im = decode_image(im, to_rgb=False)
+                im = decode_image(im, read_raw=True)
             ori_shape = im.shape[:2]
             sample = {'image': im}
             im = transforms(sample)

+ 2 - 2
paddlers/tasks/object_detector.py

@@ -347,7 +347,7 @@ class BaseDetector(BaseModel):
                         "Invalid pretrained weights. Please specify a .pdparams file.",
                         exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
-        self.net_initialize(
+        self.initialize_net(
             pretrain_weights=pretrain_weights,
             save_dir=pretrained_dir,
             resume_checkpoint=resume_checkpoint,
@@ -617,7 +617,7 @@ class BaseDetector(BaseModel):
         batch_samples = list()
         for im in images:
             if isinstance(im, str):
-                im = decode_image(im, to_rgb=False)
+                im = decode_image(im, read_raw=True)
             sample = {'image': im}
             sample = transforms(sample)
             batch_samples.append(sample)

+ 2 - 2
paddlers/tasks/restorer.py

@@ -283,7 +283,7 @@ class BaseRestorer(BaseModel):
                         exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         is_backbone_weights = pretrain_weights == 'IMAGENET'
-        self.net_initialize(
+        self.initialize_net(
             pretrain_weights=pretrain_weights,
             save_dir=pretrained_dir,
             resume_checkpoint=resume_checkpoint,
@@ -481,7 +481,7 @@ class BaseRestorer(BaseModel):
         batch_tar_shape = list()
         for im in images:
             if isinstance(im, str):
-                im = decode_image(im, to_rgb=False)
+                im = decode_image(im, read_raw=True)
             ori_shape = im.shape[:2]
             sample = {'image': im}
             im = transforms(sample)[0]

+ 25 - 70
paddlers/tasks/segmenter.py

@@ -34,6 +34,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from .base import BaseModel
 from .utils import seg_metrics as metrics
 from .utils.infer_nets import InferSegNet
+from .utils.slider_predict import slider_predict
 
 __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"]
 
@@ -307,7 +308,7 @@ class BaseSegmenter(BaseModel):
                         exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         is_backbone_weights = pretrain_weights == 'IMAGENET'
-        self.net_initialize(
+        self.initialize_net(
             pretrain_weights=pretrain_weights,
             save_dir=pretrained_dir,
             resume_checkpoint=resume_checkpoint,
@@ -557,86 +558,40 @@ class BaseSegmenter(BaseModel):
                        save_dir,
                        block_size,
                        overlap=36,
-                       transforms=None):
+                       transforms=None,
+                       invalid_value=255,
+                       merge_strategy='keep_last',
+                       batch_size=1,
+                       quiet=False):
         """
-        Do inference.
+        Do inference using sliding windows.
 
         Args:
             img_file (str): Image path.
             save_dir (str): Directory that contains saved geotiff file.
             block_size (list[int] | tuple[int] | int):
-                Size of block.
+                Size of block. If `block_size` is list or tuple, it should be in 
+                (W, H) format.
             overlap (list[int] | tuple[int] | int, optional):
-                Overlap between two blocks. Defaults to 36.
+                Overlap between two blocks. If `overlap` is list or tuple, it should
+                be in (W, H) format. Defaults to 36.
             transforms (paddlers.transforms.Compose|None, optional): Transforms for 
                 inputs. If None, the transforms for evaluation process will be used. 
                 Defaults to None.
+            invalid_value (int, optional): Value that marks invalid pixels in output 
+                image. Defaults to 255.
+            merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices
+                are {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' 
+                means keeping the values of the first and the last block in traversal 
+                order, respectively. 'accum' means determining the class of an overlapping 
+                pixel according to accumulated probabilities. Defaults to 'keep_last'.
+            batch_size (int, optional): Batch size used in inference. Defaults to 1.
+            quiet (bool, optional): If True, disable the progress bar. Defaults to False.
         """
 
-        try:
-            from osgeo import gdal
-        except:
-            import gdal
-
-        if isinstance(block_size, int):
-            block_size = (block_size, block_size)
-        elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
-            block_size = tuple(block_size)
-        else:
-            raise ValueError(
-                "`block_size` must be a tuple/list of length 2 or an integer.")
-        if isinstance(overlap, int):
-            overlap = (overlap, overlap)
-        elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
-            overlap = tuple(overlap)
-        else:
-            raise ValueError(
-                "`overlap` must be a tuple/list of length 2 or an integer.")
-
-        src_data = gdal.Open(img_file)
-        width = src_data.RasterXSize
-        height = src_data.RasterYSize
-        bands = src_data.RasterCount
-
-        driver = gdal.GetDriverByName("GTiff")
-        file_name = osp.splitext(osp.normpath(img_file).split(os.sep)[-1])[
-            0] + ".tif"
-        if not osp.exists(save_dir):
-            os.makedirs(save_dir)
-        save_file = osp.join(save_dir, file_name)
-        dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
-        dst_data.SetGeoTransform(src_data.GetGeoTransform())
-        dst_data.SetProjection(src_data.GetProjection())
-        band = dst_data.GetRasterBand(1)
-        band.WriteArray(255 * np.ones((height, width), dtype="uint8"))
-
-        step = np.array(block_size) - np.array(overlap)
-        for yoff in range(0, height, step[1]):
-            for xoff in range(0, width, step[0]):
-                xsize, ysize = block_size
-                if xoff + xsize > width:
-                    xsize = int(width - xoff)
-                if yoff + ysize > height:
-                    ysize = int(height - yoff)
-                im = src_data.ReadAsArray(int(xoff), int(yoff), xsize,
-                                          ysize).transpose((1, 2, 0))
-                # Fill
-                h, w = im.shape[:2]
-                im_fill = np.zeros(
-                    (block_size[1], block_size[0], bands), dtype=im.dtype)
-                im_fill[:h, :w, :] = im
-                # Predict
-                pred = self.predict(im_fill,
-                                    transforms)["label_map"].astype("uint8")
-                # Overlap
-                rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize)
-                mask = (rd_block == pred[:h, :w]) | (rd_block == 255)
-                temp = pred[:h, :w].copy()
-                temp[mask == False] = 0
-                band.WriteArray(temp, int(xoff), int(yoff))
-                dst_data.FlushCache()
-        dst_data = None
-        print("GeoTiff saved in {}.".format(save_file))
+        slider_predict(self.predict, img_file, save_dir, block_size, overlap,
+                       transforms, invalid_value, merge_strategy, batch_size,
+                       not quiet)
 
     def preprocess(self, images, transforms, to_tensor=True):
         self._check_transforms(transforms, 'test')
@@ -644,7 +599,7 @@ class BaseSegmenter(BaseModel):
         batch_ori_shape = list()
         for im in images:
             if isinstance(im, str):
-                im = decode_image(im, to_rgb=False)
+                im = decode_image(im, read_raw=True)
             ori_shape = im.shape[:2]
             sample = {'image': im}
             im = transforms(sample)[0]

+ 437 - 0
paddlers/tasks/utils/slider_predict.py

@@ -0,0 +1,437 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+import math
+from abc import ABCMeta, abstractmethod
+from collections import Counter, defaultdict
+
+import numpy as np
+from tqdm import tqdm
+
+import paddlers.utils.logging as logging
+
+
+class Cache(metaclass=ABCMeta):
+    @abstractmethod
+    def get_block(self, i_st, j_st, h, w):
+        pass
+
+
+class SlowCache(Cache):
+    def __init__(self):
+        super(SlowCache, self).__init__()
+        self.cache = defaultdict(Counter)
+
+    def push_pixel(self, i, j, l):
+        self.cache[(i, j)][l] += 1
+
+    def push_block(self, i_st, j_st, h, w, data):
+        for i in range(0, h):
+            for j in range(0, w):
+                self.push_pixel(i_st + i, j_st + j, data[i, j])
+
+    def pop_pixel(self, i, j):
+        self.cache.pop((i, j))
+
+    def pop_block(self, i_st, j_st, h, w):
+        for i in range(0, h):
+            for j in range(0, w):
+                self.pop_pixel(i_st + i, j_st + j)
+
+    def get_pixel(self, i, j):
+        winners = self.cache[(i, j)].most_common(1)
+        winner = winners[0]
+        return winner[0]
+
+    def get_block(self, i_st, j_st, h, w):
+        block = []
+        for i in range(i_st, i_st + h):
+            row = []
+            for j in range(j_st, j_st + w):
+                row.append(self.get_pixel(i, j))
+            block.append(row)
+        return np.asarray(block)
+
+
+class ProbCache(Cache):
+    def __init__(self, h, w, ch, cw, sh, sw, dtype=np.float32, order='c'):
+        super(ProbCache, self).__init__()
+        self.cache = None
+        self.h = h
+        self.w = w
+        self.ch = ch
+        self.cw = cw
+        self.sh = sh
+        self.sw = sw
+        if not issubclass(dtype, np.floating):
+            raise TypeError("`dtype` must be one of the floating types.")
+        self.dtype = dtype
+        order = order.lower()
+        if order not in ('c', 'f'):
+            raise ValueError("`order` other than 'c' and 'f' is not supported.")
+        self.order = order
+
+    def _alloc_memory(self, nc):
+        if self.order == 'c':
+            # Colomn-first order (C-style)
+            #
+            # <-- cw -->
+            # |--------|---------------------|^    ^
+            # |                              ||    | sh
+            # |--------|---------------------|| ch v
+            # |                              || 
+            # |--------|---------------------|v
+            # <------------ w --------------->
+            self.cache = np.zeros((self.ch, self.w, nc), dtype=self.dtype)
+        elif self.order == 'f':
+            # Row-first order (Fortran-style)
+            #
+            # <-- sw -->
+            # <---- cw ---->
+            # |--------|---|^   ^
+            # |        |   ||   |
+            # |        |   ||   ch
+            # |        |   ||   |
+            # |--------|---|| h v
+            # |        |   ||
+            # |        |   ||
+            # |        |   ||
+            # |--------|---|v
+            self.cache = np.zeros((self.h, self.cw, nc), dtype=self.dtype)
+
+    def update_block(self, i_st, j_st, h, w, prob_map):
+        if self.cache is None:
+            nc = prob_map.shape[2]
+            # Lazy allocation of memory
+            self._alloc_memory(nc)
+        self.cache[i_st:i_st + h, j_st:j_st + w] += prob_map
+
+    def roll_cache(self, shift):
+        if self.order == 'c':
+            self.cache[:-shift] = self.cache[shift:]
+            self.cache[-shift:, :] = 0
+        elif self.order == 'f':
+            self.cache[:, :-shift] = self.cache[:, shift:]
+            self.cache[:, -shift:] = 0
+
+    def get_block(self, i_st, j_st, h, w):
+        return np.argmax(self.cache[i_st:i_st + h, j_st:j_st + w], axis=2)
+
+
+class OverlapProcessor(metaclass=ABCMeta):
+    def __init__(self, h, w, ch, cw, sh, sw):
+        super(OverlapProcessor, self).__init__()
+        self.h = h
+        self.w = w
+        self.ch = ch
+        self.cw = cw
+        self.sh = sh
+        self.sw = sw
+
+    @abstractmethod
+    def process_pred(self, out, xoff, yoff):
+        pass
+
+
+class KeepFirstProcessor(OverlapProcessor):
+    def __init__(self, h, w, ch, cw, sh, sw, ds, inval=255):
+        super(KeepFirstProcessor, self).__init__(h, w, ch, cw, sh, sw)
+        self.ds = ds
+        self.inval = inval
+
+    def process_pred(self, out, xoff, yoff):
+        pred = out['label_map']
+        pred = pred[:self.ch, :self.cw]
+        rd_block = self.ds.ReadAsArray(xoff, yoff, self.cw, self.ch)
+        mask = rd_block != self.inval
+        pred = np.where(mask, rd_block, pred)
+        return pred
+
+
+class KeepLastProcessor(OverlapProcessor):
+    def process_pred(self, out, xoff, yoff):
+        pred = out['label_map']
+        pred = pred[:self.ch, :self.cw]
+        return pred
+
+
+class AccumProcessor(OverlapProcessor):
+    def __init__(self,
+                 h,
+                 w,
+                 ch,
+                 cw,
+                 sh,
+                 sw,
+                 dtype=np.float16,
+                 assign_weight=True):
+        super(AccumProcessor, self).__init__(h, w, ch, cw, sh, sw)
+        self.cache = ProbCache(h, w, ch, cw, sh, sw, dtype=dtype, order='c')
+        self.prev_yoff = None
+        self.assign_weight = assign_weight
+
+    def process_pred(self, out, xoff, yoff):
+        if self.prev_yoff is not None and yoff != self.prev_yoff:
+            if yoff < self.prev_yoff:
+                raise RuntimeError
+            self.cache.roll_cache(yoff - self.prev_yoff)
+        pred = out['label_map']
+        pred = pred[:self.ch, :self.cw]
+        prob = out['score_map']
+        prob = prob[:self.ch, :self.cw]
+        if self.assign_weight:
+            prob = assign_border_weights(prob, border_ratio=0.25, inplace=True)
+        self.cache.update_block(0, xoff, self.ch, self.cw, prob)
+        pred = self.cache.get_block(0, xoff, self.ch, self.cw)
+        self.prev_yoff = yoff
+        return pred
+
+
+def assign_border_weights(array, weight=0.5, border_ratio=0.25, inplace=True):
+    if not inplace:
+        array = array.copy()
+    h, w = array.shape[:2]
+    hm, wm = int(h * border_ratio), int(w * border_ratio)
+    array[:hm] *= weight
+    array[-hm:] *= weight
+    array[:, :wm] *= weight
+    array[:, -wm:] *= weight
+    return array
+
+
+def read_block(ds,
+               xoff,
+               yoff,
+               xsize,
+               ysize,
+               tar_xsize=None,
+               tar_ysize=None,
+               pad_val=0):
+    if tar_xsize is None:
+        tar_xsize = xsize
+    if tar_ysize is None:
+        tar_ysize = ysize
+    # Read data from dataset
+    block = ds.ReadAsArray(xoff, yoff, xsize, ysize)
+    c, real_ysize, real_xsize = block.shape
+    assert real_ysize == ysize and real_xsize == xsize
+    # [c, h, w] -> [h, w, c]
+    block = block.transpose((1, 2, 0))
+    if (real_ysize, real_xsize) != (tar_ysize, tar_xsize):
+        if real_ysize >= tar_ysize or real_xsize >= tar_xsize:
+            raise ValueError
+        padded_block = np.full(
+            (tar_ysize, tar_xsize, c), fill_value=pad_val, dtype=block.dtype)
+        # Fill
+        padded_block[:real_ysize, :real_xsize] = block
+        return padded_block
+    else:
+        return block
+
+
+def slider_predict(predict_func,
+                   img_file,
+                   save_dir,
+                   block_size,
+                   overlap,
+                   transforms,
+                   invalid_value,
+                   merge_strategy,
+                   batch_size,
+                   show_progress=False):
+    """
+    Do inference using sliding windows.
+
+    Args:
+        predict_func (callable): A callable object that makes the prediction.
+        img_file (str|tuple[str]): Image path(s).
+        save_dir (str): Directory that contains saved geotiff file.
+        block_size (list[int] | tuple[int] | int):
+            Size of block. If `block_size` is list or tuple, it should be in 
+            (W, H) format.
+        overlap (list[int] | tuple[int] | int):
+            Overlap between two blocks. If `overlap` is list or tuple, it should
+            be in (W, H) format.
+        transforms (paddlers.transforms.Compose|None): Transforms for inputs. If 
+            None, the transforms for evaluation process will be used. 
+        invalid_value (int): Value that marks invalid pixels in output image. 
+            Defaults to 255.
+        merge_strategy (str): Strategy to merge overlapping blocks. Choices are 
+            {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' 
+            means keeping the values of the first and the last block in 
+            traversal order, respectively. 'accum' means determining the class 
+            of an overlapping pixel according to accumulated probabilities.
+        batch_size (int): Batch size used in inference.
+        show_progress (bool, optional): Whether to show prediction progress with a 
+            progress bar. Defaults to True.
+    """
+
+    try:
+        from osgeo import gdal
+    except:
+        import gdal
+
+    if isinstance(block_size, int):
+        block_size = (block_size, block_size)
+    elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
+        block_size = tuple(block_size)
+    else:
+        raise ValueError(
+            "`block_size` must be a tuple/list of length 2 or an integer.")
+    if isinstance(overlap, int):
+        overlap = (overlap, overlap)
+    elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
+        overlap = tuple(overlap)
+    else:
+        raise ValueError(
+            "`overlap` must be a tuple/list of length 2 or an integer.")
+
+    step = np.array(
+        block_size, dtype=np.int32) - np.array(
+            overlap, dtype=np.int32)
+    if step[0] == 0 or step[1] == 0:
+        raise ValueError("`block_size` and `overlap` should not be equal.")
+
+    if isinstance(img_file, tuple):
+        if len(img_file) != 2:
+            raise ValueError("Tuple `img_file` must have the length of two.")
+        # Assume that two input images have the same size
+        src_data = gdal.Open(img_file[0])
+        src2_data = gdal.Open(img_file[1])
+        # Output name is the same as the name of the first image
+        file_name = osp.basename(osp.normpath(img_file[0]))
+    else:
+        src_data = gdal.Open(img_file)
+        file_name = osp.basename(osp.normpath(img_file))
+
+    # Get size of original raster
+    width = src_data.RasterXSize
+    height = src_data.RasterYSize
+    bands = src_data.RasterCount
+
+    # XXX: GDAL read behavior conforms to paddlers.transforms.decode_image(read_raw=True)
+    # except for SAR images.
+    if bands == 1:
+        logging.warning(
+            f"Detected `bands=1`. Please note that currently `slider_predict()` does not properly handle SAR images."
+        )
+
+    if block_size[0] > width or block_size[1] > height:
+        raise ValueError("`block_size` should not be larger than image size.")
+
+    driver = gdal.GetDriverByName("GTiff")
+    if not osp.exists(save_dir):
+        os.makedirs(save_dir)
+    # Replace extension name with '.tif'
+    file_name = osp.splitext(file_name)[0] + ".tif"
+    save_file = osp.join(save_dir, file_name)
+    dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
+
+    # Set meta-information
+    dst_data.SetGeoTransform(src_data.GetGeoTransform())
+    dst_data.SetProjection(src_data.GetProjection())
+
+    # Initialize raster with `invalid_value`
+    band = dst_data.GetRasterBand(1)
+    band.WriteArray(
+        np.full(
+            (height, width), fill_value=invalid_value, dtype="uint8"))
+
+    if overlap == (0, 0) or block_size == (width, height):
+        # When there is no overlap or the whole image is used as input, 
+        # use 'keep_last' strategy as it introduces least overheads
+        merge_strategy = 'keep_last'
+
+    if merge_strategy == 'keep_first':
+        overlap_processor = KeepFirstProcessor(
+            height,
+            width,
+            *block_size[::-1],
+            *step[::-1],
+            band,
+            inval=invalid_value)
+    elif merge_strategy == 'keep_last':
+        overlap_processor = KeepLastProcessor(height, width, *block_size[::-1],
+                                              *step[::-1])
+    elif merge_strategy == 'accum':
+        overlap_processor = AccumProcessor(height, width, *block_size[::-1],
+                                           *step[::-1])
+    else:
+        raise ValueError("{} is not a supported stragegy for block merging.".
+                         format(merge_strategy))
+
+    xsize, ysize = block_size
+    num_blocks = math.ceil(height / step[1]) * math.ceil(width / step[0])
+    cnt = 0
+    if show_progress:
+        pb = tqdm(total=num_blocks)
+    batch_data = []
+    batch_offsets = []
+    for yoff in range(0, height, step[1]):
+        for xoff in range(0, width, step[0]):
+            if xoff + xsize > width:
+                xoff = width - xsize
+                is_end_of_row = True
+            else:
+                is_end_of_row = False
+            if yoff + ysize > height:
+                yoff = height - ysize
+                is_end_of_col = True
+            else:
+                is_end_of_col = False
+
+            # Read
+            im = read_block(src_data, xoff, yoff, xsize, ysize)
+
+            if isinstance(img_file, tuple):
+                im2 = read_block(src2_data, xoff, yoff, xsize, ysize)
+                batch_data.append((im, im2))
+            else:
+                batch_data.append(im)
+
+            batch_offsets.append((xoff, yoff))
+
+            len_batch = len(batch_data)
+
+            if is_end_of_row and is_end_of_col and len_batch < batch_size:
+                # Pad `batch_data` by repeating the last element
+                batch_data = batch_data + [batch_data[-1]] * (batch_size -
+                                                              len_batch)
+                # While keeping `len(batch_offsets)` the number of valid elements in the batch 
+
+            if len(batch_data) == batch_size:
+                # Predict
+                batch_out = predict_func(batch_data, transforms=transforms)
+
+                for out, (xoff_, yoff_) in zip(batch_out, batch_offsets):
+                    # Get processed result
+                    pred = overlap_processor.process_pred(out, xoff_, yoff_)
+                    # Write to file
+                    band.WriteArray(pred, xoff_, yoff_)
+
+                dst_data.FlushCache()
+                batch_data.clear()
+                batch_offsets.clear()
+
+            cnt += 1
+
+            if show_progress:
+                pb.update(1)
+                pb.set_description("{} out of {} blocks processed.".format(
+                    cnt, num_blocks))
+
+    dst_data = None
+    logging.info("GeoTiff file saved in {}.".format(save_file))

+ 14 - 3
paddlers/transforms/__init__.py

@@ -25,7 +25,9 @@ def decode_image(im_path,
                  to_uint8=True,
                  decode_bgr=True,
                  decode_sar=True,
-                 read_geo_info=False):
+                 read_geo_info=False,
+                 use_stretch=False,
+                 read_raw=False):
     """
     Decode an image.
     
@@ -37,11 +39,16 @@ def decode_image(im_path,
             uint8 type. Defaults to True.
         decode_bgr (bool, optional): If True, automatically interpret a non-geo 
             image (e.g. jpeg images) as a BGR image. Defaults to True.
-        decode_sar (bool, optional): If True, automatically interpret a two-channel 
+        decode_sar (bool, optional): If True, automatically interpret a single-channel 
             geo image (e.g. geotiff images) as a SAR image, set this argument to 
             True. Defaults to True.
         read_geo_info (bool, optional): If True, read geographical information from 
             the image. Deafults to False.
+        use_stretch (bool, optional): Whether to apply 2% linear stretch. Valid only if 
+            `to_uint8` is True. Defaults to False.
+        read_raw (bool, optional): If True, equivalent to setting `to_rgb` and `to_uint8`
+            to False. Setting `read_raw` takes precedence over setting `to_rgb` and 
+            `to_uint8`. Defaults to False.
     
     Returns:
         np.ndarray|tuple: If `read_geo_info` is False, return the decoded image. 
@@ -53,12 +60,16 @@ def decode_image(im_path,
     # Do a presence check. osp.exists() assumes `im_path` is a path-like object.
     if not osp.exists(im_path):
         raise ValueError(f"{im_path} does not exist!")
+    if read_raw:
+        to_rgb = False
+        to_uint8 = False
     decoder = T.DecodeImg(
         to_rgb=to_rgb,
         to_uint8=to_uint8,
         decode_bgr=decode_bgr,
         decode_sar=decode_sar,
-        read_geo_info=read_geo_info)
+        read_geo_info=read_geo_info,
+        use_stretch=use_stretch)
     # Deepcopy to avoid inplace modification
     sample = {'image': copy.deepcopy(im_path)}
     sample = decoder(sample)

+ 3 - 3
paddlers/transforms/functions.py

@@ -382,13 +382,13 @@ def resize_rle(rle, im_h, im_w, im_scale_x, im_scale_y, interp):
     return rle
 
 
-def to_uint8(im, is_linear=False):
+def to_uint8(im, stretch=False):
     """
     Convert raster data to uint8 type.
     
     Args:
         im (np.ndarray): Input raster image.
-        is_linear (bool, optional): Use 2% linear stretch or not. Default is False.
+        stretch (bool, optional): Use 2% linear stretch or not. Default is False.
 
     Returns:
         np.ndarray: Image data with unit8 type.
@@ -430,7 +430,7 @@ def to_uint8(im, is_linear=False):
     dtype = im.dtype.name
     if dtype != "uint8":
         im = _sample_norm(im)
-    if is_linear:
+    if stretch:
         im = _two_percent_linear(im)
     return im
 

+ 9 - 5
paddlers/transforms/operators.py

@@ -203,11 +203,13 @@ class DecodeImg(Transform):
             uint8 type. Defaults to True.
         decode_bgr (bool, optional): If True, automatically interpret a non-geo image 
             (e.g., jpeg images) as a BGR image. Defaults to True.
-        decode_sar (bool, optional): If True, automatically interpret a two-channel 
+        decode_sar (bool, optional): If True, automatically interpret a single-channel 
             geo image (e.g. geotiff images) as a SAR image, set this argument to 
             True. Defaults to True.
         read_geo_info (bool, optional): If True, read geographical information from 
             the image. Deafults to False.
+        use_stretch (bool, optional): Whether to apply 2% linear stretch. Valid only if 
+            `to_uint8` is True. Defaults to False.
     """
 
     def __init__(self,
@@ -215,13 +217,15 @@ class DecodeImg(Transform):
                  to_uint8=True,
                  decode_bgr=True,
                  decode_sar=True,
-                 read_geo_info=False):
+                 read_geo_info=False,
+                 use_stretch=False):
         super(DecodeImg, self).__init__()
         self.to_rgb = to_rgb
         self.to_uint8 = to_uint8
         self.decode_bgr = decode_bgr
         self.decode_sar = decode_sar
-        self.read_geo_info = False
+        self.read_geo_info = read_geo_info
+        self.use_stretch = use_stretch
 
     def read_img(self, img_path):
         img_format = imghdr.what(img_path)
@@ -251,7 +255,7 @@ class DecodeImg(Transform):
                     im_data = im_data.transpose((1, 2, 0))
             if self.read_geo_info:
                 geo_trans = dataset.GetGeoTransform()
-                geo_proj = dataset.GetGeoProjection()
+                geo_proj = dataset.GetProjection()
         elif img_format in ['jpeg', 'bmp', 'png', 'jpg']:
             if self.decode_bgr:
                 im_data = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
@@ -288,7 +292,7 @@ class DecodeImg(Transform):
             image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
         if self.to_uint8:
-            image = to_uint8(image)
+            image = to_uint8(image, stretch=self.use_stretch)
 
         if self.read_geo_info:
             return image, geo_info_dict

+ 17 - 12
tests/deploy/test_predictor.py

@@ -151,8 +151,8 @@ class TestCDPredictor(TestPredictor):
 
         # Single input (ndarrays)
         input_ = (decode_image(
-            t1_path, to_rgb=False), decode_image(
-                t2_path, to_rgb=False))  # Reuse the name `input_`
+            t1_path, read_raw=True), decode_image(
+                t2_path, read_raw=True))  # Reuse the name `input_`
         out_single_array_p = predictor.predict(input_, transforms=transforms)
         self.check_dict_equal(out_single_array_p, out_single_file_p)
         out_single_array_t = trainer.predict(input_, transforms=transforms)
@@ -175,8 +175,9 @@ class TestCDPredictor(TestPredictor):
 
         # Multiple inputs (ndarrays)
         input_ = [(decode_image(
-            t1_path, to_rgb=False), decode_image(
-                t2_path, to_rgb=False))] * num_inputs  # Reuse the name `input_`
+            t1_path, read_raw=True), decode_image(
+                t2_path,
+                read_raw=True))] * num_inputs  # Reuse the name `input_`
         out_multi_array_p = predictor.predict(input_, transforms=transforms)
         self.assertEqual(len(out_multi_array_p), num_inputs)
         out_multi_array_t = trainer.predict(input_, transforms=transforms)
@@ -217,7 +218,7 @@ class TestClasPredictor(TestPredictor):
 
         # Single input (ndarray)
         input_ = decode_image(
-            single_input, to_rgb=False)  # Reuse the name `input_`
+            single_input, read_raw=True)  # Reuse the name `input_`
         out_single_array_p = predictor.predict(input_, transforms=transforms)
         self.check_dict_equal(out_single_array_p, out_single_file_p)
         out_single_array_t = trainer.predict(input_, transforms=transforms)
@@ -241,7 +242,8 @@ class TestClasPredictor(TestPredictor):
 
         # Multiple inputs (ndarrays)
         input_ = [decode_image(
-            single_input, to_rgb=False)] * num_inputs  # Reuse the name `input_`
+            single_input,
+            read_raw=True)] * num_inputs  # Reuse the name `input_`
         out_multi_array_p = predictor.predict(input_, transforms=transforms)
         self.assertEqual(len(out_multi_array_p), num_inputs)
         out_multi_array_t = trainer.predict(input_, transforms=transforms)
@@ -282,7 +284,7 @@ class TestDetPredictor(TestPredictor):
 
         # Single input (ndarray)
         input_ = decode_image(
-            single_input, to_rgb=False)  # Reuse the name `input_`
+            single_input, read_raw=True)  # Reuse the name `input_`
         predictor.predict(input_, transforms=transforms)
         trainer.predict(input_, transforms=transforms)
         out_single_array_list_p = predictor.predict(
@@ -301,7 +303,8 @@ class TestDetPredictor(TestPredictor):
 
         # Multiple inputs (ndarrays)
         input_ = [decode_image(
-            single_input, to_rgb=False)] * num_inputs  # Reuse the name `input_`
+            single_input,
+            read_raw=True)] * num_inputs  # Reuse the name `input_`
         out_multi_array_p = predictor.predict(input_, transforms=transforms)
         self.assertEqual(len(out_multi_array_p), num_inputs)
         out_multi_array_t = trainer.predict(input_, transforms=transforms)
@@ -343,7 +346,7 @@ class TestResPredictor(TestPredictor):
 
         # Single input (ndarray)
         input_ = decode_image(
-            single_input, to_rgb=False)  # Reuse the name `input_`
+            single_input, read_raw=True)  # Reuse the name `input_`
         predictor.predict(input_, transforms=transforms)
         trainer.predict(input_, transforms=transforms)
         out_single_array_list_p = predictor.predict(
@@ -362,7 +365,8 @@ class TestResPredictor(TestPredictor):
 
         # Multiple inputs (ndarrays)
         input_ = [decode_image(
-            single_input, to_rgb=False)] * num_inputs  # Reuse the name `input_`
+            single_input,
+            read_raw=True)] * num_inputs  # Reuse the name `input_`
         out_multi_array_p = predictor.predict(input_, transforms=transforms)
         self.assertEqual(len(out_multi_array_p), num_inputs)
         out_multi_array_t = trainer.predict(input_, transforms=transforms)
@@ -400,7 +404,7 @@ class TestSegPredictor(TestPredictor):
 
         # Single input (ndarray)
         input_ = decode_image(
-            single_input, to_rgb=False)  # Reuse the name `input_`
+            single_input, read_raw=True)  # Reuse the name `input_`
         out_single_array_p = predictor.predict(input_, transforms=transforms)
         self.check_dict_equal(out_single_array_p, out_single_file_p)
         out_single_array_t = trainer.predict(input_, transforms=transforms)
@@ -423,7 +427,8 @@ class TestSegPredictor(TestPredictor):
 
         # Multiple inputs (ndarrays)
         input_ = [decode_image(
-            single_input, to_rgb=False)] * num_inputs  # Reuse the name `input_`
+            single_input,
+            read_raw=True)] * num_inputs  # Reuse the name `input_`
         out_multi_array_p = predictor.predict(input_, transforms=transforms)
         self.assertEqual(len(out_multi_array_p), num_inputs)
         out_multi_array_t = trainer.predict(input_, transforms=transforms)

+ 1 - 0
tests/fast_tests.py

@@ -13,4 +13,5 @@
 # limitations under the License.
 
 from rs_models import *
+from tasks import *
 from transforms import *

+ 2 - 0
tests/tasks/__init__.py

@@ -11,3 +11,5 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
+from .test_slider_predict import *

+ 212 - 0
tests/tasks/test_slider_predict.py

@@ -0,0 +1,212 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path as osp
+import tempfile
+
+import paddlers as pdrs
+import paddlers.transforms as T
+from testing_utils import CommonTest
+
+
+class _TestSliderPredictNamespace:
+    class TestSliderPredict(CommonTest):
+        def test_blocksize_and_overlap_whole(self):
+            # Original image size (256, 256)
+            with tempfile.TemporaryDirectory() as td:
+                # Whole-image inference using predict()
+                pred_whole = self.model.predict(self.image_path,
+                                                self.transforms)
+                pred_whole = pred_whole['label_map']
+
+                # Whole-image inference using slider_predict()
+                save_dir = osp.join(td, 'pred1')
+                self.model.slider_predict(self.image_path, save_dir, 256, 0,
+                                          self.transforms)
+                pred1 = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+                self.check_output_equal(pred1.shape, pred_whole.shape)
+
+                # `block_size` == `overlap`
+                save_dir = osp.join(td, 'pred2')
+                with self.assertRaises(ValueError):
+                    self.model.slider_predict(self.image_path, save_dir, 128,
+                                              128, self.transforms)
+
+                # `block_size` is a tuple
+                save_dir = osp.join(td, 'pred3')
+                self.model.slider_predict(self.image_path, save_dir, (128, 32),
+                                          0, self.transforms)
+                pred3 = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+                self.check_output_equal(pred3.shape, pred_whole.shape)
+
+                # `block_size` and `overlap` are both tuples
+                save_dir = osp.join(td, 'pred4')
+                self.model.slider_predict(self.image_path, save_dir, (128, 100),
+                                          (10, 5), self.transforms)
+                pred4 = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+                self.check_output_equal(pred4.shape, pred_whole.shape)
+
+                # `block_size` larger than image size
+                save_dir = osp.join(td, 'pred5')
+                with self.assertRaises(ValueError):
+                    self.model.slider_predict(self.image_path, save_dir, 512, 0,
+                                              self.transforms)
+
+        def test_merge_strategy(self):
+            with tempfile.TemporaryDirectory() as td:
+                # Whole-image inference using predict()
+                pred_whole = self.model.predict(self.image_path,
+                                                self.transforms)
+                pred_whole = pred_whole['label_map']
+
+                # 'keep_first'
+                save_dir = osp.join(td, 'keep_first')
+                self.model.slider_predict(
+                    self.image_path,
+                    save_dir,
+                    128,
+                    64,
+                    self.transforms,
+                    merge_strategy='keep_first')
+                pred_keepfirst = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+                self.check_output_equal(pred_keepfirst.shape, pred_whole.shape)
+
+                # 'keep_last'
+                save_dir = osp.join(td, 'keep_last')
+                self.model.slider_predict(
+                    self.image_path,
+                    save_dir,
+                    128,
+                    64,
+                    self.transforms,
+                    merge_strategy='keep_last')
+                pred_keeplast = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+                self.check_output_equal(pred_keeplast.shape, pred_whole.shape)
+
+                # 'accum'
+                save_dir = osp.join(td, 'accum')
+                self.model.slider_predict(
+                    self.image_path,
+                    save_dir,
+                    128,
+                    64,
+                    self.transforms,
+                    merge_strategy='accum')
+                pred_accum = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+                self.check_output_equal(pred_accum.shape, pred_whole.shape)
+
+        def test_geo_info(self):
+            with tempfile.TemporaryDirectory() as td:
+                _, geo_info_in = T.decode_image(
+                    self.ref_path, read_geo_info=True)
+                self.model.slider_predict(self.image_path, td, 128, 0,
+                                          self.transforms)
+                _, geo_info_out = T.decode_image(
+                    osp.join(td, self.basename), read_geo_info=True)
+                self.assertEqual(geo_info_out['geo_trans'],
+                                 geo_info_in['geo_trans'])
+                self.assertEqual(geo_info_out['geo_proj'],
+                                 geo_info_in['geo_proj'])
+
+        def test_batch_size(self):
+            with tempfile.TemporaryDirectory() as td:
+                # batch_size = 1
+                save_dir = osp.join(td, 'bs1')
+                self.model.slider_predict(
+                    self.image_path,
+                    save_dir,
+                    128,
+                    64,
+                    self.transforms,
+                    merge_strategy='keep_first',
+                    batch_size=1)
+                pred_bs1 = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+
+                # batch_size = 4
+                save_dir = osp.join(td, 'bs4')
+                self.model.slider_predict(
+                    self.image_path,
+                    save_dir,
+                    128,
+                    64,
+                    self.transforms,
+                    merge_strategy='keep_first',
+                    batch_size=4)
+                pred_bs4 = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+                self.check_output_equal(pred_bs4, pred_bs1)
+
+                # batch_size = 8
+                save_dir = osp.join(td, 'bs4')
+                self.model.slider_predict(
+                    self.image_path,
+                    save_dir,
+                    128,
+                    64,
+                    self.transforms,
+                    merge_strategy='keep_first',
+                    batch_size=8)
+                pred_bs8 = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+                self.check_output_equal(pred_bs8, pred_bs1)
+
+
+class TestSegSliderPredict(_TestSliderPredictNamespace.TestSliderPredict):
+    def setUp(self):
+        self.model = pdrs.tasks.seg.UNet(in_channels=10)
+        self.transforms = T.Compose([
+            T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10),
+            T.ArrangeSegmenter('test')
+        ])
+        self.image_path = "data/ssst/multispectral.tif"
+        self.ref_path = self.image_path
+        self.basename = osp.basename(self.ref_path)
+
+
+class TestCDSliderPredict(_TestSliderPredictNamespace.TestSliderPredict):
+    def setUp(self):
+        self.model = pdrs.tasks.cd.BIT(in_channels=10)
+        self.transforms = T.Compose([
+            T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10),
+            T.ArrangeChangeDetector('test')
+        ])
+        self.image_path = ("data/ssmt/multispectral_t1.tif",
+                           "data/ssmt/multispectral_t2.tif")
+        self.ref_path = self.image_path[0]
+        self.basename = osp.basename(self.ref_path)