|
|
@@ -118,22 +118,22 @@ class ProbCache(Cache):
|
|
|
def roll_cache(self):
|
|
|
if self.order == 'c':
|
|
|
self.cache = np.roll(self.cache, -self.sh, axis=0)
|
|
|
- self.cache[self.sh:self.ch, :] = 0
|
|
|
+ self.cache[-self.sh:, :] = 0
|
|
|
elif self.order == 'f':
|
|
|
self.cache = np.roll(self.cache, -self.sw, axis=1)
|
|
|
- self.cache[:, self.sw:self.cw] = 0
|
|
|
+ self.cache[:, -self.sw:] = 0
|
|
|
|
|
|
def get_block(self, i_st, j_st, h, w):
|
|
|
return np.argmax(self.cache[i_st:i_st + h, j_st:j_st + w], axis=2)
|
|
|
|
|
|
|
|
|
-def slider_predict(predictor, img_file, save_dir, block_size, overlap,
|
|
|
+def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
|
|
|
transforms, invalid_value, merge_strategy):
|
|
|
"""
|
|
|
Do inference using sliding windows.
|
|
|
|
|
|
Args:
|
|
|
- predictor (object): Object that implements `predict()` method.
|
|
|
+ predict_func (callable): A callable object that makes the prediction.
|
|
|
img_file (str|tuple[str]): Image path(s).
|
|
|
save_dir (str): Directory that contains saved geotiff file.
|
|
|
block_size (list[int] | tuple[int] | int):
|
|
|
@@ -147,12 +147,10 @@ def slider_predict(predictor, img_file, save_dir, block_size, overlap,
|
|
|
invalid_value (int): Value that marks invalid pixels in output image.
|
|
|
Defaults to 255.
|
|
|
merge_strategy (str): Strategy to merge overlapping blocks. Choices are
|
|
|
- {'keep_first', 'keep_last', 'vote', 'accum'}. 'keep_first' and
|
|
|
- 'keep_last' means keeping the values of the first and the last block in
|
|
|
- traversal order, respectively. 'vote' means applying a simple voting
|
|
|
- strategy when there are conflicts in the overlapping pixels. 'accum'
|
|
|
- means determining the class of an overlapping pixel according to
|
|
|
- accumulated probabilities.
|
|
|
+ {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last'
|
|
|
+ means keeping the values of the first and the last block in
|
|
|
+ traversal order, respectively. 'accum' means determining the class
|
|
|
+ of an overlapping pixel according to accumulated probabilities.
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
@@ -175,7 +173,7 @@ def slider_predict(predictor, img_file, save_dir, block_size, overlap,
|
|
|
raise ValueError(
|
|
|
"`overlap` must be a tuple/list of length 2 or an integer.")
|
|
|
|
|
|
- if merge_strategy not in ('keep_first', 'keep_last', 'vote', 'accum'):
|
|
|
+ if merge_strategy not in ('keep_first', 'keep_last', 'accum'):
|
|
|
raise ValueError("{} is not a supported stragegy for block merging.".
|
|
|
format(merge_strategy))
|
|
|
|
|
|
@@ -227,16 +225,8 @@ def slider_predict(predictor, img_file, save_dir, block_size, overlap,
|
|
|
# When there is no overlap or the whole image is used as input,
|
|
|
# use 'keep_last' strategy as it introduces least overheads
|
|
|
merge_strategy = 'keep_last'
|
|
|
- if merge_strategy == 'vote':
|
|
|
- logging.warning(
|
|
|
- "Currently, a naive Python-implemented cache is used for aggregating voting results. "
|
|
|
- "For higher performance in inferring large images, please set `merge_strategy` to 'keep_first', "
|
|
|
- "'keep_last', or 'accum'.")
|
|
|
- cache = SlowCache()
|
|
|
- elif merge_strategy == 'accum':
|
|
|
- cache = ProbCache(height, width, *block_size, *step)
|
|
|
-
|
|
|
- prev_yoff, prev_xoff = None, None
|
|
|
+ if merge_strategy == 'accum':
|
|
|
+ cache = ProbCache(height, width, *block_size[::-1], *step[::-1])
|
|
|
|
|
|
for yoff in range(0, height, step[1]):
|
|
|
for xoff in range(0, width, step[0]):
|
|
|
@@ -254,32 +244,16 @@ def slider_predict(predictor, img_file, save_dir, block_size, overlap,
|
|
|
im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
|
|
|
(1, 2, 0))
|
|
|
# Predict
|
|
|
- out = predictor.predict((im, im2), transforms)
|
|
|
+ out = predict_func((im, im2), transforms=transforms)
|
|
|
else:
|
|
|
# Predict
|
|
|
- out = predictor.predict(im, transforms)
|
|
|
+ out = predict_func(im, transforms=transforms)
|
|
|
|
|
|
pred = out['label_map'].astype('uint8')
|
|
|
pred = pred[:ysize, :xsize]
|
|
|
|
|
|
# Deal with overlapping pixels
|
|
|
- if merge_strategy == 'vote':
|
|
|
- cache.push_block(yoff, xoff, ysize, xsize, pred)
|
|
|
- pred = cache.get_block(yoff, xoff, ysize, xsize)
|
|
|
- pred = pred.astype('uint8')
|
|
|
- if prev_yoff is not None:
|
|
|
- pop_h = yoff - prev_yoff
|
|
|
- else:
|
|
|
- pop_h = 0
|
|
|
- if prev_xoff is not None:
|
|
|
- if xoff < prev_xoff:
|
|
|
- pop_w = xsize
|
|
|
- else:
|
|
|
- pop_w = xoff - prev_xoff
|
|
|
- else:
|
|
|
- pop_w = 0
|
|
|
- cache.pop_block(prev_yoff, prev_xoff, pop_h, pop_w)
|
|
|
- elif merge_strategy == 'keep_first':
|
|
|
+ if merge_strategy == 'keep_first':
|
|
|
rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize)
|
|
|
mask = rd_block != invalid_value
|
|
|
pred = np.where(mask, rd_block, pred)
|
|
|
@@ -288,17 +262,14 @@ def slider_predict(predictor, img_file, save_dir, block_size, overlap,
|
|
|
elif merge_strategy == 'accum':
|
|
|
prob = out['score_map']
|
|
|
prob = prob[:ysize, :xsize]
|
|
|
- cache.update_block(0, yoff, ysize, xsize, prob)
|
|
|
- pred = cache.get_block(0, yoff, ysize, xsize)
|
|
|
- if xoff + step[0] >= width:
|
|
|
+ cache.update_block(0, xoff, ysize, xsize, prob)
|
|
|
+ pred = cache.get_block(0, xoff, ysize, xsize)
|
|
|
+ if xoff + xsize >= width:
|
|
|
cache.roll_cache()
|
|
|
|
|
|
# Write to file
|
|
|
band.WriteArray(pred, xoff, yoff)
|
|
|
dst_data.FlushCache()
|
|
|
|
|
|
- prev_xoff = xoff
|
|
|
- prev_yoff = yoff
|
|
|
-
|
|
|
dst_data = None
|
|
|
logging.info("GeoTiff file saved in {}.".format(save_file))
|