Bladeren bron

Add accum strategy

Bobholamovic 2 jaren geleden
bovenliggende
commit
bf54499f5a
4 gewijzigde bestanden met toevoegingen van 301 en 277 verwijderingen
  1. 3 134
      paddlers/tasks/change_detector.py
  2. 9 128
      paddlers/tasks/segmenter.py
  3. 253 1
      paddlers/tasks/utils/slider_predict.py
  4. 36 14
      tests/tasks/test_slider_predict.py

+ 3 - 134
paddlers/tasks/change_detector.py

@@ -35,7 +35,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from .base import BaseModel
 from .utils import seg_metrics as metrics
 from .utils.infer_nets import InferCDNet
-from .utils.slider_predict import SlowCache as Cache
+from .utils.slider_predict import slider_predict
 
 __all__ = [
     "CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT",
@@ -606,139 +606,8 @@ class BaseChangeDetector(BaseModel):
                 there are conflicts in the overlapping pixels. Defaults to 'keep_last'.
         """
 
-        try:
-            from osgeo import gdal
-        except:
-            import gdal
-
-        if isinstance(block_size, int):
-            block_size = (block_size, block_size)
-        elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
-            block_size = tuple(block_size)
-        else:
-            raise ValueError(
-                "`block_size` must be a tuple/list of length 2 or an integer.")
-        if isinstance(overlap, int):
-            overlap = (overlap, overlap)
-        elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
-            overlap = tuple(overlap)
-        else:
-            raise ValueError(
-                "`overlap` must be a tuple/list of length 2 or an integer.")
-
-        if merge_strategy not in ('keep_first', 'keep_last', 'vote'):
-            raise ValueError(
-                "{} is not a supported stragegy for block merging.".format(
-                    merge_strategy))
-        if overlap == (0, 0):
-            # When there is no overlap, use 'keep_last' strategy as it introduces least overheads
-            merge_strategy = 'keep_last'
-        if merge_strategy == 'vote':
-            logging.warning(
-                "Currently, a naive Python-implemented cache is used for aggregating voting results. "
-                "For higher performance in inferring large images, please set `merge_strategy` to 'keep_first' or "
-                "'keep_last'.")
-            cache = Cache()
-
-        src1_data = gdal.Open(img_files[0])
-        src2_data = gdal.Open(img_files[1])
-
-        # Assume that two input images have the same size
-        width = src1_data.RasterXSize
-        height = src1_data.RasterYSize
-        bands = src1_data.RasterCount
-
-        driver = gdal.GetDriverByName("GTiff")
-        # Output name is the same as the name of the first image
-        file_name = osp.basename(osp.normpath(img_files[0]))
-        # Replace extension name with '.tif'
-        file_name = osp.splitext(file_name)[0] + ".tif"
-        if not osp.exists(save_dir):
-            os.makedirs(save_dir)
-        save_file = osp.join(save_dir, file_name)
-        dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
-
-        # Set meta-information (consistent with the first image)
-        dst_data.SetGeoTransform(src1_data.GetGeoTransform())
-        dst_data.SetProjection(src1_data.GetProjection())
-
-        band = dst_data.GetRasterBand(1)
-        band.WriteArray(
-            np.full(
-                (height, width), fill_value=invalid_value, dtype="uint8"))
-
-        prev_yoff, prev_xoff = None, None
-        prev_h, prev_w = None, None
-        step = np.array(
-            block_size, dtype=np.int32) - np.array(
-                overlap, dtype=np.int32)
-        if step[0] == 0 or step[1] == 0:
-            raise ValueError("`block_size` and `overlap` should not be equal.")
-        for yoff in range(0, height, step[1]):
-            for xoff in range(0, width, step[0]):
-                xsize, ysize = block_size
-                if xoff + xsize > width:
-                    xsize = int(width - xoff)
-                if yoff + ysize > height:
-                    ysize = int(height - yoff)
-
-                xoff = int(xoff)
-                yoff = int(yoff)
-                im1 = src1_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
-                    (1, 2, 0))
-                im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
-                    (1, 2, 0))
-                # Fill
-                h, w = im1.shape[:2]
-                im1_fill = np.zeros(
-                    (block_size[1], block_size[0], bands), dtype=im1.dtype)
-                im1_fill[:h, :w, :] = im1
-
-                im2_fill = np.zeros(
-                    (block_size[1], block_size[0], bands), dtype=im2.dtype)
-                im2_fill[:h, :w, :] = im2
-
-                # Predict
-                pred = self.predict((im1_fill, im2_fill), transforms)
-                pred = pred["label_map"].astype('uint8')
-                pred = pred[:h, :w]
-
-                # Deal with overlapping pixels
-                if merge_strategy == 'vote':
-                    cache.push_block(yoff, xoff, h, w, pred)
-                    pred = cache.get_block(yoff, xoff, h, w)
-                    pred = pred.astype('uint8')
-                    if prev_yoff is not None:
-                        pop_h = yoff - prev_yoff
-                    else:
-                        pop_h = 0
-                    if prev_xoff is not None:
-                        if xoff < prev_xoff:
-                            pop_w = prev_w
-                        else:
-                            pop_w = xoff - prev_xoff
-                    else:
-                        pop_w = 0
-                    cache.pop_block(prev_yoff, prev_xoff, pop_h, pop_w)
-                elif merge_strategy == 'keep_first':
-                    rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize)
-                    mask = rd_block != invalid_value
-                    pred = np.where(mask, rd_block, pred)
-                elif merge_strategy == 'keep_last':
-                    pass
-
-                # Write to file
-                band.WriteArray(pred, xoff, yoff)
-                dst_data.FlushCache()
-
-                prev_xoff = xoff
-                prev_w = w
-
-            prev_yoff = yoff
-            prev_h = h
-
-        dst_data = None
-        logging.info("GeoTiff file saved in {}.".format(save_file))
+        slider_predict(self, img_files, save_dir, block_size, overlap,
+                       transforms, invalid_value, merge_strategy)
 
     def preprocess(self, images, transforms, to_tensor=True):
         self._check_transforms(transforms, 'test')

+ 9 - 128
paddlers/tasks/segmenter.py

@@ -34,7 +34,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from .base import BaseModel
 from .utils import seg_metrics as metrics
 from .utils.infer_nets import InferSegNet
-from .utils.slider_predict import SlowCache as Cache
+from .utils.slider_predict import slider_predict
 
 __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"]
 
@@ -572,135 +572,16 @@ class BaseSegmenter(BaseModel):
             invalid_value (int, optional): Value that marks invalid pixels in output 
                 image. Defaults to 255.
             merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices
-                are {'keep_first', 'keep_last', 'vote'}. 'keep_first' and 'keep_last' 
-                means keeping the values of the first and the last block in traversal 
-                order, respectively. 'vote' means applying a simple voting strategy when 
-                there are conflicts in the overlapping pixels. Defaults to 'keep_last'.
+                are {'keep_first', 'keep_last', 'vote', 'accum'}. 'keep_first' and 
+                'keep_last' means keeping the values of the first and the last block in 
+                traversal order, respectively. 'vote' means applying a simple voting 
+                strategy when there are conflicts in the overlapping pixels. 'accum' 
+                means determining the class of an overlapping pixel according to 
+                accumulated probabilities. Defaults to 'keep_last'.
         """
 
-        try:
-            from osgeo import gdal
-        except:
-            import gdal
-
-        if isinstance(block_size, int):
-            block_size = (block_size, block_size)
-        elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
-            block_size = tuple(block_size)
-        else:
-            raise ValueError(
-                "`block_size` must be a tuple/list of length 2 or an integer.")
-        if isinstance(overlap, int):
-            overlap = (overlap, overlap)
-        elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
-            overlap = tuple(overlap)
-        else:
-            raise ValueError(
-                "`overlap` must be a tuple/list of length 2 or an integer.")
-
-        if merge_strategy not in ('keep_first', 'keep_last', 'vote'):
-            raise ValueError(
-                "{} is not a supported stragegy for block merging.".format(
-                    merge_strategy))
-        if overlap == (0, 0):
-            # When there is no overlap, use 'keep_last' strategy as it introduces least overheads
-            merge_strategy = 'keep_last'
-        if merge_strategy == 'vote':
-            logging.warning(
-                "Currently, a naive Python-implemented cache is used for aggregating voting results. "
-                "For higher performance in inferring large images, please set `merge_strategy` to 'keep_first' or "
-                "'keep_last'.")
-            cache = Cache()
-
-        src_data = gdal.Open(img_file)
-        width = src_data.RasterXSize
-        height = src_data.RasterYSize
-        bands = src_data.RasterCount
-
-        driver = gdal.GetDriverByName("GTiff")
-        file_name = osp.basename(osp.normpath(img_file))
-        # Replace extension name with '.tif'
-        file_name = osp.splitext(file_name)[0] + ".tif"
-        if not osp.exists(save_dir):
-            os.makedirs(save_dir)
-        save_file = osp.join(save_dir, file_name)
-        dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
-
-        # Set meta-information
-        dst_data.SetGeoTransform(src_data.GetGeoTransform())
-        dst_data.SetProjection(src_data.GetProjection())
-
-        band = dst_data.GetRasterBand(1)
-        band.WriteArray(
-            np.full(
-                (height, width), fill_value=invalid_value, dtype="uint8"))
-
-        prev_yoff, prev_xoff = None, None
-        prev_h, prev_w = None, None
-        step = np.array(
-            block_size, dtype=np.int32) - np.array(
-                overlap, dtype=np.int32)
-        if step[0] == 0 or step[1] == 0:
-            raise ValueError("`block_size` and `overlap` should not be equal.")
-        for yoff in range(0, height, step[1]):
-            for xoff in range(0, width, step[0]):
-                xsize, ysize = block_size
-                if xoff + xsize > width:
-                    xsize = int(width - xoff)
-                if yoff + ysize > height:
-                    ysize = int(height - yoff)
-
-                xoff = int(xoff)
-                yoff = int(yoff)
-                im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
-                    (1, 2, 0))
-                # Fill
-                h, w = im.shape[:2]
-                im_fill = np.zeros(
-                    (block_size[1], block_size[0], bands), dtype=im.dtype)
-                im_fill[:h, :w, :] = im
-
-                # Predict
-                pred = self.predict(im_fill, transforms)
-                pred = pred["label_map"].astype('uint8')
-                pred = pred[:h, :w]
-
-                # Deal with overlapping pixels
-                if merge_strategy == 'vote':
-                    cache.push_block(yoff, xoff, h, w, pred)
-                    pred = cache.get_block(yoff, xoff, h, w)
-                    pred = pred.astype('uint8')
-                    if prev_yoff is not None:
-                        pop_h = yoff - prev_yoff
-                    else:
-                        pop_h = 0
-                    if prev_xoff is not None:
-                        if xoff < prev_xoff:
-                            pop_w = prev_w
-                        else:
-                            pop_w = xoff - prev_xoff
-                    else:
-                        pop_w = 0
-                    cache.pop_block(prev_yoff, prev_xoff, pop_h, pop_w)
-                elif merge_strategy == 'keep_first':
-                    rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize)
-                    mask = rd_block != invalid_value
-                    pred = np.where(mask, rd_block, pred)
-                elif merge_strategy == 'keep_last':
-                    pass
-
-                # Write to file
-                band.WriteArray(pred, xoff, yoff)
-                dst_data.FlushCache()
-
-                prev_xoff = xoff
-                prev_w = w
-
-            prev_yoff = yoff
-            prev_h = h
-
-        dst_data = None
-        logging.info("GeoTiff file saved in {}.".format(save_file))
+        slider_predict(self, img_file, save_dir, block_size, overlap,
+                       transforms, invalid_value, merge_strategy)
 
     def preprocess(self, images, transforms, to_tensor=True):
         self._check_transforms(transforms, 'test')

+ 253 - 1
paddlers/tasks/utils/slider_predict.py

@@ -12,12 +12,23 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import os
+import os.path as osp
+from abc import ABCMeta, abstractmethod
 from collections import Counter, defaultdict
 
 import numpy as np
 
+import paddlers.utils.logging as logging
 
-class SlowCache(object):
+
+class Cache(metaclass=ABCMeta):
+    @abstractmethod
+    def get_block(self, i_st, j_st, h, w):
+        pass
+
+
+class SlowCache(Cache):
     def __init__(self):
         self.cache = defaultdict(Counter)
 
@@ -50,3 +61,244 @@ class SlowCache(object):
                 row.append(self.get_pixel(i, j))
             block.append(row)
         return np.asarray(block)
+
+
+class ProbCache(Cache):
+    def __init__(self, h, w, ch, cw, sh, sw, dtype=np.float32, order='c'):
+        self.cache = None
+        self.h = h
+        self.w = w
+        self.ch = ch
+        self.cw = cw
+        self.sh = sh
+        self.sw = sw
+        if not issubclass(dtype, np.floating):
+            raise TypeError("`dtype` must be one of the floating types.")
+        self.dtype = dtype
+        order = order.lower()
+        if order not in ('c', 'f'):
+            raise ValueError("`order` other than 'c' and 'f' is not supported.")
+        self.order = order
+
+    def _alloc_memory(self, nc):
+        if self.order == 'c':
+            # Colomn-first order (C-style)
+            #
+            # <-- cw -->
+            # |--------|---------------------|^    ^
+            # |                              ||    | sh
+            # |--------|---------------------|| ch v
+            # |                              || 
+            # |--------|---------------------|v
+            # <------------ w --------------->
+            self.cache = np.zeros((self.ch, self.w, nc), dtype=self.dtype)
+        elif self.order == 'f':
+            # Row-first order (Fortran-style)
+            #
+            # <-- sw -->
+            # <---- cw ---->
+            # |--------|---|^   ^
+            # |        |   ||   |
+            # |        |   ||   ch
+            # |        |   ||   |
+            # |--------|---|| h v
+            # |        |   ||
+            # |        |   ||
+            # |        |   ||
+            # |--------|---|v
+            self.cache = np.zeros((self.h, self.cw, nc), dtype=self.dtype)
+
+    def update_block(self, i_st, j_st, h, w, prob_map):
+        if self.cache is None:
+            nc = prob_map.shape[2]
+            # Lazy allocation of memory
+            self._alloc_memory(nc)
+        self.cache[i_st:i_st + h, j_st:j_st + w] += prob_map
+
+    def roll_cache(self):
+        if self.order == 'c':
+            self.cache = np.roll(self.cache, -self.sh, axis=0)
+            self.cache[self.sh:self.ch, :] = 0
+        elif self.order == 'f':
+            self.cache = np.roll(self.cache, -self.sw, axis=1)
+            self.cache[:, self.sw:self.cw] = 0
+
+    def get_block(self, i_st, j_st, h, w):
+        return np.argmax(self.cache[i_st:i_st + h, j_st:j_st + w], axis=2)
+
+
+def slider_predict(predictor, img_file, save_dir, block_size, overlap,
+                   transforms, invalid_value, merge_strategy):
+    """
+    Do inference using sliding windows.
+
+    Args:
+        predictor (object): Object that implements `predict()` method.
+        img_file (str|tuple[str]): Image path(s).
+        save_dir (str): Directory that contains saved geotiff file.
+        block_size (list[int] | tuple[int] | int):
+            Size of block. If `block_size` is list or tuple, it should be in 
+            (W, H) format.
+        overlap (list[int] | tuple[int] | int):
+            Overlap between two blocks. If `overlap` is list or tuple, it should
+            be in (W, H) format.
+        transforms (paddlers.transforms.Compose|None): Transforms for inputs. If 
+            None, the transforms for evaluation process will be used. 
+        invalid_value (int): Value that marks invalid pixels in output image. 
+            Defaults to 255.
+        merge_strategy (str): Strategy to merge overlapping blocks. Choices are 
+            {'keep_first', 'keep_last', 'vote', 'accum'}. 'keep_first' and 
+            'keep_last' means keeping the values of the first and the last block in 
+            traversal order, respectively. 'vote' means applying a simple voting 
+            strategy when there are conflicts in the overlapping pixels. 'accum' 
+            means determining the class of an overlapping pixel according to 
+            accumulated probabilities.
+    """
+
+    try:
+        from osgeo import gdal
+    except:
+        import gdal
+
+    if isinstance(block_size, int):
+        block_size = (block_size, block_size)
+    elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
+        block_size = tuple(block_size)
+    else:
+        raise ValueError(
+            "`block_size` must be a tuple/list of length 2 or an integer.")
+    if isinstance(overlap, int):
+        overlap = (overlap, overlap)
+    elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
+        overlap = tuple(overlap)
+    else:
+        raise ValueError(
+            "`overlap` must be a tuple/list of length 2 or an integer.")
+
+    if merge_strategy not in ('keep_first', 'keep_last', 'vote', 'accum'):
+        raise ValueError("{} is not a supported stragegy for block merging.".
+                         format(merge_strategy))
+
+    step = np.array(
+        block_size, dtype=np.int32) - np.array(
+            overlap, dtype=np.int32)
+    if step[0] == 0 or step[1] == 0:
+        raise ValueError("`block_size` and `overlap` should not be equal.")
+
+    if isinstance(img_file, tuple):
+        if len(img_file) != 2:
+            raise ValueError("Tuple `img_file` must have the length of two.")
+        # Assume that two input images have the same size
+        src_data = gdal.Open(img_file[0])
+        src2_data = gdal.Open(img_file[1])
+        # Output name is the same as the name of the first image
+        file_name = osp.basename(osp.normpath(img_file[0]))
+    else:
+        src_data = gdal.Open(img_file)
+        file_name = osp.basename(osp.normpath(img_file))
+
+    # Get size of original raster
+    width = src_data.RasterXSize
+    height = src_data.RasterYSize
+    bands = src_data.RasterCount
+
+    if block_size[0] > width or block_size[1] > height:
+        raise ValueError("`block_size` should not be larger than image size.")
+
+    driver = gdal.GetDriverByName("GTiff")
+    if not osp.exists(save_dir):
+        os.makedirs(save_dir)
+    # Replace extension name with '.tif'
+    file_name = osp.splitext(file_name)[0] + ".tif"
+    save_file = osp.join(save_dir, file_name)
+    dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
+
+    # Set meta-information
+    dst_data.SetGeoTransform(src_data.GetGeoTransform())
+    dst_data.SetProjection(src_data.GetProjection())
+
+    # Initialize raster with `invalid_value`
+    band = dst_data.GetRasterBand(1)
+    band.WriteArray(
+        np.full(
+            (height, width), fill_value=invalid_value, dtype="uint8"))
+
+    if overlap == (0, 0) or block_size == (width, height):
+        # When there is no overlap or the whole image is used as input, 
+        # use 'keep_last' strategy as it introduces least overheads
+        merge_strategy = 'keep_last'
+    if merge_strategy == 'vote':
+        logging.warning(
+            "Currently, a naive Python-implemented cache is used for aggregating voting results. "
+            "For higher performance in inferring large images, please set `merge_strategy` to 'keep_first', "
+            "'keep_last', or 'accum'.")
+        cache = SlowCache()
+    elif merge_strategy == 'accum':
+        cache = ProbCache(height, width, *block_size, *step)
+
+    prev_yoff, prev_xoff = None, None
+
+    for yoff in range(0, height, step[1]):
+        for xoff in range(0, width, step[0]):
+            xsize, ysize = block_size
+            if xoff + xsize > width:
+                xoff = width - xsize
+            if yoff + ysize > height:
+                yoff = height - ysize
+
+            # Read and fill
+            im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
+                (1, 2, 0))
+
+            if isinstance(img_file, tuple):
+                im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
+                    (1, 2, 0))
+                # Predict
+                out = predictor.predict((im, im2), transforms)
+            else:
+                # Predict
+                out = predictor.predict(im, transforms)
+
+            pred = out['label_map'].astype('uint8')
+            pred = pred[:ysize, :xsize]
+
+            # Deal with overlapping pixels
+            if merge_strategy == 'vote':
+                cache.push_block(yoff, xoff, ysize, xsize, pred)
+                pred = cache.get_block(yoff, xoff, ysize, xsize)
+                pred = pred.astype('uint8')
+                if prev_yoff is not None:
+                    pop_h = yoff - prev_yoff
+                else:
+                    pop_h = 0
+                if prev_xoff is not None:
+                    if xoff < prev_xoff:
+                        pop_w = xsize
+                    else:
+                        pop_w = xoff - prev_xoff
+                else:
+                    pop_w = 0
+                cache.pop_block(prev_yoff, prev_xoff, pop_h, pop_w)
+            elif merge_strategy == 'keep_first':
+                rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize)
+                mask = rd_block != invalid_value
+                pred = np.where(mask, rd_block, pred)
+            elif merge_strategy == 'keep_last':
+                pass
+            elif merge_strategy == 'accum':
+                prob = out['score_map']
+                prob = prob[:ysize, :xsize]
+                cache.update_block(0, yoff, ysize, xsize, prob)
+                pred = cache.get_block(0, yoff, ysize, xsize)
+                if xoff + step[0] >= width:
+                    cache.roll_cache()
+
+            # Write to file
+            band.WriteArray(pred, xoff, yoff)
+            dst_data.FlushCache()
+
+            prev_xoff = xoff
+        prev_yoff = yoff
+
+    dst_data = None
+    logging.info("GeoTiff file saved in {}.".format(save_file))

+ 36 - 14
tests/tasks/test_slider_predict.py

@@ -75,13 +75,9 @@ class TestSegSliderPredict(CommonTest):
 
             # `block_size` larger than image size
             save_dir = osp.join(td, 'pred5')
-            self.model.slider_predict(self.image_path, save_dir, 512, 0,
-                                      self.transforms)
-            pred5 = T.decode_image(
-                osp.join(save_dir, self.basename),
-                to_uint8=False,
-                decode_sar=False)
-            self.check_output_equal(pred5.shape, pred_whole.shape)
+            with self.assertRaises(ValueError):
+                self.model.slider_predict(self.image_path, save_dir, 512, 0,
+                                          self.transforms)
 
     def test_merge_strategy(self):
         with tempfile.TemporaryDirectory() as td:
@@ -134,6 +130,21 @@ class TestSegSliderPredict(CommonTest):
                 decode_sar=False)
             self.check_output_equal(pred_vote.shape, pred_whole.shape)
 
+            # 'accum'
+            save_dir = osp.join(td, 'accum')
+            self.model.slider_predict(
+                self.image_path,
+                save_dir,
+                128,
+                64,
+                self.transforms,
+                merge_strategy='vote')
+            pred_accum = T.decode_image(
+                osp.join(save_dir, self.basename),
+                to_uint8=False,
+                decode_sar=False)
+            self.check_output_equal(pred_accum.shape, pred_whole.shape)
+
     def test_geo_info(self):
         with tempfile.TemporaryDirectory() as td:
             _, geo_info_in = T.decode_image(self.image_path, read_geo_info=True)
@@ -202,13 +213,9 @@ class TestCDSliderPredict(CommonTest):
 
             # `block_size` larger than image size
             save_dir = osp.join(td, 'pred5')
-            self.model.slider_predict(self.image_paths, save_dir, 512, 0,
-                                      self.transforms)
-            pred5 = T.decode_image(
-                osp.join(save_dir, self.basename),
-                to_uint8=False,
-                decode_sar=False)
-            self.check_output_equal(pred5.shape, pred_whole.shape)
+            with self.assertRaises(ValueError):
+                self.model.slider_predict(self.image_paths, save_dir, 512, 0,
+                                          self.transforms)
 
     def test_merge_strategy(self):
         with tempfile.TemporaryDirectory() as td:
@@ -261,6 +268,21 @@ class TestCDSliderPredict(CommonTest):
                 decode_sar=False)
             self.check_output_equal(pred_vote.shape, pred_whole.shape)
 
+            # 'accum'
+            save_dir = osp.join(td, 'accum')
+            self.model.slider_predict(
+                self.image_paths,
+                save_dir,
+                128,
+                64,
+                self.transforms,
+                merge_strategy='vote')
+            pred_accum = T.decode_image(
+                osp.join(save_dir, self.basename),
+                to_uint8=False,
+                decode_sar=False)
+            self.check_output_equal(pred_accum.shape, pred_whole.shape)
+
     def test_geo_info(self):
         with tempfile.TemporaryDirectory() as td:
             _, geo_info_in = T.decode_image(