Browse Source

Remove vote mode and fix bugs

Bobholamovic 3 năm trước cách đây
mục cha
commit
7a0f5405f6

+ 6 - 1
docs/apis/infer.md

@@ -170,7 +170,7 @@ def slider_predict(self,
 |`overlap`|`list[int]` \| `tuple[int]` \| `int`|滑窗的滑动步长(以列表或元组指定宽度、高度或以一个整数指定相同的宽高)。|`36`|
 |`overlap`|`list[int]` \| `tuple[int]` \| `int`|滑窗的滑动步长(以列表或元组指定宽度、高度或以一个整数指定相同的宽高)。|`36`|
 |`transforms`|`paddlers.transforms.Compose` \| `None`|对输入数据应用的数据变换算子。若为`None`,则使用训练器在验证阶段使用的数据变换算子。|`None`|
 |`transforms`|`paddlers.transforms.Compose` \| `None`|对输入数据应用的数据变换算子。若为`None`,则使用训练器在验证阶段使用的数据变换算子。|`None`|
 |`invalid_value`|`int`|输出影像中用于标记无效像素的数值。|`255`|
 |`invalid_value`|`int`|输出影像中用于标记无效像素的数值。|`255`|
-|`merge_strategy`|`str`|合并滑窗重叠区域使用的策略。`'keep_first'`表示保留遍历顺序(从左至右,从上往下,列优先)最靠前的窗口的预测类别;`'keep_last'`表示保留遍历顺序最靠后的窗口的预测类别;`'vote'`表示使用投票策略,即对于每个像素,最终预测类别为所有覆盖该像素的滑窗给出的预测类别中出现频率最高者;`'accum'`表示通过将各窗口在重叠区域给出的预测概率累加,计算最终预测类别。需要注意的是,在对大尺寸影像进行`overlap`较大的密集推理时,使用`'vote'`策略可能导致较长的推理时间。|`'keep_last'`|
+|`merge_strategy`|`str`|合并滑窗重叠区域使用的策略。`'keep_first'`表示保留遍历顺序(从左至右,从上往下,列优先)最靠前的窗口的预测类别;`'keep_last'`表示保留遍历顺序最靠后的窗口的预测类别;`'accum'`表示通过将各窗口在重叠区域给出的预测概率累加,计算最终预测类别。需要注意的是,在对大尺寸影像进行`overlap`较大的密集推理时,使用`'accum'`策略可能导致较长的推理时间,但一般能够在窗口交界部分取得更好的表现。|`'keep_last'`|
 
 
 变化检测任务的滑窗推理API与图像分割任务类似,但需要注意的是输出结果中存储的地理变换、投影等信息以从第一时相影像中读取的信息为准,存储滑窗推理结果的文件名也与第一时相影像文件相同。
 变化检测任务的滑窗推理API与图像分割任务类似,但需要注意的是输出结果中存储的地理变换、投影等信息以从第一时相影像中读取的信息为准,存储滑窗推理结果的文件名也与第一时相影像文件相同。
 
 
@@ -220,5 +220,10 @@ def predict(self,
 |`transforms`|`paddlers.transforms.Compose`\|`None`|对输入数据应用的数据变换算子。若为`None`,则使用从`model.yml`中读取的算子。|`None`|
 |`transforms`|`paddlers.transforms.Compose`\|`None`|对输入数据应用的数据变换算子。若为`None`,则使用从`model.yml`中读取的算子。|`None`|
 |`warmup_iters`|`int`|预热轮数,用于评估模型推理以及前后处理速度。若大于1,将预先重复执行`warmup_iters`次推理,而后才开始正式的预测及其速度评估。|`0`|
 |`warmup_iters`|`int`|预热轮数,用于评估模型推理以及前后处理速度。若大于1,将预先重复执行`warmup_iters`次推理,而后才开始正式的预测及其速度评估。|`0`|
 |`repeats`|`int`|重复次数,用于评估模型推理以及前后处理速度。若大于1,将执行`repeats`次预测并取时间平均值。|`1`|
 |`repeats`|`int`|重复次数,用于评估模型推理以及前后处理速度。若大于1,将执行`repeats`次预测并取时间平均值。|`1`|
+|`quiet`|`bool`|若为`True`,不打印计时信息。|`False`|
 
 
 `Predictor.predict()`的返回格式与相应的动态图推理API的返回格式完全相同,详情请参考[动态图推理API](#动态图推理api)。
 `Predictor.predict()`的返回格式与相应的动态图推理API的返回格式完全相同,详情请参考[动态图推理API](#动态图推理api)。
+
+### `Predictor.slider_predict()`
+
+实现滑窗推理功能。用法与`BaseSegmenter`和`BaseChangeDetector`的`slider_predict()`方法相同。

+ 51 - 4
paddlers/deploy/predictor.py

@@ -14,6 +14,7 @@
 
 
 import os.path as osp
 import os.path as osp
 from operator import itemgetter
 from operator import itemgetter
+from functools import partial
 
 
 import numpy as np
 import numpy as np
 import paddle
 import paddle
@@ -23,6 +24,7 @@ from paddle.inference import PrecisionType
 
 
 from paddlers.tasks import load_model
 from paddlers.tasks import load_model
 from paddlers.utils import logging, Timer
 from paddlers.utils import logging, Timer
+from paddlers.tasks.utils.slider_predict import slider_predict
 
 
 
 
 class Predictor(object):
 class Predictor(object):
@@ -271,22 +273,24 @@ class Predictor(object):
                 topk=1,
                 topk=1,
                 transforms=None,
                 transforms=None,
                 warmup_iters=0,
                 warmup_iters=0,
-                repeats=1):
+                repeats=1,
+                quiet=False):
         """
         """
-        Do prediction.
+        Do inference.
 
 
         Args:
         Args:
             img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For scene classification, image restoration, 
             img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For scene classification, image restoration, 
                 object detection and semantic segmentation tasks, `img_file` should be either the path of the image to predict,
                 object detection and semantic segmentation tasks, `img_file` should be either the path of the image to predict,
                 a decoded image (a np.ndarray, which should be consistent with what you get from passing image path to
                 a decoded image (a np.ndarray, which should be consistent with what you get from passing image path to
                 paddlers.transforms.decode_image()), or a list of image paths or decoded images. For change detection tasks,
                 paddlers.transforms.decode_image()), or a list of image paths or decoded images. For change detection tasks,
-                img_file should be a tuple of image paths, a tuple of decoded images, or a list of tuples.
+                `img_file` should be a tuple of image paths, a tuple of decoded images, or a list of tuples.
             topk(int, optional): Top-k values to reserve in a classification result. Defaults to 1.
             topk(int, optional): Top-k values to reserve in a classification result. Defaults to 1.
             transforms (paddlers.transforms.Compose|None, optional): Pipeline of data preprocessing. If None, load transforms
             transforms (paddlers.transforms.Compose|None, optional): Pipeline of data preprocessing. If None, load transforms
                 from `model.yml`. Defaults to None.
                 from `model.yml`. Defaults to None.
             warmup_iters (int, optional): Warm-up iterations before measuring the execution time. Defaults to 0.
             warmup_iters (int, optional): Warm-up iterations before measuring the execution time. Defaults to 0.
             repeats (int, optional): Number of repetitions to evaluate model inference and data processing speed. If greater than
             repeats (int, optional): Number of repetitions to evaluate model inference and data processing speed. If greater than
                 1, the reported time consumption is the average of all repeats. Defaults to 1.
                 1, the reported time consumption is the average of all repeats. Defaults to 1.
+            quiet (bool, optional): If True, do not display the timing information. Defaults to False.
         """
         """
 
 
         if repeats < 1:
         if repeats < 1:
@@ -313,12 +317,55 @@ class Predictor(object):
 
 
         self.timer.repeats = repeats
         self.timer.repeats = repeats
         self.timer.img_num = len(images)
         self.timer.img_num = len(images)
-        self.timer.info(average=True)
+        if not quiet:
+            self.timer.info(average=True)
 
 
         if isinstance(img_file, (str, np.ndarray, tuple)):
         if isinstance(img_file, (str, np.ndarray, tuple)):
             results = results[0]
             results = results[0]
 
 
         return results
         return results
 
 
+    def slider_predict(self,
+                       img_file,
+                       save_dir,
+                       block_size,
+                       overlap=36,
+                       transforms=None,
+                       invalid_value=255,
+                       merge_strategy='keep_last'):
+        """
+        Do inference using sliding windows. Only semantic segmentation and change detection models are supported in the 
+            sliding-predicting mode.
+
+        Args:
+            img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For semantic segmentation tasks, `img_file` 
+                should be either the path of the image to predict, a decoded image (a np.ndarray, which should be 
+                consistent with what you get from passing image path to paddlers.transforms.decode_image()), or a list of 
+                image paths or decoded images. For change detection tasks, `img_file` should be a tuple of image paths, a 
+                tuple of decoded images, or a list of tuples.
+            save_dir (str): Directory that contains saved geotiff file.
+            block_size (list[int] | tuple[int] | int): Size of block. If `block_size` is a list or tuple, it should be in 
+                (W, H) format.
+            overlap (list[int] | tuple[int] | int, optional): Overlap between two blocks. If `overlap` is a list or tuple, 
+                it should be in (W, H) format. Defaults to 36.
+            transforms (paddlers.transforms.Compose|None, optional): Pipeline of data preprocessing. If None, load transforms
+                from `model.yml`. Defaults to None.
+            invalid_value (int, optional): Value that marks invalid pixels in output image. Defaults to 255.
+            merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices are 
+                {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' means keeping the values of the first and 
+                the last block in traversal order, respectively. 'accum' means determining the class of an overlapping pixel 
+                according to accumulated probabilities. Defaults to 'keep_last'.
+        """
+        slider_predict(
+            partial(
+                self.predict, quiet=True),
+            img_file,
+            save_dir,
+            block_size,
+            overlap,
+            transforms,
+            invalid_value,
+            merge_strategy)
+
     def batch_predict(self, image_list, **params):
     def batch_predict(self, image_list, **params):
         return self.predict(img_file=image_list, **params)
         return self.predict(img_file=image_list, **params)

+ 17 - 46
paddlers/tasks/utils/slider_predict.py

@@ -118,22 +118,22 @@ class ProbCache(Cache):
     def roll_cache(self):
     def roll_cache(self):
         if self.order == 'c':
         if self.order == 'c':
             self.cache = np.roll(self.cache, -self.sh, axis=0)
             self.cache = np.roll(self.cache, -self.sh, axis=0)
-            self.cache[self.sh:self.ch, :] = 0
+            self.cache[-self.sh:, :] = 0
         elif self.order == 'f':
         elif self.order == 'f':
             self.cache = np.roll(self.cache, -self.sw, axis=1)
             self.cache = np.roll(self.cache, -self.sw, axis=1)
-            self.cache[:, self.sw:self.cw] = 0
+            self.cache[:, -self.sw:] = 0
 
 
     def get_block(self, i_st, j_st, h, w):
     def get_block(self, i_st, j_st, h, w):
         return np.argmax(self.cache[i_st:i_st + h, j_st:j_st + w], axis=2)
         return np.argmax(self.cache[i_st:i_st + h, j_st:j_st + w], axis=2)
 
 
 
 
-def slider_predict(predictor, img_file, save_dir, block_size, overlap,
+def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
                    transforms, invalid_value, merge_strategy):
                    transforms, invalid_value, merge_strategy):
     """
     """
     Do inference using sliding windows.
     Do inference using sliding windows.
 
 
     Args:
     Args:
-        predictor (object): Object that implements `predict()` method.
+        predict_func (callable): A callable object that makes the prediction.
         img_file (str|tuple[str]): Image path(s).
         img_file (str|tuple[str]): Image path(s).
         save_dir (str): Directory that contains saved geotiff file.
         save_dir (str): Directory that contains saved geotiff file.
         block_size (list[int] | tuple[int] | int):
         block_size (list[int] | tuple[int] | int):
@@ -147,12 +147,10 @@ def slider_predict(predictor, img_file, save_dir, block_size, overlap,
         invalid_value (int): Value that marks invalid pixels in output image. 
         invalid_value (int): Value that marks invalid pixels in output image. 
             Defaults to 255.
             Defaults to 255.
         merge_strategy (str): Strategy to merge overlapping blocks. Choices are 
         merge_strategy (str): Strategy to merge overlapping blocks. Choices are 
-            {'keep_first', 'keep_last', 'vote', 'accum'}. 'keep_first' and 
-            'keep_last' means keeping the values of the first and the last block in 
-            traversal order, respectively. 'vote' means applying a simple voting 
-            strategy when there are conflicts in the overlapping pixels. 'accum' 
-            means determining the class of an overlapping pixel according to 
-            accumulated probabilities.
+            {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' 
+            means keeping the values of the first and the last block in 
+            traversal order, respectively. 'accum' means determining the class 
+            of an overlapping pixel according to accumulated probabilities.
     """
     """
 
 
     try:
     try:
@@ -175,7 +173,7 @@ def slider_predict(predictor, img_file, save_dir, block_size, overlap,
         raise ValueError(
         raise ValueError(
             "`overlap` must be a tuple/list of length 2 or an integer.")
             "`overlap` must be a tuple/list of length 2 or an integer.")
 
 
-    if merge_strategy not in ('keep_first', 'keep_last', 'vote', 'accum'):
+    if merge_strategy not in ('keep_first', 'keep_last', 'accum'):
         raise ValueError("{} is not a supported stragegy for block merging.".
         raise ValueError("{} is not a supported stragegy for block merging.".
                          format(merge_strategy))
                          format(merge_strategy))
 
 
@@ -227,16 +225,8 @@ def slider_predict(predictor, img_file, save_dir, block_size, overlap,
         # When there is no overlap or the whole image is used as input, 
         # When there is no overlap or the whole image is used as input, 
         # use 'keep_last' strategy as it introduces least overheads
         # use 'keep_last' strategy as it introduces least overheads
         merge_strategy = 'keep_last'
         merge_strategy = 'keep_last'
-    if merge_strategy == 'vote':
-        logging.warning(
-            "Currently, a naive Python-implemented cache is used for aggregating voting results. "
-            "For higher performance in inferring large images, please set `merge_strategy` to 'keep_first', "
-            "'keep_last', or 'accum'.")
-        cache = SlowCache()
-    elif merge_strategy == 'accum':
-        cache = ProbCache(height, width, *block_size, *step)
-
-    prev_yoff, prev_xoff = None, None
+    if merge_strategy == 'accum':
+        cache = ProbCache(height, width, *block_size[::-1], *step[::-1])
 
 
     for yoff in range(0, height, step[1]):
     for yoff in range(0, height, step[1]):
         for xoff in range(0, width, step[0]):
         for xoff in range(0, width, step[0]):
@@ -254,32 +244,16 @@ def slider_predict(predictor, img_file, save_dir, block_size, overlap,
                 im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
                 im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
                     (1, 2, 0))
                     (1, 2, 0))
                 # Predict
                 # Predict
-                out = predictor.predict((im, im2), transforms)
+                out = predict_func((im, im2), transforms=transforms)
             else:
             else:
                 # Predict
                 # Predict
-                out = predictor.predict(im, transforms)
+                out = predict_func(im, transforms=transforms)
 
 
             pred = out['label_map'].astype('uint8')
             pred = out['label_map'].astype('uint8')
             pred = pred[:ysize, :xsize]
             pred = pred[:ysize, :xsize]
 
 
             # Deal with overlapping pixels
             # Deal with overlapping pixels
-            if merge_strategy == 'vote':
-                cache.push_block(yoff, xoff, ysize, xsize, pred)
-                pred = cache.get_block(yoff, xoff, ysize, xsize)
-                pred = pred.astype('uint8')
-                if prev_yoff is not None:
-                    pop_h = yoff - prev_yoff
-                else:
-                    pop_h = 0
-                if prev_xoff is not None:
-                    if xoff < prev_xoff:
-                        pop_w = xsize
-                    else:
-                        pop_w = xoff - prev_xoff
-                else:
-                    pop_w = 0
-                cache.pop_block(prev_yoff, prev_xoff, pop_h, pop_w)
-            elif merge_strategy == 'keep_first':
+            if merge_strategy == 'keep_first':
                 rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize)
                 rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize)
                 mask = rd_block != invalid_value
                 mask = rd_block != invalid_value
                 pred = np.where(mask, rd_block, pred)
                 pred = np.where(mask, rd_block, pred)
@@ -288,17 +262,14 @@ def slider_predict(predictor, img_file, save_dir, block_size, overlap,
             elif merge_strategy == 'accum':
             elif merge_strategy == 'accum':
                 prob = out['score_map']
                 prob = out['score_map']
                 prob = prob[:ysize, :xsize]
                 prob = prob[:ysize, :xsize]
-                cache.update_block(0, yoff, ysize, xsize, prob)
-                pred = cache.get_block(0, yoff, ysize, xsize)
-                if xoff + step[0] >= width:
+                cache.update_block(0, xoff, ysize, xsize, prob)
+                pred = cache.get_block(0, xoff, ysize, xsize)
+                if xoff + xsize >= width:
                     cache.roll_cache()
                     cache.roll_cache()
 
 
             # Write to file
             # Write to file
             band.WriteArray(pred, xoff, yoff)
             band.WriteArray(pred, xoff, yoff)
             dst_data.FlushCache()
             dst_data.FlushCache()
 
 
-            prev_xoff = xoff
-        prev_yoff = yoff
-
     dst_data = None
     dst_data = None
     logging.info("GeoTiff file saved in {}.".format(save_file))
     logging.info("GeoTiff file saved in {}.".format(save_file))

+ 2 - 32
tests/tasks/test_slider_predict.py

@@ -115,21 +115,6 @@ class TestSegSliderPredict(CommonTest):
                 decode_sar=False)
                 decode_sar=False)
             self.check_output_equal(pred_keeplast.shape, pred_whole.shape)
             self.check_output_equal(pred_keeplast.shape, pred_whole.shape)
 
 
-            # 'vote'
-            save_dir = osp.join(td, 'vote')
-            self.model.slider_predict(
-                self.image_path,
-                save_dir,
-                128,
-                64,
-                self.transforms,
-                merge_strategy='vote')
-            pred_vote = T.decode_image(
-                osp.join(save_dir, self.basename),
-                to_uint8=False,
-                decode_sar=False)
-            self.check_output_equal(pred_vote.shape, pred_whole.shape)
-
             # 'accum'
             # 'accum'
             save_dir = osp.join(td, 'accum')
             save_dir = osp.join(td, 'accum')
             self.model.slider_predict(
             self.model.slider_predict(
@@ -138,7 +123,7 @@ class TestSegSliderPredict(CommonTest):
                 128,
                 128,
                 64,
                 64,
                 self.transforms,
                 self.transforms,
-                merge_strategy='vote')
+                merge_strategy='accum')
             pred_accum = T.decode_image(
             pred_accum = T.decode_image(
                 osp.join(save_dir, self.basename),
                 osp.join(save_dir, self.basename),
                 to_uint8=False,
                 to_uint8=False,
@@ -253,21 +238,6 @@ class TestCDSliderPredict(CommonTest):
                 decode_sar=False)
                 decode_sar=False)
             self.check_output_equal(pred_keeplast.shape, pred_whole.shape)
             self.check_output_equal(pred_keeplast.shape, pred_whole.shape)
 
 
-            # 'vote'
-            save_dir = osp.join(td, 'vote')
-            self.model.slider_predict(
-                self.image_paths,
-                save_dir,
-                128,
-                64,
-                self.transforms,
-                merge_strategy='vote')
-            pred_vote = T.decode_image(
-                osp.join(save_dir, self.basename),
-                to_uint8=False,
-                decode_sar=False)
-            self.check_output_equal(pred_vote.shape, pred_whole.shape)
-
             # 'accum'
             # 'accum'
             save_dir = osp.join(td, 'accum')
             save_dir = osp.join(td, 'accum')
             self.model.slider_predict(
             self.model.slider_predict(
@@ -276,7 +246,7 @@ class TestCDSliderPredict(CommonTest):
                 128,
                 128,
                 64,
                 64,
                 self.transforms,
                 self.transforms,
-                merge_strategy='vote')
+                merge_strategy='accum')
             pred_accum = T.decode_image(
             pred_accum = T.decode_image(
                 osp.join(save_dir, self.basename),
                 osp.join(save_dir, self.basename),
                 to_uint8=False,
                 to_uint8=False,