Просмотр исходного кода

[Feat] Add swelling predict (#92)

* [Feat] Init add swelling predict

* [Feat] Init add swelling predict

* [Refactor] Update OverlapProcessor

* [Fix] Fix update_batch_offsets
Yizhou Chen 2 лет назад
Родитель
Сommit
cb49a9b0ee
2 измененных файлов с 78 добавлено и 9 удалено
  1. 63 9
      paddlers/tasks/utils/slider_predict.py
  2. 15 0
      tests/tasks/test_slider_predict.py

+ 63 - 9
paddlers/tasks/utils/slider_predict.py

@@ -141,6 +141,9 @@ class OverlapProcessor(metaclass=ABCMeta):
         self.sh = sh
         self.sw = sw
 
+    def update_batch_offsets(self, xoff, yoff):
+        return xoff, yoff
+
     @abstractmethod
     def process_pred(self, out, xoff, yoff):
         pass
@@ -200,6 +203,21 @@ class AccumProcessor(OverlapProcessor):
         return pred
 
 
+class SwellProcessor(OverlapProcessor):
+    def __init__(self, h, w, ch, cw, sh, sw, oh, ow):
+        super(SwellProcessor, self).__init__(h, w, ch, cw, sh, sw)
+        self.oh = oh
+        self.ow = ow
+
+    def update_batch_offsets(self, xoff, yoff):
+        return xoff + self.oh, yoff + self.ow
+
+    def process_pred(self, out, xoff, yoff):
+        pred = out['label_map']
+        pred = pred[self.oh:self.ch - self.oh, self.ow:self.cw - self.ow]
+        return pred
+
+
 def assign_border_weights(array, weight=0.5, border_ratio=0.25, inplace=True):
     if not inplace:
         array = array.copy()
@@ -216,6 +234,8 @@ class BlockReader(metaclass=ABCMeta):
     def __init__(self, ds):
         super().__init__()
         self.ds = ds
+        self.ww = self.ds.RasterXSize
+        self.wh = self.ds.RasterYSize
 
     @abstractmethod
     def read_block(self, xoff, yoff, xsize, ysize):
@@ -233,6 +253,22 @@ class BlockReader(metaclass=ABCMeta):
             tar_xsize = xsize
         if tar_ysize is None:
             tar_ysize = ysize
+        # Negative index correction
+        lxpad = 0
+        lypad = 0
+        if xoff < 0:
+            lxpad = -xoff
+            xsize -= lxpad
+            xoff = 0
+        if yoff < 0:
+            lypad = -yoff
+            ysize -= lypad
+            yoff = 0
+        # Out of index correction
+        if xoff + xsize > self.ww:
+            xsize = self.ww - xoff
+        if yoff + ysize > self.wh:
+            ysize = self.wh - yoff
         block = self.read_block(xoff, yoff, xsize, ysize)
         c, real_ysize, real_xsize = block.shape
         assert real_ysize == ysize and real_xsize == xsize
@@ -246,7 +282,8 @@ class BlockReader(metaclass=ABCMeta):
                 fill_value=pad_val,
                 dtype=block.dtype)
             # Fill
-            padded_block[:real_ysize, :real_xsize] = block
+            padded_block[lypad:real_ysize + lypad, lxpad:real_xsize +
+                         lxpad] = block
             return padded_block
         else:
             return block
@@ -298,10 +335,11 @@ def slider_predict(predict_func,
         invalid_value (int): Value that marks invalid pixels in output image. 
             Defaults to 255.
         merge_strategy (str): Strategy to merge overlapping blocks. Choices are 
-            {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' 
+            {'keep_first', 'keep_last', 'accum', 'swell'}. 'keep_first' and 'keep_last' 
             means keeping the values of the first and the last block in 
             traversal order, respectively. 'accum' means determining the class 
             of an overlapping pixel according to accumulated probabilities.
+            'swell' means keeping only the center part of each block prediction.
         batch_size (int): Batch size used in inference.
         eager_load (bool, optional): Whether to load the whole image(s) eagerly.
             Defaults to False.
@@ -363,6 +401,9 @@ def slider_predict(predict_func,
     height = src_data.RasterYSize
     bands = src_data.RasterCount
 
+    start = (0, 0)
+    end = (width, height)
+
     # XXX: GDAL read behavior conforms to paddlers.transforms.decode_image(read_raw=True)
     # except for SAR images.
     if bands == 1:
@@ -410,6 +451,13 @@ def slider_predict(predict_func,
     elif merge_strategy == 'accum':
         overlap_processor = AccumProcessor(height, width, *block_size[::-1],
                                            *step[::-1])
+    elif merge_strategy == 'swell':
+        start = tuple([-o for o in overlap])
+        end = tuple([o + e for o, e in zip(overlap, end)])
+        step = np.array(block_size, dtype=np.int32)
+        block_size = tuple([b + 2 * o for b, o in zip(block_size, overlap)])
+        overlap_processor = SwellProcessor(height, width, *block_size[::-1],
+                                           *step[::-1], *overlap[::-1])
     else:
         raise ValueError("{} is not a supported stragegy for block merging.".
                          format(merge_strategy))
@@ -421,29 +469,35 @@ def slider_predict(predict_func,
         pb = tqdm(total=num_blocks)
     batch_data = []
     batch_offsets = []
-    for yoff in range(0, height, step[1]):
-        for xoff in range(0, width, step[0]):
+    start_h, start_w = start[::-1]
+    end_h, end_w = end[::-1]
+    for yoff in range(start_h, height, step[1]):
+        for xoff in range(start_w, width, step[0]):
             if xoff + xsize > width:
-                xoff = width - xsize
+                xoff = end_w - xsize
                 is_end_of_row = True
             else:
                 is_end_of_row = False
             if yoff + ysize > height:
-                yoff = height - ysize
+                yoff = end_h - ysize
                 is_end_of_col = True
             else:
                 is_end_of_col = False
 
             # Read
-            im = reader.get_block(xoff, yoff, xsize, ysize)
+            tar_xsize, tar_ysize = block_size
+            im = reader.get_block(xoff, yoff, xsize, ysize, tar_xsize,
+                                  tar_ysize)
 
             if isinstance(img_file, tuple):
-                im2 = reader2.get_block(xoff, yoff, xsize, ysize)
+                im2 = reader2.get_block(xoff, yoff, xsize, ysize, tar_xsize,
+                                        tar_ysize)
                 batch_data.append((im, im2))
             else:
                 batch_data.append(im)
 
-            batch_offsets.append((xoff, yoff))
+            batch_offsets.append(
+                overlap_processor.update_batch_offsets(xoff, yoff))
 
             len_batch = len(batch_data)
 

+ 15 - 0
tests/tasks/test_slider_predict.py

@@ -151,6 +151,21 @@ class _TestSliderPredictNamespace:
                     decode_sar=False)
                 self.check_output_equal(pred_accum.shape, pred_whole.shape)
 
+                # 'swell'
+                save_dir = osp.join(td, 'swell')
+                self.model.slider_predict(
+                    self.image_path,
+                    save_dir,
+                    128,
+                    64,
+                    self.transforms,
+                    merge_strategy='swell')
+                pred_swell = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+                self.check_output_equal(pred_swell.shape, pred_whole.shape)
+
         def test_geo_info(self):
             with tempfile.TemporaryDirectory() as td:
                 _, geo_info_in = T.decode_image(