Procházet zdrojové kódy

Merge branch 'develop' into add_fft

Bobholamovic před 2 roky
rodič
revize
0962fb0ac3

+ 4 - 2
docs/apis/data.md

@@ -134,10 +134,12 @@
 |-------|----|--------|-----|
 |-------|----|--------|-----|
 |`im_path`|`str`|输入图像路径。||
 |`im_path`|`str`|输入图像路径。||
 |`to_rgb`|`bool`|若为`True`,则执行BGR到RGB格式的转换。|`True`|
 |`to_rgb`|`bool`|若为`True`,则执行BGR到RGB格式的转换。|`True`|
-|`to_uint8`|`bool`|若为`True`,则将读取的像数据量化并转换为uint8类型。|`True`|
+|`to_uint8`|`bool`|若为`True`,则将读取的像数据量化并转换为uint8类型。|`True`|
 |`decode_bgr`|`bool`|若为`True`,则自动将非地学格式影像(如jpeg影像)解析为BGR格式。|`True`|
 |`decode_bgr`|`bool`|若为`True`,则自动将非地学格式影像(如jpeg影像)解析为BGR格式。|`True`|
-|`decode_sar`|`bool`|若为`True`,则自动将2通道的地学格式影像(如GeoTiff影像)作为SAR影像解析。|`True`|
+|`decode_sar`|`bool`|若为`True`,则自动将通道的地学格式影像(如GeoTiff影像)作为SAR影像解析。|`True`|
 |`read_geo_info`|`bool`|若为`True`,则从影像中读取地理信息。|`False`|
 |`read_geo_info`|`bool`|若为`True`,则从影像中读取地理信息。|`False`|
+|`use_stretch`|`bool`|是否对影像亮度进行2%线性拉伸。仅当`to_uint8`为`True`时有效。|`False`|
+|`read_raw`|`bool`|若为`True`,等价于指定`to_rgb`和`to_uint8`为`False`,且该参数的优先级高于上述参数。|`False`|
 
 
 返回格式如下:
 返回格式如下:
 
 

+ 17 - 4
docs/apis/infer.md

@@ -155,7 +155,11 @@ def slider_predict(self,
                    save_dir,
                    save_dir,
                    block_size,
                    block_size,
                    overlap=36,
                    overlap=36,
-                   transforms=None):
+                   transforms=None,
+                   invalid_value=255,
+                   merge_strategy='keep_last',
+                   batch_size=1,
+                   quiet=False):
 ```
 ```
 
 
 输入参数列表:
 输入参数列表:
@@ -164,11 +168,15 @@ def slider_predict(self,
 |-------|----|--------|-----|
 |-------|----|--------|-----|
 |`img_file`|`str`|输入影像路径。||
 |`img_file`|`str`|输入影像路径。||
 |`save_dir`|`str`|预测结果输出路径。||
 |`save_dir`|`str`|预测结果输出路径。||
-|`block_size`|`list[int]` \| `tuple[int]` \| `int`|滑窗的窗口大小(以列表或元组指定长、宽或以一个整数指定相同的宽)。||
-|`overlap`|`list[int]` \| `tuple[int]` \| `int`|滑窗的滑动步长(以列表或元组指定长、宽或以一个整数指定相同的宽)。|`36`|
+|`block_size`|`list[int]` \| `tuple[int]` \| `int`|滑窗的窗口大小(以列表或元组指定宽度、高度或以一个整数指定相同的宽)。||
+|`overlap`|`list[int]` \| `tuple[int]` \| `int`|滑窗的滑动步长(以列表或元组指定宽度、高度或以一个整数指定相同的宽)。|`36`|
 |`transforms`|`paddlers.transforms.Compose` \| `None`|对输入数据应用的数据变换算子。若为`None`,则使用训练器在验证阶段使用的数据变换算子。|`None`|
 |`transforms`|`paddlers.transforms.Compose` \| `None`|对输入数据应用的数据变换算子。若为`None`,则使用训练器在验证阶段使用的数据变换算子。|`None`|
+|`invalid_value`|`int`|输出影像中用于标记无效像素的数值。|`255`|
+|`merge_strategy`|`str`|合并滑窗重叠区域使用的策略。`'keep_first'`表示保留遍历顺序(从左至右,从上往下,列优先)最靠前的窗口的预测类别;`'keep_last'`表示保留遍历顺序最靠后的窗口的预测类别;`'accum'`表示通过将各窗口在重叠区域给出的预测概率累加,计算最终预测类别。需要注意的是,在对大尺寸影像进行`overlap`较大的密集推理时,使用`'accum'`策略可能导致较长的推理时间,但一般能够在窗口交界部分取得更好的表现。|`'keep_last'`|
+|`batch_size`|`int`|预测时使用的mini-batch大小。|`1`|
+|`quiet`|`bool`|若为`True`,不显示预测进度。|`False`|
 
 
-变化检测任务的滑窗推理API与图像分割任务类似,但需要注意的是输出结果中存储的地理变换、投影等信息以从第一时相影像中读取的信息为准。
+变化检测任务的滑窗推理API与图像分割任务类似,但需要注意的是输出结果中存储的地理变换、投影等信息以从第一时相影像中读取的信息为准,存储滑窗推理结果的文件名也与第一时相影像文件相同
 
 
 ## 静态图推理API
 ## 静态图推理API
 
 
@@ -216,5 +224,10 @@ def predict(self,
 |`transforms`|`paddlers.transforms.Compose`\|`None`|对输入数据应用的数据变换算子。若为`None`,则使用从`model.yml`中读取的算子。|`None`|
 |`transforms`|`paddlers.transforms.Compose`\|`None`|对输入数据应用的数据变换算子。若为`None`,则使用从`model.yml`中读取的算子。|`None`|
 |`warmup_iters`|`int`|预热轮数,用于评估模型推理以及前后处理速度。若大于1,将预先重复执行`warmup_iters`次推理,而后才开始正式的预测及其速度评估。|`0`|
 |`warmup_iters`|`int`|预热轮数,用于评估模型推理以及前后处理速度。若大于1,将预先重复执行`warmup_iters`次推理,而后才开始正式的预测及其速度评估。|`0`|
 |`repeats`|`int`|重复次数,用于评估模型推理以及前后处理速度。若大于1,将执行`repeats`次预测并取时间平均值。|`1`|
 |`repeats`|`int`|重复次数,用于评估模型推理以及前后处理速度。若大于1,将执行`repeats`次预测并取时间平均值。|`1`|
+|`quiet`|`bool`|若为`True`,不打印计时信息。|`False`|
 
 
 `Predictor.predict()`的返回格式与相应的动态图推理API的返回格式完全相同,详情请参考[动态图推理API](#动态图推理api)。
 `Predictor.predict()`的返回格式与相应的动态图推理API的返回格式完全相同,详情请参考[动态图推理API](#动态图推理api)。
+
+### `Predictor.slider_predict()`
+
+实现滑窗推理功能。用法与`BaseSegmenter`和`BaseChangeDetector`的`slider_predict()`方法相同。

+ 58 - 5
paddlers/deploy/predictor.py

@@ -14,6 +14,7 @@
 
 
 import os.path as osp
 import os.path as osp
 from operator import itemgetter
 from operator import itemgetter
+from functools import partial
 
 
 import numpy as np
 import numpy as np
 import paddle
 import paddle
@@ -23,6 +24,7 @@ from paddle.inference import PrecisionType
 
 
 from paddlers.tasks import load_model
 from paddlers.tasks import load_model
 from paddlers.utils import logging, Timer
 from paddlers.utils import logging, Timer
+from paddlers.tasks.utils.slider_predict import slider_predict
 
 
 
 
 class Predictor(object):
 class Predictor(object):
@@ -271,22 +273,24 @@ class Predictor(object):
                 topk=1,
                 topk=1,
                 transforms=None,
                 transforms=None,
                 warmup_iters=0,
                 warmup_iters=0,
-                repeats=1):
+                repeats=1,
+                quiet=False):
         """
         """
-        Do prediction.
+        Do inference.
 
 
         Args:
         Args:
             img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For scene classification, image restoration, 
             img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For scene classification, image restoration, 
                 object detection and semantic segmentation tasks, `img_file` should be either the path of the image to predict,
                 object detection and semantic segmentation tasks, `img_file` should be either the path of the image to predict,
                 a decoded image (a np.ndarray, which should be consistent with what you get from passing image path to
                 a decoded image (a np.ndarray, which should be consistent with what you get from passing image path to
-                paddlers.transforms.decode_image()), or a list of image paths or decoded images. For change detection tasks,
-                img_file should be a tuple of image paths, a tuple of decoded images, or a list of tuples.
+                paddlers.transforms.decode_image(..., read_raw=True)), or a list of image paths or decoded images. For change 
+                detection tasks, `img_file` should be a tuple of image paths, a tuple of decoded images, or a list of tuples.
             topk(int, optional): Top-k values to reserve in a classification result. Defaults to 1.
             topk(int, optional): Top-k values to reserve in a classification result. Defaults to 1.
             transforms (paddlers.transforms.Compose|None, optional): Pipeline of data preprocessing. If None, load transforms
             transforms (paddlers.transforms.Compose|None, optional): Pipeline of data preprocessing. If None, load transforms
                 from `model.yml`. Defaults to None.
                 from `model.yml`. Defaults to None.
             warmup_iters (int, optional): Warm-up iterations before measuring the execution time. Defaults to 0.
             warmup_iters (int, optional): Warm-up iterations before measuring the execution time. Defaults to 0.
             repeats (int, optional): Number of repetitions to evaluate model inference and data processing speed. If greater than
             repeats (int, optional): Number of repetitions to evaluate model inference and data processing speed. If greater than
                 1, the reported time consumption is the average of all repeats. Defaults to 1.
                 1, the reported time consumption is the average of all repeats. Defaults to 1.
+            quiet (bool, optional): If True, do not display the timing information. Defaults to False.
         """
         """
 
 
         if repeats < 1:
         if repeats < 1:
@@ -313,12 +317,61 @@ class Predictor(object):
 
 
         self.timer.repeats = repeats
         self.timer.repeats = repeats
         self.timer.img_num = len(images)
         self.timer.img_num = len(images)
-        self.timer.info(average=True)
+        if not quiet:
+            self.timer.info(average=True)
 
 
         if isinstance(img_file, (str, np.ndarray, tuple)):
         if isinstance(img_file, (str, np.ndarray, tuple)):
             results = results[0]
             results = results[0]
 
 
         return results
         return results
 
 
+    def slider_predict(self,
+                       img_file,
+                       save_dir,
+                       block_size,
+                       overlap=36,
+                       transforms=None,
+                       invalid_value=255,
+                       merge_strategy='keep_last',
+                       batch_size=1,
+                       quiet=False):
+        """
+        Do inference using sliding windows. Only semantic segmentation and change detection models are supported in the 
+            sliding-predicting mode.
+
+        Args:
+            img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For semantic segmentation tasks, `img_file` 
+                should be either the path of the image to predict, a decoded image (a np.ndarray, which should be 
+                consistent with what you get from passing image path to paddlers.transforms.decode_image(..., read_raw=True)), 
+                or a list of image paths or decoded images. For change detection tasks, `img_file` should be a tuple of 
+                image paths, a tuple of decoded images, or a list of tuples.
+            save_dir (str): Directory that contains saved geotiff file.
+            block_size (list[int] | tuple[int] | int): Size of block. If `block_size` is a list or tuple, it should be in 
+                (W, H) format.
+            overlap (list[int] | tuple[int] | int, optional): Overlap between two blocks. If `overlap` is a list or tuple, 
+                it should be in (W, H) format. Defaults to 36.
+            transforms (paddlers.transforms.Compose|None, optional): Pipeline of data preprocessing. If None, load transforms
+                from `model.yml`. Defaults to None.
+            invalid_value (int, optional): Value that marks invalid pixels in output image. Defaults to 255.
+            merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices are 
+                {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' means keeping the values of the first and 
+                the last block in traversal order, respectively. 'accum' means determining the class of an overlapping pixel 
+                according to accumulated probabilities. Defaults to 'keep_last'.
+            batch_size (int, optional): Batch size used in inference. Defaults to 1.
+            quiet (bool, optional): If True, disable the progress bar. Defaults to False.
+        """
+        slider_predict(
+            partial(
+                self.predict, quiet=True),
+            img_file,
+            save_dir,
+            block_size,
+            overlap,
+            transforms,
+            invalid_value,
+            merge_strategy,
+            batch_size,
+            not quiet)
+
     def batch_predict(self, image_list, **params):
     def batch_predict(self, image_list, **params):
         return self.predict(img_file=image_list, **params)
         return self.predict(img_file=image_list, **params)

+ 16 - 1
paddlers/tasks/base.py

@@ -86,7 +86,7 @@ class BaseModel(metaclass=ModelMeta):
         self.quant_config = None
         self.quant_config = None
         self.fixed_input_shape = None
         self.fixed_input_shape = None
 
 
-    def net_initialize(self,
+    def initialize_net(self,
                        pretrain_weights=None,
                        pretrain_weights=None,
                        save_dir='.',
                        save_dir='.',
                        resume_checkpoint=None,
                        resume_checkpoint=None,
@@ -677,3 +677,18 @@ class BaseModel(metaclass=ModelMeta):
             raise ValueError(
             raise ValueError(
                 f"Incorrect arrange mode! Expected {mode} but got {arrange_obj.mode}."
                 f"Incorrect arrange mode! Expected {mode} but got {arrange_obj.mode}."
             )
             )
+
+    def run(self, net, inputs, mode):
+        raise NotImplementedError
+
+    def train(self, *args, **kwargs):
+        raise NotImplementedError
+
+    def evaluate(self, *args, **kwargs):
+        raise NotImplementedError
+
+    def preprocess(self, images, transforms, to_tensor):
+        raise NotImplementedError
+
+    def postprocess(self, *args, **kwargs):
+        raise NotImplementedError

+ 33 - 84
paddlers/tasks/change_detector.py

@@ -35,6 +35,7 @@ from paddlers.utils.checkpoint import cd_pretrain_weights_dict
 from .base import BaseModel
 from .base import BaseModel
 from .utils import seg_metrics as metrics
 from .utils import seg_metrics as metrics
 from .utils.infer_nets import InferCDNet
 from .utils.infer_nets import InferCDNet
+from .utils.slider_predict import slider_predict
 
 
 __all__ = [
 __all__ = [
     "CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT",
     "CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT",
@@ -315,7 +316,7 @@ class BaseChangeDetector(BaseModel):
                         exit=True)
                         exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         pretrained_dir = osp.join(save_dir, 'pretrain')
         is_backbone_weights = pretrain_weights == 'IMAGENET'
         is_backbone_weights = pretrain_weights == 'IMAGENET'
-        self.net_initialize(
+        self.initialize_net(
             pretrain_weights=pretrain_weights,
             pretrain_weights=pretrain_weights,
             save_dir=pretrained_dir,
             save_dir=pretrained_dir,
             resume_checkpoint=resume_checkpoint,
             resume_checkpoint=resume_checkpoint,
@@ -581,96 +582,44 @@ class BaseChangeDetector(BaseModel):
         return prediction
         return prediction
 
 
     def slider_predict(self,
     def slider_predict(self,
-                       img_file,
+                       img_files,
                        save_dir,
                        save_dir,
                        block_size,
                        block_size,
                        overlap=36,
                        overlap=36,
-                       transforms=None):
+                       transforms=None,
+                       invalid_value=255,
+                       merge_strategy='keep_last',
+                       batch_size=1,
+                       quiet=False):
         """
         """
-        Do inference.
+        Do inference using sliding windows.
 
 
         Args:
         Args:
-            img_file (tuple[str]): Tuple of image paths.
+            img_files (tuple[str]): Tuple of image paths.
             save_dir (str): Directory that contains saved geotiff file.
             save_dir (str): Directory that contains saved geotiff file.
-            block_size (list[int] | tuple[int] | int, optional): Size of block.
-            overlap (list[int] | tuple[int] | int, optional): Overlap between two blocks. 
-                Defaults to 36.
-            transforms (paddlers.transforms.Compose|None, optional): Transforms for inputs.
-                If None, the transforms for evaluation process will be used. Defaults to None.
+            block_size (list[int] | tuple[int] | int):
+                Size of block. If `block_size` is a list or tuple, it should be in 
+                (W, H) format.
+            overlap (list[int] | tuple[int] | int, optional):
+                Overlap between two blocks. If `overlap` is a list or tuple, it should
+                be in (W, H) format. Defaults to 36.
+            transforms (paddlers.transforms.Compose|None, optional): Transforms for 
+                inputs. If None, the transforms for evaluation process will be used. 
+                Defaults to None.
+            invalid_value (int, optional): Value that marks invalid pixels in output 
+                image. Defaults to 255.
+            merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices
+                are {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' 
+                means keeping the values of the first and the last block in traversal 
+                order, respectively. 'accum' means determining the class of an overlapping 
+                pixel according to accumulated probabilities. Defaults to 'keep_last'.
+            batch_size (int, optional): Batch size used in inference. Defaults to 1.
+            quiet (bool, optional): If True, disable the progress bar. Defaults to False.
         """
         """
 
 
-        try:
-            from osgeo import gdal
-        except:
-            import gdal
-
-        if not isinstance(img_file, tuple) or len(img_file) != 2:
-            raise ValueError("`img_file` must be a tuple of length 2.")
-        if isinstance(block_size, int):
-            block_size = (block_size, block_size)
-        elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
-            block_size = tuple(block_size)
-        else:
-            raise ValueError(
-                "`block_size` must be a tuple/list of length 2 or an integer.")
-        if isinstance(overlap, int):
-            overlap = (overlap, overlap)
-        elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
-            overlap = tuple(overlap)
-        else:
-            raise ValueError(
-                "`overlap` must be a tuple/list of length 2 or an integer.")
-
-        src1_data = gdal.Open(img_file[0])
-        src2_data = gdal.Open(img_file[1])
-        width = src1_data.RasterXSize
-        height = src1_data.RasterYSize
-        bands = src1_data.RasterCount
-
-        driver = gdal.GetDriverByName("GTiff")
-        file_name = osp.splitext(osp.normpath(img_file[0]).split(os.sep)[-1])[
-            0] + ".tif"
-        if not osp.exists(save_dir):
-            os.makedirs(save_dir)
-        save_file = osp.join(save_dir, file_name)
-        dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
-        dst_data.SetGeoTransform(src1_data.GetGeoTransform())
-        dst_data.SetProjection(src1_data.GetProjection())
-        band = dst_data.GetRasterBand(1)
-        band.WriteArray(255 * np.ones((height, width), dtype="uint8"))
-
-        step = np.array(block_size) - np.array(overlap)
-        for yoff in range(0, height, step[1]):
-            for xoff in range(0, width, step[0]):
-                xsize, ysize = block_size
-                if xoff + xsize > width:
-                    xsize = int(width - xoff)
-                if yoff + ysize > height:
-                    ysize = int(height - yoff)
-                im1 = src1_data.ReadAsArray(
-                    int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0))
-                im2 = src2_data.ReadAsArray(
-                    int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0))
-                # Fill
-                h, w = im1.shape[:2]
-                im1_fill = np.zeros(
-                    (block_size[1], block_size[0], bands), dtype=im1.dtype)
-                im2_fill = im1_fill.copy()
-                im1_fill[:h, :w, :] = im1
-                im2_fill[:h, :w, :] = im2
-                im_fill = (im1_fill, im2_fill)
-                # Predict
-                pred = self.predict(im_fill,
-                                    transforms)["label_map"].astype("uint8")
-                # Overlap
-                rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize)
-                mask = (rd_block == pred[:h, :w]) | (rd_block == 255)
-                temp = pred[:h, :w].copy()
-                temp[mask == False] = 0
-                band.WriteArray(temp, int(xoff), int(yoff))
-                dst_data.FlushCache()
-        dst_data = None
-        print("GeoTiff saved in {}.".format(save_file))
+        slider_predict(self.predict, img_files, save_dir, block_size, overlap,
+                       transforms, invalid_value, merge_strategy, batch_size,
+                       not quiet)
 
 
     def preprocess(self, images, transforms, to_tensor=True):
     def preprocess(self, images, transforms, to_tensor=True):
         self._check_transforms(transforms, 'test')
         self._check_transforms(transforms, 'test')
@@ -678,8 +627,8 @@ class BaseChangeDetector(BaseModel):
         batch_ori_shape = list()
         batch_ori_shape = list()
         for im1, im2 in images:
         for im1, im2 in images:
             if isinstance(im1, str) or isinstance(im2, str):
             if isinstance(im1, str) or isinstance(im2, str):
-                im1 = decode_image(im1, to_rgb=False)
-                im2 = decode_image(im2, to_rgb=False)
+                im1 = decode_image(im1, read_raw=True)
+                im2 = decode_image(im2, read_raw=True)
             ori_shape = im1.shape[:2]
             ori_shape = im1.shape[:2]
             # XXX: sample do not contain 'image_t1' and 'image_t2'.
             # XXX: sample do not contain 'image_t1' and 'image_t2'.
             sample = {'image': im1, 'image2': im2}
             sample = {'image': im1, 'image2': im2}

+ 2 - 2
paddlers/tasks/classifier.py

@@ -288,7 +288,7 @@ class BaseClassifier(BaseModel):
                         exit=True)
                         exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         pretrained_dir = osp.join(save_dir, 'pretrain')
         is_backbone_weights = False
         is_backbone_weights = False
-        self.net_initialize(
+        self.initialize_net(
             pretrain_weights=pretrain_weights,
             pretrain_weights=pretrain_weights,
             save_dir=pretrained_dir,
             save_dir=pretrained_dir,
             resume_checkpoint=resume_checkpoint,
             resume_checkpoint=resume_checkpoint,
@@ -497,7 +497,7 @@ class BaseClassifier(BaseModel):
         batch_ori_shape = list()
         batch_ori_shape = list()
         for im in images:
         for im in images:
             if isinstance(im, str):
             if isinstance(im, str):
-                im = decode_image(im, to_rgb=False)
+                im = decode_image(im, read_raw=True)
             ori_shape = im.shape[:2]
             ori_shape = im.shape[:2]
             sample = {'image': im}
             sample = {'image': im}
             im = transforms(sample)
             im = transforms(sample)

+ 2 - 2
paddlers/tasks/object_detector.py

@@ -347,7 +347,7 @@ class BaseDetector(BaseModel):
                         "Invalid pretrained weights. Please specify a .pdparams file.",
                         "Invalid pretrained weights. Please specify a .pdparams file.",
                         exit=True)
                         exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         pretrained_dir = osp.join(save_dir, 'pretrain')
-        self.net_initialize(
+        self.initialize_net(
             pretrain_weights=pretrain_weights,
             pretrain_weights=pretrain_weights,
             save_dir=pretrained_dir,
             save_dir=pretrained_dir,
             resume_checkpoint=resume_checkpoint,
             resume_checkpoint=resume_checkpoint,
@@ -617,7 +617,7 @@ class BaseDetector(BaseModel):
         batch_samples = list()
         batch_samples = list()
         for im in images:
         for im in images:
             if isinstance(im, str):
             if isinstance(im, str):
-                im = decode_image(im, to_rgb=False)
+                im = decode_image(im, read_raw=True)
             sample = {'image': im}
             sample = {'image': im}
             sample = transforms(sample)
             sample = transforms(sample)
             batch_samples.append(sample)
             batch_samples.append(sample)

+ 2 - 2
paddlers/tasks/restorer.py

@@ -283,7 +283,7 @@ class BaseRestorer(BaseModel):
                         exit=True)
                         exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         pretrained_dir = osp.join(save_dir, 'pretrain')
         is_backbone_weights = pretrain_weights == 'IMAGENET'
         is_backbone_weights = pretrain_weights == 'IMAGENET'
-        self.net_initialize(
+        self.initialize_net(
             pretrain_weights=pretrain_weights,
             pretrain_weights=pretrain_weights,
             save_dir=pretrained_dir,
             save_dir=pretrained_dir,
             resume_checkpoint=resume_checkpoint,
             resume_checkpoint=resume_checkpoint,
@@ -481,7 +481,7 @@ class BaseRestorer(BaseModel):
         batch_tar_shape = list()
         batch_tar_shape = list()
         for im in images:
         for im in images:
             if isinstance(im, str):
             if isinstance(im, str):
-                im = decode_image(im, to_rgb=False)
+                im = decode_image(im, read_raw=True)
             ori_shape = im.shape[:2]
             ori_shape = im.shape[:2]
             sample = {'image': im}
             sample = {'image': im}
             im = transforms(sample)[0]
             im = transforms(sample)[0]

+ 25 - 70
paddlers/tasks/segmenter.py

@@ -34,6 +34,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from .base import BaseModel
 from .base import BaseModel
 from .utils import seg_metrics as metrics
 from .utils import seg_metrics as metrics
 from .utils.infer_nets import InferSegNet
 from .utils.infer_nets import InferSegNet
+from .utils.slider_predict import slider_predict
 
 
 __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"]
 __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"]
 
 
@@ -307,7 +308,7 @@ class BaseSegmenter(BaseModel):
                         exit=True)
                         exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         pretrained_dir = osp.join(save_dir, 'pretrain')
         is_backbone_weights = pretrain_weights == 'IMAGENET'
         is_backbone_weights = pretrain_weights == 'IMAGENET'
-        self.net_initialize(
+        self.initialize_net(
             pretrain_weights=pretrain_weights,
             pretrain_weights=pretrain_weights,
             save_dir=pretrained_dir,
             save_dir=pretrained_dir,
             resume_checkpoint=resume_checkpoint,
             resume_checkpoint=resume_checkpoint,
@@ -557,86 +558,40 @@ class BaseSegmenter(BaseModel):
                        save_dir,
                        save_dir,
                        block_size,
                        block_size,
                        overlap=36,
                        overlap=36,
-                       transforms=None):
+                       transforms=None,
+                       invalid_value=255,
+                       merge_strategy='keep_last',
+                       batch_size=1,
+                       quiet=False):
         """
         """
-        Do inference.
+        Do inference using sliding windows.
 
 
         Args:
         Args:
             img_file (str): Image path.
             img_file (str): Image path.
             save_dir (str): Directory that contains saved geotiff file.
             save_dir (str): Directory that contains saved geotiff file.
             block_size (list[int] | tuple[int] | int):
             block_size (list[int] | tuple[int] | int):
-                Size of block.
+                Size of block. If `block_size` is list or tuple, it should be in 
+                (W, H) format.
             overlap (list[int] | tuple[int] | int, optional):
             overlap (list[int] | tuple[int] | int, optional):
-                Overlap between two blocks. Defaults to 36.
+                Overlap between two blocks. If `overlap` is list or tuple, it should
+                be in (W, H) format. Defaults to 36.
             transforms (paddlers.transforms.Compose|None, optional): Transforms for 
             transforms (paddlers.transforms.Compose|None, optional): Transforms for 
                 inputs. If None, the transforms for evaluation process will be used. 
                 inputs. If None, the transforms for evaluation process will be used. 
                 Defaults to None.
                 Defaults to None.
+            invalid_value (int, optional): Value that marks invalid pixels in output 
+                image. Defaults to 255.
+            merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices
+                are {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' 
+                means keeping the values of the first and the last block in traversal 
+                order, respectively. 'accum' means determining the class of an overlapping 
+                pixel according to accumulated probabilities. Defaults to 'keep_last'.
+            batch_size (int, optional): Batch size used in inference. Defaults to 1.
+            quiet (bool, optional): If True, disable the progress bar. Defaults to False.
         """
         """
 
 
-        try:
-            from osgeo import gdal
-        except:
-            import gdal
-
-        if isinstance(block_size, int):
-            block_size = (block_size, block_size)
-        elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
-            block_size = tuple(block_size)
-        else:
-            raise ValueError(
-                "`block_size` must be a tuple/list of length 2 or an integer.")
-        if isinstance(overlap, int):
-            overlap = (overlap, overlap)
-        elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
-            overlap = tuple(overlap)
-        else:
-            raise ValueError(
-                "`overlap` must be a tuple/list of length 2 or an integer.")
-
-        src_data = gdal.Open(img_file)
-        width = src_data.RasterXSize
-        height = src_data.RasterYSize
-        bands = src_data.RasterCount
-
-        driver = gdal.GetDriverByName("GTiff")
-        file_name = osp.splitext(osp.normpath(img_file).split(os.sep)[-1])[
-            0] + ".tif"
-        if not osp.exists(save_dir):
-            os.makedirs(save_dir)
-        save_file = osp.join(save_dir, file_name)
-        dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
-        dst_data.SetGeoTransform(src_data.GetGeoTransform())
-        dst_data.SetProjection(src_data.GetProjection())
-        band = dst_data.GetRasterBand(1)
-        band.WriteArray(255 * np.ones((height, width), dtype="uint8"))
-
-        step = np.array(block_size) - np.array(overlap)
-        for yoff in range(0, height, step[1]):
-            for xoff in range(0, width, step[0]):
-                xsize, ysize = block_size
-                if xoff + xsize > width:
-                    xsize = int(width - xoff)
-                if yoff + ysize > height:
-                    ysize = int(height - yoff)
-                im = src_data.ReadAsArray(int(xoff), int(yoff), xsize,
-                                          ysize).transpose((1, 2, 0))
-                # Fill
-                h, w = im.shape[:2]
-                im_fill = np.zeros(
-                    (block_size[1], block_size[0], bands), dtype=im.dtype)
-                im_fill[:h, :w, :] = im
-                # Predict
-                pred = self.predict(im_fill,
-                                    transforms)["label_map"].astype("uint8")
-                # Overlap
-                rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize)
-                mask = (rd_block == pred[:h, :w]) | (rd_block == 255)
-                temp = pred[:h, :w].copy()
-                temp[mask == False] = 0
-                band.WriteArray(temp, int(xoff), int(yoff))
-                dst_data.FlushCache()
-        dst_data = None
-        print("GeoTiff saved in {}.".format(save_file))
+        slider_predict(self.predict, img_file, save_dir, block_size, overlap,
+                       transforms, invalid_value, merge_strategy, batch_size,
+                       not quiet)
 
 
     def preprocess(self, images, transforms, to_tensor=True):
     def preprocess(self, images, transforms, to_tensor=True):
         self._check_transforms(transforms, 'test')
         self._check_transforms(transforms, 'test')
@@ -644,7 +599,7 @@ class BaseSegmenter(BaseModel):
         batch_ori_shape = list()
         batch_ori_shape = list()
         for im in images:
         for im in images:
             if isinstance(im, str):
             if isinstance(im, str):
-                im = decode_image(im, to_rgb=False)
+                im = decode_image(im, read_raw=True)
             ori_shape = im.shape[:2]
             ori_shape = im.shape[:2]
             sample = {'image': im}
             sample = {'image': im}
             im = transforms(sample)[0]
             im = transforms(sample)[0]

+ 437 - 0
paddlers/tasks/utils/slider_predict.py

@@ -0,0 +1,437 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+import math
+from abc import ABCMeta, abstractmethod
+from collections import Counter, defaultdict
+
+import numpy as np
+from tqdm import tqdm
+
+import paddlers.utils.logging as logging
+
+
+class Cache(metaclass=ABCMeta):
+    @abstractmethod
+    def get_block(self, i_st, j_st, h, w):
+        pass
+
+
+class SlowCache(Cache):
+    def __init__(self):
+        super(SlowCache, self).__init__()
+        self.cache = defaultdict(Counter)
+
+    def push_pixel(self, i, j, l):
+        self.cache[(i, j)][l] += 1
+
+    def push_block(self, i_st, j_st, h, w, data):
+        for i in range(0, h):
+            for j in range(0, w):
+                self.push_pixel(i_st + i, j_st + j, data[i, j])
+
+    def pop_pixel(self, i, j):
+        self.cache.pop((i, j))
+
+    def pop_block(self, i_st, j_st, h, w):
+        for i in range(0, h):
+            for j in range(0, w):
+                self.pop_pixel(i_st + i, j_st + j)
+
+    def get_pixel(self, i, j):
+        winners = self.cache[(i, j)].most_common(1)
+        winner = winners[0]
+        return winner[0]
+
+    def get_block(self, i_st, j_st, h, w):
+        block = []
+        for i in range(i_st, i_st + h):
+            row = []
+            for j in range(j_st, j_st + w):
+                row.append(self.get_pixel(i, j))
+            block.append(row)
+        return np.asarray(block)
+
+
+class ProbCache(Cache):
+    def __init__(self, h, w, ch, cw, sh, sw, dtype=np.float32, order='c'):
+        super(ProbCache, self).__init__()
+        self.cache = None
+        self.h = h
+        self.w = w
+        self.ch = ch
+        self.cw = cw
+        self.sh = sh
+        self.sw = sw
+        if not issubclass(dtype, np.floating):
+            raise TypeError("`dtype` must be one of the floating types.")
+        self.dtype = dtype
+        order = order.lower()
+        if order not in ('c', 'f'):
+            raise ValueError("`order` other than 'c' and 'f' is not supported.")
+        self.order = order
+
+    def _alloc_memory(self, nc):
+        if self.order == 'c':
+            # Colomn-first order (C-style)
+            #
+            # <-- cw -->
+            # |--------|---------------------|^    ^
+            # |                              ||    | sh
+            # |--------|---------------------|| ch v
+            # |                              || 
+            # |--------|---------------------|v
+            # <------------ w --------------->
+            self.cache = np.zeros((self.ch, self.w, nc), dtype=self.dtype)
+        elif self.order == 'f':
+            # Row-first order (Fortran-style)
+            #
+            # <-- sw -->
+            # <---- cw ---->
+            # |--------|---|^   ^
+            # |        |   ||   |
+            # |        |   ||   ch
+            # |        |   ||   |
+            # |--------|---|| h v
+            # |        |   ||
+            # |        |   ||
+            # |        |   ||
+            # |--------|---|v
+            self.cache = np.zeros((self.h, self.cw, nc), dtype=self.dtype)
+
+    def update_block(self, i_st, j_st, h, w, prob_map):
+        if self.cache is None:
+            nc = prob_map.shape[2]
+            # Lazy allocation of memory
+            self._alloc_memory(nc)
+        self.cache[i_st:i_st + h, j_st:j_st + w] += prob_map
+
+    def roll_cache(self, shift):
+        if self.order == 'c':
+            self.cache[:-shift] = self.cache[shift:]
+            self.cache[-shift:, :] = 0
+        elif self.order == 'f':
+            self.cache[:, :-shift] = self.cache[:, shift:]
+            self.cache[:, -shift:] = 0
+
+    def get_block(self, i_st, j_st, h, w):
+        return np.argmax(self.cache[i_st:i_st + h, j_st:j_st + w], axis=2)
+
+
+class OverlapProcessor(metaclass=ABCMeta):
+    def __init__(self, h, w, ch, cw, sh, sw):
+        super(OverlapProcessor, self).__init__()
+        self.h = h
+        self.w = w
+        self.ch = ch
+        self.cw = cw
+        self.sh = sh
+        self.sw = sw
+
+    @abstractmethod
+    def process_pred(self, out, xoff, yoff):
+        pass
+
+
+class KeepFirstProcessor(OverlapProcessor):
+    def __init__(self, h, w, ch, cw, sh, sw, ds, inval=255):
+        super(KeepFirstProcessor, self).__init__(h, w, ch, cw, sh, sw)
+        self.ds = ds
+        self.inval = inval
+
+    def process_pred(self, out, xoff, yoff):
+        pred = out['label_map']
+        pred = pred[:self.ch, :self.cw]
+        rd_block = self.ds.ReadAsArray(xoff, yoff, self.cw, self.ch)
+        mask = rd_block != self.inval
+        pred = np.where(mask, rd_block, pred)
+        return pred
+
+
+class KeepLastProcessor(OverlapProcessor):
+    def process_pred(self, out, xoff, yoff):
+        pred = out['label_map']
+        pred = pred[:self.ch, :self.cw]
+        return pred
+
+
+class AccumProcessor(OverlapProcessor):
+    def __init__(self,
+                 h,
+                 w,
+                 ch,
+                 cw,
+                 sh,
+                 sw,
+                 dtype=np.float16,
+                 assign_weight=True):
+        super(AccumProcessor, self).__init__(h, w, ch, cw, sh, sw)
+        self.cache = ProbCache(h, w, ch, cw, sh, sw, dtype=dtype, order='c')
+        self.prev_yoff = None
+        self.assign_weight = assign_weight
+
+    def process_pred(self, out, xoff, yoff):
+        if self.prev_yoff is not None and yoff != self.prev_yoff:
+            if yoff < self.prev_yoff:
+                raise RuntimeError
+            self.cache.roll_cache(yoff - self.prev_yoff)
+        pred = out['label_map']
+        pred = pred[:self.ch, :self.cw]
+        prob = out['score_map']
+        prob = prob[:self.ch, :self.cw]
+        if self.assign_weight:
+            prob = assign_border_weights(prob, border_ratio=0.25, inplace=True)
+        self.cache.update_block(0, xoff, self.ch, self.cw, prob)
+        pred = self.cache.get_block(0, xoff, self.ch, self.cw)
+        self.prev_yoff = yoff
+        return pred
+
+
+def assign_border_weights(array, weight=0.5, border_ratio=0.25, inplace=True):
+    if not inplace:
+        array = array.copy()
+    h, w = array.shape[:2]
+    hm, wm = int(h * border_ratio), int(w * border_ratio)
+    array[:hm] *= weight
+    array[-hm:] *= weight
+    array[:, :wm] *= weight
+    array[:, -wm:] *= weight
+    return array
+
+
+def read_block(ds,
+               xoff,
+               yoff,
+               xsize,
+               ysize,
+               tar_xsize=None,
+               tar_ysize=None,
+               pad_val=0):
+    if tar_xsize is None:
+        tar_xsize = xsize
+    if tar_ysize is None:
+        tar_ysize = ysize
+    # Read data from dataset
+    block = ds.ReadAsArray(xoff, yoff, xsize, ysize)
+    c, real_ysize, real_xsize = block.shape
+    assert real_ysize == ysize and real_xsize == xsize
+    # [c, h, w] -> [h, w, c]
+    block = block.transpose((1, 2, 0))
+    if (real_ysize, real_xsize) != (tar_ysize, tar_xsize):
+        if real_ysize >= tar_ysize or real_xsize >= tar_xsize:
+            raise ValueError
+        padded_block = np.full(
+            (tar_ysize, tar_xsize, c), fill_value=pad_val, dtype=block.dtype)
+        # Fill
+        padded_block[:real_ysize, :real_xsize] = block
+        return padded_block
+    else:
+        return block
+
+
+def slider_predict(predict_func,
+                   img_file,
+                   save_dir,
+                   block_size,
+                   overlap,
+                   transforms,
+                   invalid_value,
+                   merge_strategy,
+                   batch_size,
+                   show_progress=False):
+    """
+    Do inference using sliding windows.
+
+    Args:
+        predict_func (callable): A callable object that makes the prediction.
+        img_file (str|tuple[str]): Image path(s).
+        save_dir (str): Directory that contains saved geotiff file.
+        block_size (list[int] | tuple[int] | int):
+            Size of block. If `block_size` is list or tuple, it should be in 
+            (W, H) format.
+        overlap (list[int] | tuple[int] | int):
+            Overlap between two blocks. If `overlap` is list or tuple, it should
+            be in (W, H) format.
+        transforms (paddlers.transforms.Compose|None): Transforms for inputs. If 
+            None, the transforms for evaluation process will be used. 
+        invalid_value (int): Value that marks invalid pixels in output image. 
+            Defaults to 255.
+        merge_strategy (str): Strategy to merge overlapping blocks. Choices are 
+            {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' 
+            means keeping the values of the first and the last block in 
+            traversal order, respectively. 'accum' means determining the class 
+            of an overlapping pixel according to accumulated probabilities.
+        batch_size (int): Batch size used in inference.
+        show_progress (bool, optional): Whether to show prediction progress with a 
+            progress bar. Defaults to True.
+    """
+
+    try:
+        from osgeo import gdal
+    except:
+        import gdal
+
+    if isinstance(block_size, int):
+        block_size = (block_size, block_size)
+    elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
+        block_size = tuple(block_size)
+    else:
+        raise ValueError(
+            "`block_size` must be a tuple/list of length 2 or an integer.")
+    if isinstance(overlap, int):
+        overlap = (overlap, overlap)
+    elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
+        overlap = tuple(overlap)
+    else:
+        raise ValueError(
+            "`overlap` must be a tuple/list of length 2 or an integer.")
+
+    step = np.array(
+        block_size, dtype=np.int32) - np.array(
+            overlap, dtype=np.int32)
+    if step[0] == 0 or step[1] == 0:
+        raise ValueError("`block_size` and `overlap` should not be equal.")
+
+    if isinstance(img_file, tuple):
+        if len(img_file) != 2:
+            raise ValueError("Tuple `img_file` must have the length of two.")
+        # Assume that two input images have the same size
+        src_data = gdal.Open(img_file[0])
+        src2_data = gdal.Open(img_file[1])
+        # Output name is the same as the name of the first image
+        file_name = osp.basename(osp.normpath(img_file[0]))
+    else:
+        src_data = gdal.Open(img_file)
+        file_name = osp.basename(osp.normpath(img_file))
+
+    # Get size of original raster
+    width = src_data.RasterXSize
+    height = src_data.RasterYSize
+    bands = src_data.RasterCount
+
+    # XXX: GDAL read behavior conforms to paddlers.transforms.decode_image(read_raw=True)
+    # except for SAR images.
+    if bands == 1:
+        logging.warning(
+            f"Detected `bands=1`. Please note that currently `slider_predict()` does not properly handle SAR images."
+        )
+
+    if block_size[0] > width or block_size[1] > height:
+        raise ValueError("`block_size` should not be larger than image size.")
+
+    driver = gdal.GetDriverByName("GTiff")
+    if not osp.exists(save_dir):
+        os.makedirs(save_dir)
+    # Replace extension name with '.tif'
+    file_name = osp.splitext(file_name)[0] + ".tif"
+    save_file = osp.join(save_dir, file_name)
+    dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
+
+    # Set meta-information
+    dst_data.SetGeoTransform(src_data.GetGeoTransform())
+    dst_data.SetProjection(src_data.GetProjection())
+
+    # Initialize raster with `invalid_value`
+    band = dst_data.GetRasterBand(1)
+    band.WriteArray(
+        np.full(
+            (height, width), fill_value=invalid_value, dtype="uint8"))
+
+    if overlap == (0, 0) or block_size == (width, height):
+        # When there is no overlap or the whole image is used as input, 
+        # use 'keep_last' strategy as it introduces least overheads
+        merge_strategy = 'keep_last'
+
+    if merge_strategy == 'keep_first':
+        overlap_processor = KeepFirstProcessor(
+            height,
+            width,
+            *block_size[::-1],
+            *step[::-1],
+            band,
+            inval=invalid_value)
+    elif merge_strategy == 'keep_last':
+        overlap_processor = KeepLastProcessor(height, width, *block_size[::-1],
+                                              *step[::-1])
+    elif merge_strategy == 'accum':
+        overlap_processor = AccumProcessor(height, width, *block_size[::-1],
+                                           *step[::-1])
+    else:
+        raise ValueError("{} is not a supported stragegy for block merging.".
+                         format(merge_strategy))
+
+    xsize, ysize = block_size
+    num_blocks = math.ceil(height / step[1]) * math.ceil(width / step[0])
+    cnt = 0
+    if show_progress:
+        pb = tqdm(total=num_blocks)
+    batch_data = []
+    batch_offsets = []
+    for yoff in range(0, height, step[1]):
+        for xoff in range(0, width, step[0]):
+            if xoff + xsize > width:
+                xoff = width - xsize
+                is_end_of_row = True
+            else:
+                is_end_of_row = False
+            if yoff + ysize > height:
+                yoff = height - ysize
+                is_end_of_col = True
+            else:
+                is_end_of_col = False
+
+            # Read
+            im = read_block(src_data, xoff, yoff, xsize, ysize)
+
+            if isinstance(img_file, tuple):
+                im2 = read_block(src2_data, xoff, yoff, xsize, ysize)
+                batch_data.append((im, im2))
+            else:
+                batch_data.append(im)
+
+            batch_offsets.append((xoff, yoff))
+
+            len_batch = len(batch_data)
+
+            if is_end_of_row and is_end_of_col and len_batch < batch_size:
+                # Pad `batch_data` by repeating the last element
+                batch_data = batch_data + [batch_data[-1]] * (batch_size -
+                                                              len_batch)
+                # While keeping `len(batch_offsets)` the number of valid elements in the batch 
+
+            if len(batch_data) == batch_size:
+                # Predict
+                batch_out = predict_func(batch_data, transforms=transforms)
+
+                for out, (xoff_, yoff_) in zip(batch_out, batch_offsets):
+                    # Get processed result
+                    pred = overlap_processor.process_pred(out, xoff_, yoff_)
+                    # Write to file
+                    band.WriteArray(pred, xoff_, yoff_)
+
+                dst_data.FlushCache()
+                batch_data.clear()
+                batch_offsets.clear()
+
+            cnt += 1
+
+            if show_progress:
+                pb.update(1)
+                pb.set_description("{} out of {} blocks processed.".format(
+                    cnt, num_blocks))
+
+    dst_data = None
+    logging.info("GeoTiff file saved in {}.".format(save_file))

+ 14 - 3
paddlers/transforms/__init__.py

@@ -25,7 +25,9 @@ def decode_image(im_path,
                  to_uint8=True,
                  to_uint8=True,
                  decode_bgr=True,
                  decode_bgr=True,
                  decode_sar=True,
                  decode_sar=True,
-                 read_geo_info=False):
+                 read_geo_info=False,
+                 use_stretch=False,
+                 read_raw=False):
     """
     """
     Decode an image.
     Decode an image.
     
     
@@ -37,11 +39,16 @@ def decode_image(im_path,
             uint8 type. Defaults to True.
             uint8 type. Defaults to True.
         decode_bgr (bool, optional): If True, automatically interpret a non-geo 
         decode_bgr (bool, optional): If True, automatically interpret a non-geo 
             image (e.g. jpeg images) as a BGR image. Defaults to True.
             image (e.g. jpeg images) as a BGR image. Defaults to True.
-        decode_sar (bool, optional): If True, automatically interpret a two-channel 
+        decode_sar (bool, optional): If True, automatically interpret a single-channel 
             geo image (e.g. geotiff images) as a SAR image, set this argument to 
             geo image (e.g. geotiff images) as a SAR image, set this argument to 
             True. Defaults to True.
             True. Defaults to True.
         read_geo_info (bool, optional): If True, read geographical information from 
         read_geo_info (bool, optional): If True, read geographical information from 
             the image. Deafults to False.
             the image. Deafults to False.
+        use_stretch (bool, optional): Whether to apply 2% linear stretch. Valid only if 
+            `to_uint8` is True. Defaults to False.
+        read_raw (bool, optional): If True, equivalent to setting `to_rgb` and `to_uint8`
+            to False. Setting `read_raw` takes precedence over setting `to_rgb` and 
+            `to_uint8`. Defaults to False.
     
     
     Returns:
     Returns:
         np.ndarray|tuple: If `read_geo_info` is False, return the decoded image. 
         np.ndarray|tuple: If `read_geo_info` is False, return the decoded image. 
@@ -53,12 +60,16 @@ def decode_image(im_path,
     # Do a presence check. osp.exists() assumes `im_path` is a path-like object.
     # Do a presence check. osp.exists() assumes `im_path` is a path-like object.
     if not osp.exists(im_path):
     if not osp.exists(im_path):
         raise ValueError(f"{im_path} does not exist!")
         raise ValueError(f"{im_path} does not exist!")
+    if read_raw:
+        to_rgb = False
+        to_uint8 = False
     decoder = T.DecodeImg(
     decoder = T.DecodeImg(
         to_rgb=to_rgb,
         to_rgb=to_rgb,
         to_uint8=to_uint8,
         to_uint8=to_uint8,
         decode_bgr=decode_bgr,
         decode_bgr=decode_bgr,
         decode_sar=decode_sar,
         decode_sar=decode_sar,
-        read_geo_info=read_geo_info)
+        read_geo_info=read_geo_info,
+        use_stretch=use_stretch)
     # Deepcopy to avoid inplace modification
     # Deepcopy to avoid inplace modification
     sample = {'image': copy.deepcopy(im_path)}
     sample = {'image': copy.deepcopy(im_path)}
     sample = decoder(sample)
     sample = decoder(sample)

+ 3 - 3
paddlers/transforms/functions.py

@@ -382,13 +382,13 @@ def resize_rle(rle, im_h, im_w, im_scale_x, im_scale_y, interp):
     return rle
     return rle
 
 
 
 
-def to_uint8(im, is_linear=False):
+def to_uint8(im, stretch=False):
     """
     """
     Convert raster data to uint8 type.
     Convert raster data to uint8 type.
     
     
     Args:
     Args:
         im (np.ndarray): Input raster image.
         im (np.ndarray): Input raster image.
-        is_linear (bool, optional): Use 2% linear stretch or not. Default is False.
+        stretch (bool, optional): Use 2% linear stretch or not. Default is False.
 
 
     Returns:
     Returns:
         np.ndarray: Image data with unit8 type.
         np.ndarray: Image data with unit8 type.
@@ -430,7 +430,7 @@ def to_uint8(im, is_linear=False):
     dtype = im.dtype.name
     dtype = im.dtype.name
     if dtype != "uint8":
     if dtype != "uint8":
         im = _sample_norm(im)
         im = _sample_norm(im)
-    if is_linear:
+    if stretch:
         im = _two_percent_linear(im)
         im = _two_percent_linear(im)
     return im
     return im
 
 

+ 9 - 5
paddlers/transforms/operators.py

@@ -197,11 +197,13 @@ class DecodeImg(Transform):
             uint8 type. Defaults to True.
             uint8 type. Defaults to True.
         decode_bgr (bool, optional): If True, automatically interpret a non-geo image 
         decode_bgr (bool, optional): If True, automatically interpret a non-geo image 
             (e.g., jpeg images) as a BGR image. Defaults to True.
             (e.g., jpeg images) as a BGR image. Defaults to True.
-        decode_sar (bool, optional): If True, automatically interpret a two-channel 
+        decode_sar (bool, optional): If True, automatically interpret a single-channel 
             geo image (e.g. geotiff images) as a SAR image, set this argument to 
             geo image (e.g. geotiff images) as a SAR image, set this argument to 
             True. Defaults to True.
             True. Defaults to True.
         read_geo_info (bool, optional): If True, read geographical information from 
         read_geo_info (bool, optional): If True, read geographical information from 
             the image. Deafults to False.
             the image. Deafults to False.
+        use_stretch (bool, optional): Whether to apply 2% linear stretch. Valid only if 
+            `to_uint8` is True. Defaults to False.
     """
     """
 
 
     def __init__(self,
     def __init__(self,
@@ -209,13 +211,15 @@ class DecodeImg(Transform):
                  to_uint8=True,
                  to_uint8=True,
                  decode_bgr=True,
                  decode_bgr=True,
                  decode_sar=True,
                  decode_sar=True,
-                 read_geo_info=False):
+                 read_geo_info=False,
+                 use_stretch=False):
         super(DecodeImg, self).__init__()
         super(DecodeImg, self).__init__()
         self.to_rgb = to_rgb
         self.to_rgb = to_rgb
         self.to_uint8 = to_uint8
         self.to_uint8 = to_uint8
         self.decode_bgr = decode_bgr
         self.decode_bgr = decode_bgr
         self.decode_sar = decode_sar
         self.decode_sar = decode_sar
-        self.read_geo_info = False
+        self.read_geo_info = read_geo_info
+        self.use_stretch = use_stretch
 
 
     def read_img(self, img_path):
     def read_img(self, img_path):
         img_format = imghdr.what(img_path)
         img_format = imghdr.what(img_path)
@@ -245,7 +249,7 @@ class DecodeImg(Transform):
                     im_data = im_data.transpose((1, 2, 0))
                     im_data = im_data.transpose((1, 2, 0))
             if self.read_geo_info:
             if self.read_geo_info:
                 geo_trans = dataset.GetGeoTransform()
                 geo_trans = dataset.GetGeoTransform()
-                geo_proj = dataset.GetGeoProjection()
+                geo_proj = dataset.GetProjection()
         elif img_format in ['jpeg', 'bmp', 'png', 'jpg']:
         elif img_format in ['jpeg', 'bmp', 'png', 'jpg']:
             if self.decode_bgr:
             if self.decode_bgr:
                 im_data = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
                 im_data = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
@@ -282,7 +286,7 @@ class DecodeImg(Transform):
             image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
             image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
 
         if self.to_uint8:
         if self.to_uint8:
-            image = F.to_uint8(image)
+            image = F.to_uint8(image, stretch=self.use_stretch)
 
 
         if self.read_geo_info:
         if self.read_geo_info:
             return image, geo_info_dict
             return image, geo_info_dict

+ 17 - 12
tests/deploy/test_predictor.py

@@ -151,8 +151,8 @@ class TestCDPredictor(TestPredictor):
 
 
         # Single input (ndarrays)
         # Single input (ndarrays)
         input_ = (decode_image(
         input_ = (decode_image(
-            t1_path, to_rgb=False), decode_image(
-                t2_path, to_rgb=False))  # Reuse the name `input_`
+            t1_path, read_raw=True), decode_image(
+                t2_path, read_raw=True))  # Reuse the name `input_`
         out_single_array_p = predictor.predict(input_, transforms=transforms)
         out_single_array_p = predictor.predict(input_, transforms=transforms)
         self.check_dict_equal(out_single_array_p, out_single_file_p)
         self.check_dict_equal(out_single_array_p, out_single_file_p)
         out_single_array_t = trainer.predict(input_, transforms=transforms)
         out_single_array_t = trainer.predict(input_, transforms=transforms)
@@ -175,8 +175,9 @@ class TestCDPredictor(TestPredictor):
 
 
         # Multiple inputs (ndarrays)
         # Multiple inputs (ndarrays)
         input_ = [(decode_image(
         input_ = [(decode_image(
-            t1_path, to_rgb=False), decode_image(
-                t2_path, to_rgb=False))] * num_inputs  # Reuse the name `input_`
+            t1_path, read_raw=True), decode_image(
+                t2_path,
+                read_raw=True))] * num_inputs  # Reuse the name `input_`
         out_multi_array_p = predictor.predict(input_, transforms=transforms)
         out_multi_array_p = predictor.predict(input_, transforms=transforms)
         self.assertEqual(len(out_multi_array_p), num_inputs)
         self.assertEqual(len(out_multi_array_p), num_inputs)
         out_multi_array_t = trainer.predict(input_, transforms=transforms)
         out_multi_array_t = trainer.predict(input_, transforms=transforms)
@@ -217,7 +218,7 @@ class TestClasPredictor(TestPredictor):
 
 
         # Single input (ndarray)
         # Single input (ndarray)
         input_ = decode_image(
         input_ = decode_image(
-            single_input, to_rgb=False)  # Reuse the name `input_`
+            single_input, read_raw=True)  # Reuse the name `input_`
         out_single_array_p = predictor.predict(input_, transforms=transforms)
         out_single_array_p = predictor.predict(input_, transforms=transforms)
         self.check_dict_equal(out_single_array_p, out_single_file_p)
         self.check_dict_equal(out_single_array_p, out_single_file_p)
         out_single_array_t = trainer.predict(input_, transforms=transforms)
         out_single_array_t = trainer.predict(input_, transforms=transforms)
@@ -241,7 +242,8 @@ class TestClasPredictor(TestPredictor):
 
 
         # Multiple inputs (ndarrays)
         # Multiple inputs (ndarrays)
         input_ = [decode_image(
         input_ = [decode_image(
-            single_input, to_rgb=False)] * num_inputs  # Reuse the name `input_`
+            single_input,
+            read_raw=True)] * num_inputs  # Reuse the name `input_`
         out_multi_array_p = predictor.predict(input_, transforms=transforms)
         out_multi_array_p = predictor.predict(input_, transforms=transforms)
         self.assertEqual(len(out_multi_array_p), num_inputs)
         self.assertEqual(len(out_multi_array_p), num_inputs)
         out_multi_array_t = trainer.predict(input_, transforms=transforms)
         out_multi_array_t = trainer.predict(input_, transforms=transforms)
@@ -282,7 +284,7 @@ class TestDetPredictor(TestPredictor):
 
 
         # Single input (ndarray)
         # Single input (ndarray)
         input_ = decode_image(
         input_ = decode_image(
-            single_input, to_rgb=False)  # Reuse the name `input_`
+            single_input, read_raw=True)  # Reuse the name `input_`
         predictor.predict(input_, transforms=transforms)
         predictor.predict(input_, transforms=transforms)
         trainer.predict(input_, transforms=transforms)
         trainer.predict(input_, transforms=transforms)
         out_single_array_list_p = predictor.predict(
         out_single_array_list_p = predictor.predict(
@@ -301,7 +303,8 @@ class TestDetPredictor(TestPredictor):
 
 
         # Multiple inputs (ndarrays)
         # Multiple inputs (ndarrays)
         input_ = [decode_image(
         input_ = [decode_image(
-            single_input, to_rgb=False)] * num_inputs  # Reuse the name `input_`
+            single_input,
+            read_raw=True)] * num_inputs  # Reuse the name `input_`
         out_multi_array_p = predictor.predict(input_, transforms=transforms)
         out_multi_array_p = predictor.predict(input_, transforms=transforms)
         self.assertEqual(len(out_multi_array_p), num_inputs)
         self.assertEqual(len(out_multi_array_p), num_inputs)
         out_multi_array_t = trainer.predict(input_, transforms=transforms)
         out_multi_array_t = trainer.predict(input_, transforms=transforms)
@@ -343,7 +346,7 @@ class TestResPredictor(TestPredictor):
 
 
         # Single input (ndarray)
         # Single input (ndarray)
         input_ = decode_image(
         input_ = decode_image(
-            single_input, to_rgb=False)  # Reuse the name `input_`
+            single_input, read_raw=True)  # Reuse the name `input_`
         predictor.predict(input_, transforms=transforms)
         predictor.predict(input_, transforms=transforms)
         trainer.predict(input_, transforms=transforms)
         trainer.predict(input_, transforms=transforms)
         out_single_array_list_p = predictor.predict(
         out_single_array_list_p = predictor.predict(
@@ -362,7 +365,8 @@ class TestResPredictor(TestPredictor):
 
 
         # Multiple inputs (ndarrays)
         # Multiple inputs (ndarrays)
         input_ = [decode_image(
         input_ = [decode_image(
-            single_input, to_rgb=False)] * num_inputs  # Reuse the name `input_`
+            single_input,
+            read_raw=True)] * num_inputs  # Reuse the name `input_`
         out_multi_array_p = predictor.predict(input_, transforms=transforms)
         out_multi_array_p = predictor.predict(input_, transforms=transforms)
         self.assertEqual(len(out_multi_array_p), num_inputs)
         self.assertEqual(len(out_multi_array_p), num_inputs)
         out_multi_array_t = trainer.predict(input_, transforms=transforms)
         out_multi_array_t = trainer.predict(input_, transforms=transforms)
@@ -400,7 +404,7 @@ class TestSegPredictor(TestPredictor):
 
 
         # Single input (ndarray)
         # Single input (ndarray)
         input_ = decode_image(
         input_ = decode_image(
-            single_input, to_rgb=False)  # Reuse the name `input_`
+            single_input, read_raw=True)  # Reuse the name `input_`
         out_single_array_p = predictor.predict(input_, transforms=transforms)
         out_single_array_p = predictor.predict(input_, transforms=transforms)
         self.check_dict_equal(out_single_array_p, out_single_file_p)
         self.check_dict_equal(out_single_array_p, out_single_file_p)
         out_single_array_t = trainer.predict(input_, transforms=transforms)
         out_single_array_t = trainer.predict(input_, transforms=transforms)
@@ -423,7 +427,8 @@ class TestSegPredictor(TestPredictor):
 
 
         # Multiple inputs (ndarrays)
         # Multiple inputs (ndarrays)
         input_ = [decode_image(
         input_ = [decode_image(
-            single_input, to_rgb=False)] * num_inputs  # Reuse the name `input_`
+            single_input,
+            read_raw=True)] * num_inputs  # Reuse the name `input_`
         out_multi_array_p = predictor.predict(input_, transforms=transforms)
         out_multi_array_p = predictor.predict(input_, transforms=transforms)
         self.assertEqual(len(out_multi_array_p), num_inputs)
         self.assertEqual(len(out_multi_array_p), num_inputs)
         out_multi_array_t = trainer.predict(input_, transforms=transforms)
         out_multi_array_t = trainer.predict(input_, transforms=transforms)

+ 1 - 0
tests/fast_tests.py

@@ -13,4 +13,5 @@
 # limitations under the License.
 # limitations under the License.
 
 
 from rs_models import *
 from rs_models import *
+from tasks import *
 from transforms import *
 from transforms import *

+ 2 - 0
tests/tasks/__init__.py

@@ -11,3 +11,5 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
+
+from .test_slider_predict import *

+ 212 - 0
tests/tasks/test_slider_predict.py

@@ -0,0 +1,212 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path as osp
+import tempfile
+
+import paddlers as pdrs
+import paddlers.transforms as T
+from testing_utils import CommonTest
+
+
+class _TestSliderPredictNamespace:
+    class TestSliderPredict(CommonTest):
+        def test_blocksize_and_overlap_whole(self):
+            # Original image size (256, 256)
+            with tempfile.TemporaryDirectory() as td:
+                # Whole-image inference using predict()
+                pred_whole = self.model.predict(self.image_path,
+                                                self.transforms)
+                pred_whole = pred_whole['label_map']
+
+                # Whole-image inference using slider_predict()
+                save_dir = osp.join(td, 'pred1')
+                self.model.slider_predict(self.image_path, save_dir, 256, 0,
+                                          self.transforms)
+                pred1 = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+                self.check_output_equal(pred1.shape, pred_whole.shape)
+
+                # `block_size` == `overlap`
+                save_dir = osp.join(td, 'pred2')
+                with self.assertRaises(ValueError):
+                    self.model.slider_predict(self.image_path, save_dir, 128,
+                                              128, self.transforms)
+
+                # `block_size` is a tuple
+                save_dir = osp.join(td, 'pred3')
+                self.model.slider_predict(self.image_path, save_dir, (128, 32),
+                                          0, self.transforms)
+                pred3 = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+                self.check_output_equal(pred3.shape, pred_whole.shape)
+
+                # `block_size` and `overlap` are both tuples
+                save_dir = osp.join(td, 'pred4')
+                self.model.slider_predict(self.image_path, save_dir, (128, 100),
+                                          (10, 5), self.transforms)
+                pred4 = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+                self.check_output_equal(pred4.shape, pred_whole.shape)
+
+                # `block_size` larger than image size
+                save_dir = osp.join(td, 'pred5')
+                with self.assertRaises(ValueError):
+                    self.model.slider_predict(self.image_path, save_dir, 512, 0,
+                                              self.transforms)
+
+        def test_merge_strategy(self):
+            with tempfile.TemporaryDirectory() as td:
+                # Whole-image inference using predict()
+                pred_whole = self.model.predict(self.image_path,
+                                                self.transforms)
+                pred_whole = pred_whole['label_map']
+
+                # 'keep_first'
+                save_dir = osp.join(td, 'keep_first')
+                self.model.slider_predict(
+                    self.image_path,
+                    save_dir,
+                    128,
+                    64,
+                    self.transforms,
+                    merge_strategy='keep_first')
+                pred_keepfirst = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+                self.check_output_equal(pred_keepfirst.shape, pred_whole.shape)
+
+                # 'keep_last'
+                save_dir = osp.join(td, 'keep_last')
+                self.model.slider_predict(
+                    self.image_path,
+                    save_dir,
+                    128,
+                    64,
+                    self.transforms,
+                    merge_strategy='keep_last')
+                pred_keeplast = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+                self.check_output_equal(pred_keeplast.shape, pred_whole.shape)
+
+                # 'accum'
+                save_dir = osp.join(td, 'accum')
+                self.model.slider_predict(
+                    self.image_path,
+                    save_dir,
+                    128,
+                    64,
+                    self.transforms,
+                    merge_strategy='accum')
+                pred_accum = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+                self.check_output_equal(pred_accum.shape, pred_whole.shape)
+
+        def test_geo_info(self):
+            with tempfile.TemporaryDirectory() as td:
+                _, geo_info_in = T.decode_image(
+                    self.ref_path, read_geo_info=True)
+                self.model.slider_predict(self.image_path, td, 128, 0,
+                                          self.transforms)
+                _, geo_info_out = T.decode_image(
+                    osp.join(td, self.basename), read_geo_info=True)
+                self.assertEqual(geo_info_out['geo_trans'],
+                                 geo_info_in['geo_trans'])
+                self.assertEqual(geo_info_out['geo_proj'],
+                                 geo_info_in['geo_proj'])
+
+        def test_batch_size(self):
+            with tempfile.TemporaryDirectory() as td:
+                # batch_size = 1
+                save_dir = osp.join(td, 'bs1')
+                self.model.slider_predict(
+                    self.image_path,
+                    save_dir,
+                    128,
+                    64,
+                    self.transforms,
+                    merge_strategy='keep_first',
+                    batch_size=1)
+                pred_bs1 = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+
+                # batch_size = 4
+                save_dir = osp.join(td, 'bs4')
+                self.model.slider_predict(
+                    self.image_path,
+                    save_dir,
+                    128,
+                    64,
+                    self.transforms,
+                    merge_strategy='keep_first',
+                    batch_size=4)
+                pred_bs4 = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+                self.check_output_equal(pred_bs4, pred_bs1)
+
+                # batch_size = 8
+                save_dir = osp.join(td, 'bs4')
+                self.model.slider_predict(
+                    self.image_path,
+                    save_dir,
+                    128,
+                    64,
+                    self.transforms,
+                    merge_strategy='keep_first',
+                    batch_size=8)
+                pred_bs8 = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+                self.check_output_equal(pred_bs8, pred_bs1)
+
+
+class TestSegSliderPredict(_TestSliderPredictNamespace.TestSliderPredict):
+    def setUp(self):
+        self.model = pdrs.tasks.seg.UNet(in_channels=10)
+        self.transforms = T.Compose([
+            T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10),
+            T.ArrangeSegmenter('test')
+        ])
+        self.image_path = "data/ssst/multispectral.tif"
+        self.ref_path = self.image_path
+        self.basename = osp.basename(self.ref_path)
+
+
+class TestCDSliderPredict(_TestSliderPredictNamespace.TestSliderPredict):
+    def setUp(self):
+        self.model = pdrs.tasks.cd.BIT(in_channels=10)
+        self.transforms = T.Compose([
+            T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10),
+            T.ArrangeChangeDetector('test')
+        ])
+        self.image_path = ("data/ssmt/multispectral_t1.tif",
+                           "data/ssmt/multispectral_t2.tif")
+        self.ref_path = self.image_path[0]
+        self.basename = osp.basename(self.ref_path)