|
@@ -141,6 +141,9 @@ class OverlapProcessor(metaclass=ABCMeta):
|
|
|
self.sh = sh
|
|
|
self.sw = sw
|
|
|
|
|
|
+ def update_batch_offsets(self, xoff, yoff):
|
|
|
+ return xoff, yoff
|
|
|
+
|
|
|
@abstractmethod
|
|
|
def process_pred(self, out, xoff, yoff):
|
|
|
pass
|
|
@@ -200,6 +203,21 @@ class AccumProcessor(OverlapProcessor):
|
|
|
return pred
|
|
|
|
|
|
|
|
|
+class SwellProcessor(OverlapProcessor):
|
|
|
+ def __init__(self, h, w, ch, cw, sh, sw, oh, ow):
|
|
|
+ super(SwellProcessor, self).__init__(h, w, ch, cw, sh, sw)
|
|
|
+ self.oh = oh
|
|
|
+ self.ow = ow
|
|
|
+
|
|
|
+ def update_batch_offsets(self, xoff, yoff):
|
|
|
+ return xoff + self.oh, yoff + self.ow
|
|
|
+
|
|
|
+ def process_pred(self, out, xoff, yoff):
|
|
|
+ pred = out['label_map']
|
|
|
+ pred = pred[self.oh:self.ch - self.oh, self.ow:self.cw - self.ow]
|
|
|
+ return pred
|
|
|
+
|
|
|
+
|
|
|
def assign_border_weights(array, weight=0.5, border_ratio=0.25, inplace=True):
|
|
|
if not inplace:
|
|
|
array = array.copy()
|
|
@@ -216,6 +234,8 @@ class BlockReader(metaclass=ABCMeta):
|
|
|
def __init__(self, ds):
|
|
|
super().__init__()
|
|
|
self.ds = ds
|
|
|
+ self.ww = self.ds.RasterXSize
|
|
|
+ self.wh = self.ds.RasterYSize
|
|
|
|
|
|
@abstractmethod
|
|
|
def read_block(self, xoff, yoff, xsize, ysize):
|
|
@@ -233,6 +253,22 @@ class BlockReader(metaclass=ABCMeta):
|
|
|
tar_xsize = xsize
|
|
|
if tar_ysize is None:
|
|
|
tar_ysize = ysize
|
|
|
+ # Negative index correction
|
|
|
+ lxpad = 0
|
|
|
+ lypad = 0
|
|
|
+ if xoff < 0:
|
|
|
+ lxpad = -xoff
|
|
|
+ xsize -= lxpad
|
|
|
+ xoff = 0
|
|
|
+ if yoff < 0:
|
|
|
+ lypad = -yoff
|
|
|
+ ysize -= lypad
|
|
|
+ yoff = 0
|
|
|
+ # Out of index correction
|
|
|
+ if xoff + xsize > self.ww:
|
|
|
+ xsize = self.ww - xoff
|
|
|
+ if yoff + ysize > self.wh:
|
|
|
+ ysize = self.wh - yoff
|
|
|
block = self.read_block(xoff, yoff, xsize, ysize)
|
|
|
c, real_ysize, real_xsize = block.shape
|
|
|
assert real_ysize == ysize and real_xsize == xsize
|
|
@@ -246,7 +282,8 @@ class BlockReader(metaclass=ABCMeta):
|
|
|
fill_value=pad_val,
|
|
|
dtype=block.dtype)
|
|
|
# Fill
|
|
|
- padded_block[:real_ysize, :real_xsize] = block
|
|
|
+ padded_block[lypad:real_ysize + lypad, lxpad:real_xsize +
|
|
|
+ lxpad] = block
|
|
|
return padded_block
|
|
|
else:
|
|
|
return block
|
|
@@ -298,10 +335,11 @@ def slider_predict(predict_func,
|
|
|
invalid_value (int): Value that marks invalid pixels in output image.
|
|
|
Defaults to 255.
|
|
|
merge_strategy (str): Strategy to merge overlapping blocks. Choices are
|
|
|
- {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last'
|
|
|
+ {'keep_first', 'keep_last', 'accum', 'swell'}. 'keep_first' and 'keep_last'
|
|
|
means keeping the values of the first and the last block in
|
|
|
traversal order, respectively. 'accum' means determining the class
|
|
|
of an overlapping pixel according to accumulated probabilities.
|
|
|
+ 'swell' means keeping only the center part of each block prediction.
|
|
|
batch_size (int): Batch size used in inference.
|
|
|
eager_load (bool, optional): Whether to load the whole image(s) eagerly.
|
|
|
Defaults to False.
|
|
@@ -363,6 +401,9 @@ def slider_predict(predict_func,
|
|
|
height = src_data.RasterYSize
|
|
|
bands = src_data.RasterCount
|
|
|
|
|
|
+ start = (0, 0)
|
|
|
+ end = (width, height)
|
|
|
+
|
|
|
# XXX: GDAL read behavior conforms to paddlers.transforms.decode_image(read_raw=True)
|
|
|
# except for SAR images.
|
|
|
if bands == 1:
|
|
@@ -410,6 +451,13 @@ def slider_predict(predict_func,
|
|
|
elif merge_strategy == 'accum':
|
|
|
overlap_processor = AccumProcessor(height, width, *block_size[::-1],
|
|
|
*step[::-1])
|
|
|
+ elif merge_strategy == 'swell':
|
|
|
+ start = tuple([-o for o in overlap])
|
|
|
+ end = tuple([o + e for o, e in zip(overlap, end)])
|
|
|
+ step = np.array(block_size, dtype=np.int32)
|
|
|
+ block_size = tuple([b + 2 * o for b, o in zip(block_size, overlap)])
|
|
|
+ overlap_processor = SwellProcessor(height, width, *block_size[::-1],
|
|
|
+ *step[::-1], *overlap[::-1])
|
|
|
else:
|
|
|
raise ValueError("{} is not a supported stragegy for block merging.".
|
|
|
format(merge_strategy))
|
|
@@ -421,29 +469,35 @@ def slider_predict(predict_func,
|
|
|
pb = tqdm(total=num_blocks)
|
|
|
batch_data = []
|
|
|
batch_offsets = []
|
|
|
- for yoff in range(0, height, step[1]):
|
|
|
- for xoff in range(0, width, step[0]):
|
|
|
+ start_h, start_w = start[::-1]
|
|
|
+ end_h, end_w = end[::-1]
|
|
|
+ for yoff in range(start_h, height, step[1]):
|
|
|
+ for xoff in range(start_w, width, step[0]):
|
|
|
if xoff + xsize > width:
|
|
|
- xoff = width - xsize
|
|
|
+ xoff = end_w - xsize
|
|
|
is_end_of_row = True
|
|
|
else:
|
|
|
is_end_of_row = False
|
|
|
if yoff + ysize > height:
|
|
|
- yoff = height - ysize
|
|
|
+ yoff = end_h - ysize
|
|
|
is_end_of_col = True
|
|
|
else:
|
|
|
is_end_of_col = False
|
|
|
|
|
|
# Read
|
|
|
- im = reader.get_block(xoff, yoff, xsize, ysize)
|
|
|
+ tar_xsize, tar_ysize = block_size
|
|
|
+ im = reader.get_block(xoff, yoff, xsize, ysize, tar_xsize,
|
|
|
+ tar_ysize)
|
|
|
|
|
|
if isinstance(img_file, tuple):
|
|
|
- im2 = reader2.get_block(xoff, yoff, xsize, ysize)
|
|
|
+ im2 = reader2.get_block(xoff, yoff, xsize, ysize, tar_xsize,
|
|
|
+ tar_ysize)
|
|
|
batch_data.append((im, im2))
|
|
|
else:
|
|
|
batch_data.append(im)
|
|
|
|
|
|
- batch_offsets.append((xoff, yoff))
|
|
|
+ batch_offsets.append(
|
|
|
+ overlap_processor.update_batch_offsets(xoff, yoff))
|
|
|
|
|
|
len_batch = len(batch_data)
|
|
|
|