Ver Fonte

[Feat] Add swelling predict (#92)

* [Feat] Init add swelling predict

* [Feat] Init add swelling predict

* [Refactor] Update OverlapProcessor

* [Fix] Fix update_batch_offsets
Yizhou Chen há 2 anos atrás
pai
commit
cb49a9b0ee

+ 63 - 9
paddlers/tasks/utils/slider_predict.py

@@ -141,6 +141,9 @@ class OverlapProcessor(metaclass=ABCMeta):
         self.sh = sh
         self.sh = sh
         self.sw = sw
         self.sw = sw
 
 
+    def update_batch_offsets(self, xoff, yoff):
+        return xoff, yoff
+
     @abstractmethod
     @abstractmethod
     def process_pred(self, out, xoff, yoff):
     def process_pred(self, out, xoff, yoff):
         pass
         pass
@@ -200,6 +203,21 @@ class AccumProcessor(OverlapProcessor):
         return pred
         return pred
 
 
 
 
+class SwellProcessor(OverlapProcessor):
+    def __init__(self, h, w, ch, cw, sh, sw, oh, ow):
+        super(SwellProcessor, self).__init__(h, w, ch, cw, sh, sw)
+        self.oh = oh
+        self.ow = ow
+
+    def update_batch_offsets(self, xoff, yoff):
+        return xoff + self.oh, yoff + self.ow
+
+    def process_pred(self, out, xoff, yoff):
+        pred = out['label_map']
+        pred = pred[self.oh:self.ch - self.oh, self.ow:self.cw - self.ow]
+        return pred
+
+
 def assign_border_weights(array, weight=0.5, border_ratio=0.25, inplace=True):
 def assign_border_weights(array, weight=0.5, border_ratio=0.25, inplace=True):
     if not inplace:
     if not inplace:
         array = array.copy()
         array = array.copy()
@@ -216,6 +234,8 @@ class BlockReader(metaclass=ABCMeta):
     def __init__(self, ds):
     def __init__(self, ds):
         super().__init__()
         super().__init__()
         self.ds = ds
         self.ds = ds
+        self.ww = self.ds.RasterXSize
+        self.wh = self.ds.RasterYSize
 
 
     @abstractmethod
     @abstractmethod
     def read_block(self, xoff, yoff, xsize, ysize):
     def read_block(self, xoff, yoff, xsize, ysize):
@@ -233,6 +253,22 @@ class BlockReader(metaclass=ABCMeta):
             tar_xsize = xsize
             tar_xsize = xsize
         if tar_ysize is None:
         if tar_ysize is None:
             tar_ysize = ysize
             tar_ysize = ysize
+        # Negative index correction
+        lxpad = 0
+        lypad = 0
+        if xoff < 0:
+            lxpad = -xoff
+            xsize -= lxpad
+            xoff = 0
+        if yoff < 0:
+            lypad = -yoff
+            ysize -= lypad
+            yoff = 0
+        # Out of index correction
+        if xoff + xsize > self.ww:
+            xsize = self.ww - xoff
+        if yoff + ysize > self.wh:
+            ysize = self.wh - yoff
         block = self.read_block(xoff, yoff, xsize, ysize)
         block = self.read_block(xoff, yoff, xsize, ysize)
         c, real_ysize, real_xsize = block.shape
         c, real_ysize, real_xsize = block.shape
         assert real_ysize == ysize and real_xsize == xsize
         assert real_ysize == ysize and real_xsize == xsize
@@ -246,7 +282,8 @@ class BlockReader(metaclass=ABCMeta):
                 fill_value=pad_val,
                 fill_value=pad_val,
                 dtype=block.dtype)
                 dtype=block.dtype)
             # Fill
             # Fill
-            padded_block[:real_ysize, :real_xsize] = block
+            padded_block[lypad:real_ysize + lypad, lxpad:real_xsize +
+                         lxpad] = block
             return padded_block
             return padded_block
         else:
         else:
             return block
             return block
@@ -298,10 +335,11 @@ def slider_predict(predict_func,
         invalid_value (int): Value that marks invalid pixels in output image. 
         invalid_value (int): Value that marks invalid pixels in output image. 
             Defaults to 255.
             Defaults to 255.
         merge_strategy (str): Strategy to merge overlapping blocks. Choices are 
         merge_strategy (str): Strategy to merge overlapping blocks. Choices are 
-            {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' 
+            {'keep_first', 'keep_last', 'accum', 'swell'}. 'keep_first' and 'keep_last' 
             means keeping the values of the first and the last block in 
             means keeping the values of the first and the last block in 
             traversal order, respectively. 'accum' means determining the class 
             traversal order, respectively. 'accum' means determining the class 
             of an overlapping pixel according to accumulated probabilities.
             of an overlapping pixel according to accumulated probabilities.
+            'swell' means keeping only the center part of each block prediction.
         batch_size (int): Batch size used in inference.
         batch_size (int): Batch size used in inference.
         eager_load (bool, optional): Whether to load the whole image(s) eagerly.
         eager_load (bool, optional): Whether to load the whole image(s) eagerly.
             Defaults to False.
             Defaults to False.
@@ -363,6 +401,9 @@ def slider_predict(predict_func,
     height = src_data.RasterYSize
     height = src_data.RasterYSize
     bands = src_data.RasterCount
     bands = src_data.RasterCount
 
 
+    start = (0, 0)
+    end = (width, height)
+
     # XXX: GDAL read behavior conforms to paddlers.transforms.decode_image(read_raw=True)
     # XXX: GDAL read behavior conforms to paddlers.transforms.decode_image(read_raw=True)
     # except for SAR images.
     # except for SAR images.
     if bands == 1:
     if bands == 1:
@@ -410,6 +451,13 @@ def slider_predict(predict_func,
     elif merge_strategy == 'accum':
     elif merge_strategy == 'accum':
         overlap_processor = AccumProcessor(height, width, *block_size[::-1],
         overlap_processor = AccumProcessor(height, width, *block_size[::-1],
                                            *step[::-1])
                                            *step[::-1])
+    elif merge_strategy == 'swell':
+        start = tuple([-o for o in overlap])
+        end = tuple([o + e for o, e in zip(overlap, end)])
+        step = np.array(block_size, dtype=np.int32)
+        block_size = tuple([b + 2 * o for b, o in zip(block_size, overlap)])
+        overlap_processor = SwellProcessor(height, width, *block_size[::-1],
+                                           *step[::-1], *overlap[::-1])
     else:
     else:
         raise ValueError("{} is not a supported stragegy for block merging.".
         raise ValueError("{} is not a supported stragegy for block merging.".
                          format(merge_strategy))
                          format(merge_strategy))
@@ -421,29 +469,35 @@ def slider_predict(predict_func,
         pb = tqdm(total=num_blocks)
         pb = tqdm(total=num_blocks)
     batch_data = []
     batch_data = []
     batch_offsets = []
     batch_offsets = []
-    for yoff in range(0, height, step[1]):
-        for xoff in range(0, width, step[0]):
+    start_h, start_w = start[::-1]
+    end_h, end_w = end[::-1]
+    for yoff in range(start_h, height, step[1]):
+        for xoff in range(start_w, width, step[0]):
             if xoff + xsize > width:
             if xoff + xsize > width:
-                xoff = width - xsize
+                xoff = end_w - xsize
                 is_end_of_row = True
                 is_end_of_row = True
             else:
             else:
                 is_end_of_row = False
                 is_end_of_row = False
             if yoff + ysize > height:
             if yoff + ysize > height:
-                yoff = height - ysize
+                yoff = end_h - ysize
                 is_end_of_col = True
                 is_end_of_col = True
             else:
             else:
                 is_end_of_col = False
                 is_end_of_col = False
 
 
             # Read
             # Read
-            im = reader.get_block(xoff, yoff, xsize, ysize)
+            tar_xsize, tar_ysize = block_size
+            im = reader.get_block(xoff, yoff, xsize, ysize, tar_xsize,
+                                  tar_ysize)
 
 
             if isinstance(img_file, tuple):
             if isinstance(img_file, tuple):
-                im2 = reader2.get_block(xoff, yoff, xsize, ysize)
+                im2 = reader2.get_block(xoff, yoff, xsize, ysize, tar_xsize,
+                                        tar_ysize)
                 batch_data.append((im, im2))
                 batch_data.append((im, im2))
             else:
             else:
                 batch_data.append(im)
                 batch_data.append(im)
 
 
-            batch_offsets.append((xoff, yoff))
+            batch_offsets.append(
+                overlap_processor.update_batch_offsets(xoff, yoff))
 
 
             len_batch = len(batch_data)
             len_batch = len(batch_data)
 
 

+ 15 - 0
tests/tasks/test_slider_predict.py

@@ -151,6 +151,21 @@ class _TestSliderPredictNamespace:
                     decode_sar=False)
                     decode_sar=False)
                 self.check_output_equal(pred_accum.shape, pred_whole.shape)
                 self.check_output_equal(pred_accum.shape, pred_whole.shape)
 
 
+                # 'swell'
+                save_dir = osp.join(td, 'swell')
+                self.model.slider_predict(
+                    self.image_path,
+                    save_dir,
+                    128,
+                    64,
+                    self.transforms,
+                    merge_strategy='swell')
+                pred_swell = T.decode_image(
+                    osp.join(save_dir, self.basename),
+                    read_raw=True,
+                    decode_sar=False)
+                self.check_output_equal(pred_swell.shape, pred_whole.shape)
+
         def test_geo_info(self):
         def test_geo_info(self):
             with tempfile.TemporaryDirectory() as td:
             with tempfile.TemporaryDirectory() as td:
                 _, geo_info_in = T.decode_image(
                 _, geo_info_in = T.decode_image(