Browse Source

Enhance slider_predict()

Bobholamovic 2 years ago
parent
commit
f7a4ebc58d

+ 15 - 0
paddlers/tasks/base.py

@@ -677,3 +677,18 @@ class BaseModel(metaclass=ModelMeta):
             raise ValueError(
                 f"Incorrect arrange mode! Expected {mode} but got {arrange_obj.mode}."
             )
+
+    def run(self, net, inputs, mode):
+        raise NotImplementedError
+
+    def train(self, *args, **kwargs):
+        raise NotImplementedError
+
+    def evaluate(self, *args, **kwargs):
+        raise NotImplementedError
+
+    def preprocess(self, images, transforms, to_tensor):
+        raise NotImplementedError
+
+    def postprocess(self, *args, **kwargs):
+        raise NotImplementedError

+ 108 - 33
paddlers/tasks/change_detector.py

@@ -35,6 +35,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from .base import BaseModel
 from .utils import seg_metrics as metrics
 from .utils.infer_nets import InferCDNet
+from .utils.slider_predict import SlowCache as Cache
 
 __all__ = [
     "CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT",
@@ -574,22 +575,35 @@ class BaseChangeDetector(BaseModel):
         return prediction
 
     def slider_predict(self,
-                       img_file,
+                       img_files,
                        save_dir,
                        block_size,
                        overlap=36,
-                       transforms=None):
+                       transforms=None,
+                       invalid_value=255,
+                       merge_strategy='keep_last'):
         """
-        Do inference.
+        Do inference using sliding windows.
 
         Args:
-            img_file (tuple[str]): Tuple of image paths.
+            img_files (tuple[str]): Tuple of image paths.
             save_dir (str): Directory that contains saved geotiff file.
-            block_size (list[int] | tuple[int] | int, optional): Size of block.
-            overlap (list[int] | tuple[int] | int, optional): Overlap between two blocks. 
-                Defaults to 36.
-            transforms (paddlers.transforms.Compose|None, optional): Transforms for inputs.
-                If None, the transforms for evaluation process will be used. Defaults to None.
+            block_size (list[int] | tuple[int] | int):
+                Size of block. If `block_size` is a list or tuple, it should be in 
+                (W, H) format.
+            overlap (list[int] | tuple[int] | int, optional):
+                Overlap between two blocks. If `overlap` is a list or tuple, it should
+                be in (W, H) format. Defaults to 36.
+            transforms (paddlers.transforms.Compose|None, optional): Transforms for 
+                inputs. If None, the transforms for evaluation process will be used. 
+                Defaults to None.
+            invalid_value (int, optional): Value that marks invalid pixels in output 
+                image. Defaults to 255.
+            merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices
+                are {'keep_first', 'keep_last', 'vote'}. 'keep_first' and 'keep_last' 
+                means keeping the values of the first and the last block in traversal 
+                order, respectively. 'vote' means applying a simple voting strategy when 
+                there are conflicts in the overlapping pixels. Defaults to 'keep_last'.
         """
 
         try:
@@ -597,8 +611,6 @@ class BaseChangeDetector(BaseModel):
         except:
             import gdal
 
-        if not isinstance(img_file, tuple) or len(img_file) != 2:
-            raise ValueError("`img_file` must be a tuple of length 2.")
         if isinstance(block_size, int):
             block_size = (block_size, block_size)
         elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
@@ -614,25 +626,54 @@ class BaseChangeDetector(BaseModel):
             raise ValueError(
                 "`overlap` must be a tuple/list of length 2 or an integer.")
 
-        src1_data = gdal.Open(img_file[0])
-        src2_data = gdal.Open(img_file[1])
+        if merge_strategy not in ('keep_first', 'keep_last', 'vote'):
+            raise ValueError(
+                "{} is not a supported stragegy for block merging.".format(
+                    merge_strategy))
+        if overlap == (0, 0):
+            # When there is no overlap, use 'keep_last' strategy as it introduces least overheads
+            merge_strategy = 'keep_last'
+        if merge_strategy == 'vote':
+            logging.warning(
+                "Currently, a naive Python-implemented cache is used for aggregating voting results. "
+                "For higher performance in inferring large images, please set `merge_strategy` to 'keep_first' or "
+                "'keep_last'.")
+            cache = Cache()
+
+        src1_data = gdal.Open(img_files[0])
+        src2_data = gdal.Open(img_files[1])
+
+        # Assume that two input images have the same size
         width = src1_data.RasterXSize
         height = src1_data.RasterYSize
         bands = src1_data.RasterCount
 
         driver = gdal.GetDriverByName("GTiff")
-        file_name = osp.splitext(osp.normpath(img_file[0]).split(os.sep)[-1])[
-            0] + ".tif"
+        # Output name is the same as the name of the first image
+        file_name = osp.basename(osp.normpath(img_files[0]))
+        # Replace extension name with '.tif'
+        file_name = osp.splitext(file_name)[0] + ".tif"
         if not osp.exists(save_dir):
             os.makedirs(save_dir)
         save_file = osp.join(save_dir, file_name)
         dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
+
+        # Set meta-information (consistent with the first image)
         dst_data.SetGeoTransform(src1_data.GetGeoTransform())
         dst_data.SetProjection(src1_data.GetProjection())
-        band = dst_data.GetRasterBand(1)
-        band.WriteArray(255 * np.ones((height, width), dtype="uint8"))
 
-        step = np.array(block_size) - np.array(overlap)
+        band = dst_data.GetRasterBand(1)
+        band.WriteArray(
+            np.full(
+                (height, width), fill_value=invalid_value, dtype="uint8"))
+
+        prev_yoff, prev_xoff = None, None
+        prev_h, prev_w = None, None
+        step = np.array(
+            block_size, dtype=np.int32) - np.array(
+                overlap, dtype=np.int32)
+        if step[0] == 0 or step[1] == 0:
+            raise ValueError("`block_size` and `overlap` should not be equal.")
         for yoff in range(0, height, step[1]):
             for xoff in range(0, width, step[0]):
                 xsize, ysize = block_size
@@ -640,30 +681,64 @@ class BaseChangeDetector(BaseModel):
                     xsize = int(width - xoff)
                 if yoff + ysize > height:
                     ysize = int(height - yoff)
-                im1 = src1_data.ReadAsArray(
-                    int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0))
-                im2 = src2_data.ReadAsArray(
-                    int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0))
+
+                xoff = int(xoff)
+                yoff = int(yoff)
+                im1 = src1_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
+                    (1, 2, 0))
+                im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
+                    (1, 2, 0))
                 # Fill
                 h, w = im1.shape[:2]
                 im1_fill = np.zeros(
                     (block_size[1], block_size[0], bands), dtype=im1.dtype)
-                im2_fill = im1_fill.copy()
                 im1_fill[:h, :w, :] = im1
+
+                im2_fill = np.zeros(
+                    (block_size[1], block_size[0], bands), dtype=im2.dtype)
                 im2_fill[:h, :w, :] = im2
-                im_fill = (im1_fill, im2_fill)
+
                 # Predict
-                pred = self.predict(im_fill,
-                                    transforms)["label_map"].astype("uint8")
-                # Overlap
-                rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize)
-                mask = (rd_block == pred[:h, :w]) | (rd_block == 255)
-                temp = pred[:h, :w].copy()
-                temp[mask == False] = 0
-                band.WriteArray(temp, int(xoff), int(yoff))
+                pred = self.predict((im1_fill, im2_fill), transforms)
+                pred = pred["label_map"].astype('uint8')
+                pred = pred[:h, :w]
+
+                # Deal with overlapping pixels
+                if merge_strategy == 'vote':
+                    cache.push_block(yoff, xoff, h, w, pred)
+                    pred = cache.get_block(yoff, xoff, h, w)
+                    pred = pred.astype('uint8')
+                    if prev_yoff is not None:
+                        pop_h = yoff - prev_yoff
+                    else:
+                        pop_h = 0
+                    if prev_xoff is not None:
+                        if xoff < prev_xoff:
+                            pop_w = prev_w
+                        else:
+                            pop_w = xoff - prev_xoff
+                    else:
+                        pop_w = 0
+                    cache.pop_block(prev_yoff, prev_xoff, pop_h, pop_w)
+                elif merge_strategy == 'keep_first':
+                    rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize)
+                    mask = rd_block != invalid_value
+                    pred = np.where(mask, rd_block, pred)
+                elif merge_strategy == 'keep_last':
+                    pass
+
+                # Write to file
+                band.WriteArray(pred, xoff, yoff)
                 dst_data.FlushCache()
+
+                prev_xoff = xoff
+                prev_w = w
+
+            prev_yoff = yoff
+            prev_h = h
+
         dst_data = None
-        print("GeoTiff saved in {}.".format(save_file))
+        logging.info("GeoTiff file saved in {}.".format(save_file))
 
     def preprocess(self, images, transforms, to_tensor=True):
         self._check_transforms(transforms, 'test')

+ 91 - 20
paddlers/tasks/segmenter.py

@@ -34,6 +34,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from .base import BaseModel
 from .utils import seg_metrics as metrics
 from .utils.infer_nets import InferSegNet
+from .utils.slider_predict import SlowCache as Cache
 
 __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"]
 
@@ -550,20 +551,31 @@ class BaseSegmenter(BaseModel):
                        save_dir,
                        block_size,
                        overlap=36,
-                       transforms=None):
+                       transforms=None,
+                       invalid_value=255,
+                       merge_strategy='keep_last'):
         """
-        Do inference.
+        Do inference using sliding windows.
 
         Args:
             img_file (str): Image path.
             save_dir (str): Directory that contains saved geotiff file.
             block_size (list[int] | tuple[int] | int):
-                Size of block.
+                Size of block. If `block_size` is list or tuple, it should be in 
+                (W, H) format.
             overlap (list[int] | tuple[int] | int, optional):
-                Overlap between two blocks. Defaults to 36.
+                Overlap between two blocks. If `overlap` is list or tuple, it should
+                be in (W, H) format. Defaults to 36.
             transforms (paddlers.transforms.Compose|None, optional): Transforms for 
                 inputs. If None, the transforms for evaluation process will be used. 
                 Defaults to None.
+            invalid_value (int, optional): Value that marks invalid pixels in output 
+                image. Defaults to 255.
+            merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices
+                are {'keep_first', 'keep_last', 'vote'}. 'keep_first' and 'keep_last' 
+                means keeping the values of the first and the last block in traversal 
+                order, respectively. 'vote' means applying a simple voting strategy when 
+                there are conflicts in the overlapping pixels. Defaults to 'keep_last'.
         """
 
         try:
@@ -586,24 +598,50 @@ class BaseSegmenter(BaseModel):
             raise ValueError(
                 "`overlap` must be a tuple/list of length 2 or an integer.")
 
+        if merge_strategy not in ('keep_first', 'keep_last', 'vote'):
+            raise ValueError(
+                "{} is not a supported stragegy for block merging.".format(
+                    merge_strategy))
+        if overlap == (0, 0):
+            # When there is no overlap, use 'keep_last' strategy as it introduces least overheads
+            merge_strategy = 'keep_last'
+        if merge_strategy == 'vote':
+            logging.warning(
+                "Currently, a naive Python-implemented cache is used for aggregating voting results. "
+                "For higher performance in inferring large images, please set `merge_strategy` to 'keep_first' or "
+                "'keep_last'.")
+            cache = Cache()
+
         src_data = gdal.Open(img_file)
         width = src_data.RasterXSize
         height = src_data.RasterYSize
         bands = src_data.RasterCount
 
         driver = gdal.GetDriverByName("GTiff")
-        file_name = osp.splitext(osp.normpath(img_file).split(os.sep)[-1])[
-            0] + ".tif"
+        file_name = osp.basename(osp.normpath(img_file))
+        # Replace extension name with '.tif'
+        file_name = osp.splitext(file_name)[0] + ".tif"
         if not osp.exists(save_dir):
             os.makedirs(save_dir)
         save_file = osp.join(save_dir, file_name)
         dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
+
+        # Set meta-information
         dst_data.SetGeoTransform(src_data.GetGeoTransform())
         dst_data.SetProjection(src_data.GetProjection())
-        band = dst_data.GetRasterBand(1)
-        band.WriteArray(255 * np.ones((height, width), dtype="uint8"))
 
-        step = np.array(block_size) - np.array(overlap)
+        band = dst_data.GetRasterBand(1)
+        band.WriteArray(
+            np.full(
+                (height, width), fill_value=invalid_value, dtype="uint8"))
+
+        prev_yoff, prev_xoff = None, None
+        prev_h, prev_w = None, None
+        step = np.array(
+            block_size, dtype=np.int32) - np.array(
+                overlap, dtype=np.int32)
+        if step[0] == 0 or step[1] == 0:
+            raise ValueError("`block_size` and `overlap` should not be equal.")
         for yoff in range(0, height, step[1]):
             for xoff in range(0, width, step[0]):
                 xsize, ysize = block_size
@@ -611,25 +649,58 @@ class BaseSegmenter(BaseModel):
                     xsize = int(width - xoff)
                 if yoff + ysize > height:
                     ysize = int(height - yoff)
-                im = src_data.ReadAsArray(int(xoff), int(yoff), xsize,
-                                          ysize).transpose((1, 2, 0))
+
+                xoff = int(xoff)
+                yoff = int(yoff)
+                im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
+                    (1, 2, 0))
                 # Fill
                 h, w = im.shape[:2]
                 im_fill = np.zeros(
                     (block_size[1], block_size[0], bands), dtype=im.dtype)
                 im_fill[:h, :w, :] = im
+
                 # Predict
-                pred = self.predict(im_fill,
-                                    transforms)["label_map"].astype("uint8")
-                # Overlap
-                rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize)
-                mask = (rd_block == pred[:h, :w]) | (rd_block == 255)
-                temp = pred[:h, :w].copy()
-                temp[mask == False] = 0
-                band.WriteArray(temp, int(xoff), int(yoff))
+                pred = self.predict(im_fill, transforms)
+                pred = pred["label_map"].astype('uint8')
+                pred = pred[:h, :w]
+
+                # Deal with overlapping pixels
+                if merge_strategy == 'vote':
+                    cache.push_block(yoff, xoff, h, w, pred)
+                    pred = cache.get_block(yoff, xoff, h, w)
+                    pred = pred.astype('uint8')
+                    if prev_yoff is not None:
+                        pop_h = yoff - prev_yoff
+                    else:
+                        pop_h = 0
+                    if prev_xoff is not None:
+                        if xoff < prev_xoff:
+                            pop_w = prev_w
+                        else:
+                            pop_w = xoff - prev_xoff
+                    else:
+                        pop_w = 0
+                    cache.pop_block(prev_yoff, prev_xoff, pop_h, pop_w)
+                elif merge_strategy == 'keep_first':
+                    rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize)
+                    mask = rd_block != invalid_value
+                    pred = np.where(mask, rd_block, pred)
+                elif merge_strategy == 'keep_last':
+                    pass
+
+                # Write to file
+                band.WriteArray(pred, xoff, yoff)
                 dst_data.FlushCache()
+
+                prev_xoff = xoff
+                prev_w = w
+
+            prev_yoff = yoff
+            prev_h = h
+
         dst_data = None
-        print("GeoTiff saved in {}.".format(save_file))
+        logging.info("GeoTiff file saved in {}.".format(save_file))
 
     def preprocess(self, images, transforms, to_tensor=True):
         self._check_transforms(transforms, 'test')

+ 52 - 0
paddlers/tasks/utils/slider_predict.py

@@ -0,0 +1,52 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import Counter, defaultdict
+
+import numpy as np
+
+
+class SlowCache(object):
+    def __init__(self):
+        self.cache = defaultdict(Counter)
+
+    def push_pixel(self, i, j, l):
+        self.cache[(i, j)][l] += 1
+
+    def push_block(self, i_st, j_st, h, w, data):
+        for i in range(0, h):
+            for j in range(0, w):
+                self.push_pixel(i_st + i, j_st + j, data[i, j])
+
+    def pop_pixel(self, i, j):
+        self.cache.pop((i, j))
+
+    def pop_block(self, i_st, j_st, h, w):
+        for i in range(0, h):
+            for j in range(0, w):
+                self.pop_pixel(i_st + i, j_st + j)
+
+    def get_pixel(self, i, j):
+        winners = self.cache[(i, j)].most_common(1)
+        winner = winners[0]
+        return winner[0]
+
+    def get_block(self, i_st, j_st, h, w):
+        block = []
+        for i in range(i_st, i_st + h):
+            row = []
+            for j in range(j_st, j_st + w):
+                row.append(self.get_pixel(i, j))
+            block.append(row)
+        return np.asarray(block)

+ 2 - 2
paddlers/transforms/operators.py

@@ -197,7 +197,7 @@ class DecodeImg(Transform):
         self.to_uint8 = to_uint8
         self.decode_bgr = decode_bgr
         self.decode_sar = decode_sar
-        self.read_geo_info = False
+        self.read_geo_info = read_geo_info
 
     def read_img(self, img_path):
         img_format = imghdr.what(img_path)
@@ -227,7 +227,7 @@ class DecodeImg(Transform):
                     im_data = im_data.transpose((1, 2, 0))
             if self.read_geo_info:
                 geo_trans = dataset.GetGeoTransform()
-                geo_proj = dataset.GetGeoProjection()
+                geo_proj = dataset.GetProjection()
         elif img_format in ['jpeg', 'bmp', 'png', 'jpg']:
             if self.decode_bgr:
                 im_data = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |

+ 1 - 0
tests/fast_tests.py

@@ -13,4 +13,5 @@
 # limitations under the License.
 
 from rs_models import *
+from tasks import *
 from transforms import *

+ 2 - 0
tests/tasks/__init__.py

@@ -11,3 +11,5 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
+from .test_slider_predict import *

+ 274 - 0
tests/tasks/test_slider_predict.py

@@ -0,0 +1,274 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path as osp
+import tempfile
+
+import paddlers as pdrs
+import paddlers.transforms as T
+from testing_utils import CommonTest
+
+
+class TestSegSliderPredict(CommonTest):
+    def setUp(self):
+        self.model = pdrs.tasks.seg.UNet(in_channels=10)
+        self.transforms = T.Compose([
+            T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10),
+            T.ArrangeSegmenter('test')
+        ])
+        self.image_path = "data/ssst/multispectral.tif"
+        self.basename = osp.basename(self.image_path)
+
+    def test_blocksize_and_overlap_whole(self):
+        # Original image size (256, 256)
+        with tempfile.TemporaryDirectory() as td:
+            # Whole-image inference using predict()
+            pred_whole = self.model.predict(self.image_path, self.transforms)
+            pred_whole = pred_whole['label_map']
+
+            # Whole-image inference using slider_predict()
+            save_dir = osp.join(td, 'pred1')
+            self.model.slider_predict(self.image_path, save_dir, 256, 0,
+                                      self.transforms)
+            pred1 = T.decode_image(
+                osp.join(save_dir, self.basename),
+                to_uint8=False,
+                decode_sar=False)
+            self.check_output_equal(pred1.shape, pred_whole.shape)
+
+            # `block_size` == `overlap`
+            save_dir = osp.join(td, 'pred2')
+            with self.assertRaises(ValueError):
+                self.model.slider_predict(self.image_path, save_dir, 128, 128,
+                                          self.transforms)
+
+            # `block_size` is a tuple
+            save_dir = osp.join(td, 'pred3')
+            self.model.slider_predict(self.image_path, save_dir, (128, 32), 0,
+                                      self.transforms)
+            pred3 = T.decode_image(
+                osp.join(save_dir, self.basename),
+                to_uint8=False,
+                decode_sar=False)
+            self.check_output_equal(pred3.shape, pred_whole.shape)
+
+            # `block_size` and `overlap` are both tuples
+            save_dir = osp.join(td, 'pred4')
+            self.model.slider_predict(self.image_path, save_dir, (128, 100),
+                                      (10, 5), self.transforms)
+            pred4 = T.decode_image(
+                osp.join(save_dir, self.basename),
+                to_uint8=False,
+                decode_sar=False)
+            self.check_output_equal(pred4.shape, pred_whole.shape)
+
+            # `block_size` larger than image size
+            save_dir = osp.join(td, 'pred5')
+            self.model.slider_predict(self.image_path, save_dir, 512, 0,
+                                      self.transforms)
+            pred5 = T.decode_image(
+                osp.join(save_dir, self.basename),
+                to_uint8=False,
+                decode_sar=False)
+            self.check_output_equal(pred5.shape, pred_whole.shape)
+
+    def test_merge_strategy(self):
+        with tempfile.TemporaryDirectory() as td:
+            # Whole-image inference using predict()
+            pred_whole = self.model.predict(self.image_path, self.transforms)
+            pred_whole = pred_whole['label_map']
+
+            # 'keep_first'
+            save_dir = osp.join(td, 'keep_first')
+            self.model.slider_predict(
+                self.image_path,
+                save_dir,
+                128,
+                64,
+                self.transforms,
+                merge_strategy='keep_first')
+            pred_keepfirst = T.decode_image(
+                osp.join(save_dir, self.basename),
+                to_uint8=False,
+                decode_sar=False)
+            self.check_output_equal(pred_keepfirst.shape, pred_whole.shape)
+
+            # 'keep_last'
+            save_dir = osp.join(td, 'keep_last')
+            self.model.slider_predict(
+                self.image_path,
+                save_dir,
+                128,
+                64,
+                self.transforms,
+                merge_strategy='keep_last')
+            pred_keeplast = T.decode_image(
+                osp.join(save_dir, self.basename),
+                to_uint8=False,
+                decode_sar=False)
+            self.check_output_equal(pred_keeplast.shape, pred_whole.shape)
+
+            # 'vote'
+            save_dir = osp.join(td, 'vote')
+            self.model.slider_predict(
+                self.image_path,
+                save_dir,
+                128,
+                64,
+                self.transforms,
+                merge_strategy='vote')
+            pred_vote = T.decode_image(
+                osp.join(save_dir, self.basename),
+                to_uint8=False,
+                decode_sar=False)
+            self.check_output_equal(pred_vote.shape, pred_whole.shape)
+
+    def test_geo_info(self):
+        with tempfile.TemporaryDirectory() as td:
+            _, geo_info_in = T.decode_image(self.image_path, read_geo_info=True)
+            self.model.slider_predict(self.image_path, td, 128, 0,
+                                      self.transforms)
+            _, geo_info_out = T.decode_image(
+                osp.join(td, self.basename), read_geo_info=True)
+            self.assertEqual(geo_info_out['geo_trans'],
+                             geo_info_in['geo_trans'])
+            self.assertEqual(geo_info_out['geo_proj'], geo_info_in['geo_proj'])
+
+
+class TestCDSliderPredict(CommonTest):
+    def setUp(self):
+        self.model = pdrs.tasks.cd.BIT(in_channels=10)
+        self.transforms = T.Compose([
+            T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10),
+            T.ArrangeChangeDetector('test')
+        ])
+        self.image_paths = ("data/ssmt/multispectral_t1.tif",
+                            "data/ssmt/multispectral_t2.tif")
+        self.basename = osp.basename(self.image_paths[0])
+
+    def test_blocksize_and_overlap_whole(self):
+        # Original image size (256, 256)
+        with tempfile.TemporaryDirectory() as td:
+            # Whole-image inference using predict()
+            pred_whole = self.model.predict(self.image_paths, self.transforms)
+            pred_whole = pred_whole['label_map']
+
+            # Whole-image inference using slider_predict()
+            save_dir = osp.join(td, 'pred1')
+            self.model.slider_predict(self.image_paths, save_dir, 256, 0,
+                                      self.transforms)
+            pred1 = T.decode_image(
+                osp.join(save_dir, self.basename),
+                to_uint8=False,
+                decode_sar=False)
+            self.check_output_equal(pred1.shape, pred_whole.shape)
+
+            # `block_size` == `overlap`
+            save_dir = osp.join(td, 'pred2')
+            with self.assertRaises(ValueError):
+                self.model.slider_predict(self.image_paths, save_dir, 128, 128,
+                                          self.transforms)
+
+            # `block_size` is a tuple
+            save_dir = osp.join(td, 'pred3')
+            self.model.slider_predict(self.image_paths, save_dir, (128, 32), 0,
+                                      self.transforms)
+            pred3 = T.decode_image(
+                osp.join(save_dir, self.basename),
+                to_uint8=False,
+                decode_sar=False)
+            self.check_output_equal(pred3.shape, pred_whole.shape)
+
+            # `block_size` and `overlap` are both tuples
+            save_dir = osp.join(td, 'pred4')
+            self.model.slider_predict(self.image_paths, save_dir, (128, 100),
+                                      (10, 5), self.transforms)
+            pred4 = T.decode_image(
+                osp.join(save_dir, self.basename),
+                to_uint8=False,
+                decode_sar=False)
+            self.check_output_equal(pred4.shape, pred_whole.shape)
+
+            # `block_size` larger than image size
+            save_dir = osp.join(td, 'pred5')
+            self.model.slider_predict(self.image_paths, save_dir, 512, 0,
+                                      self.transforms)
+            pred5 = T.decode_image(
+                osp.join(save_dir, self.basename),
+                to_uint8=False,
+                decode_sar=False)
+            self.check_output_equal(pred5.shape, pred_whole.shape)
+
+    def test_merge_strategy(self):
+        with tempfile.TemporaryDirectory() as td:
+            # Whole-image inference using predict()
+            pred_whole = self.model.predict(self.image_paths, self.transforms)
+            pred_whole = pred_whole['label_map']
+
+            # 'keep_first'
+            save_dir = osp.join(td, 'keep_first')
+            self.model.slider_predict(
+                self.image_paths,
+                save_dir,
+                128,
+                64,
+                self.transforms,
+                merge_strategy='keep_first')
+            pred_keepfirst = T.decode_image(
+                osp.join(save_dir, self.basename),
+                to_uint8=False,
+                decode_sar=False)
+            self.check_output_equal(pred_keepfirst.shape, pred_whole.shape)
+
+            # 'keep_last'
+            save_dir = osp.join(td, 'keep_last')
+            self.model.slider_predict(
+                self.image_paths,
+                save_dir,
+                128,
+                64,
+                self.transforms,
+                merge_strategy='keep_last')
+            pred_keeplast = T.decode_image(
+                osp.join(save_dir, self.basename),
+                to_uint8=False,
+                decode_sar=False)
+            self.check_output_equal(pred_keeplast.shape, pred_whole.shape)
+
+            # 'vote'
+            save_dir = osp.join(td, 'vote')
+            self.model.slider_predict(
+                self.image_paths,
+                save_dir,
+                128,
+                64,
+                self.transforms,
+                merge_strategy='vote')
+            pred_vote = T.decode_image(
+                osp.join(save_dir, self.basename),
+                to_uint8=False,
+                decode_sar=False)
+            self.check_output_equal(pred_vote.shape, pred_whole.shape)
+
+    def test_geo_info(self):
+        with tempfile.TemporaryDirectory() as td:
+            _, geo_info_in = T.decode_image(
+                self.image_paths[0], read_geo_info=True)
+            self.model.slider_predict(self.image_paths, td, 128, 0,
+                                      self.transforms)
+            _, geo_info_out = T.decode_image(
+                osp.join(td, self.basename), read_geo_info=True)
+            self.assertEqual(geo_info_out['geo_trans'],
+                             geo_info_in['geo_trans'])
+            self.assertEqual(geo_info_out['geo_proj'], geo_info_in['geo_proj'])