|
@@ -212,36 +212,63 @@ def assign_border_weights(array, weight=0.5, border_ratio=0.25, inplace=True):
|
|
|
return array
|
|
|
|
|
|
|
|
|
-def read_block(ds,
|
|
|
- xoff,
|
|
|
- yoff,
|
|
|
- xsize,
|
|
|
- ysize,
|
|
|
- tar_xsize=None,
|
|
|
- tar_ysize=None,
|
|
|
- pad_val=0):
|
|
|
- if tar_xsize is None:
|
|
|
- tar_xsize = xsize
|
|
|
- if tar_ysize is None:
|
|
|
- tar_ysize = ysize
|
|
|
- # Read data from dataset
|
|
|
- block = ds.ReadAsArray(xoff, yoff, xsize, ysize)
|
|
|
- c, real_ysize, real_xsize = block.shape
|
|
|
- assert real_ysize == ysize and real_xsize == xsize
|
|
|
- # [c, h, w] -> [h, w, c]
|
|
|
- block = block.transpose((1, 2, 0))
|
|
|
- if (real_ysize, real_xsize) != (tar_ysize, tar_xsize):
|
|
|
- if real_ysize >= tar_ysize or real_xsize >= tar_xsize:
|
|
|
- raise ValueError
|
|
|
- padded_block = np.full(
|
|
|
- (tar_ysize, tar_xsize, c), fill_value=pad_val, dtype=block.dtype)
|
|
|
- # Fill
|
|
|
- padded_block[:real_ysize, :real_xsize] = block
|
|
|
- return padded_block
|
|
|
- else:
|
|
|
+class BlockReader(metaclass=ABCMeta):
|
|
|
+ def __init__(self, ds):
|
|
|
+ super().__init__()
|
|
|
+ self.ds = ds
|
|
|
+
|
|
|
+ @abstractmethod
|
|
|
+ def read_block(self, xoff, yoff, xsize, ysize):
|
|
|
+ pass
|
|
|
+
|
|
|
+ def get_block(self,
|
|
|
+ xoff,
|
|
|
+ yoff,
|
|
|
+ xsize,
|
|
|
+ ysize,
|
|
|
+ tar_xsize=None,
|
|
|
+ tar_ysize=None,
|
|
|
+ pad_val=0):
|
|
|
+ if tar_xsize is None:
|
|
|
+ tar_xsize = xsize
|
|
|
+ if tar_ysize is None:
|
|
|
+ tar_ysize = ysize
|
|
|
+ block = self.read_block(xoff, yoff, xsize, ysize)
|
|
|
+ c, real_ysize, real_xsize = block.shape
|
|
|
+ assert real_ysize == ysize and real_xsize == xsize
|
|
|
+ # [c, h, w] -> [h, w, c]
|
|
|
+ block = block.transpose((1, 2, 0))
|
|
|
+ if (real_ysize, real_xsize) != (tar_ysize, tar_xsize):
|
|
|
+ if real_ysize >= tar_ysize or real_xsize >= tar_xsize:
|
|
|
+ raise ValueError
|
|
|
+ padded_block = np.full(
|
|
|
+ (tar_ysize, tar_xsize, c),
|
|
|
+ fill_value=pad_val,
|
|
|
+ dtype=block.dtype)
|
|
|
+ # Fill
|
|
|
+ padded_block[:real_ysize, :real_xsize] = block
|
|
|
+ return padded_block
|
|
|
+ else:
|
|
|
+ return block
|
|
|
+
|
|
|
+
|
|
|
+class GDALLazyBlockReader(BlockReader):
|
|
|
+ def read_block(self, xoff, yoff, xsize, ysize):
|
|
|
+ block = self.ds.ReadAsArray(xoff, yoff, xsize, ysize)
|
|
|
return block
|
|
|
|
|
|
|
|
|
+class EagerBlockReader(BlockReader):
|
|
|
+ def __init__(self, ds):
|
|
|
+ super().__init__(ds)
|
|
|
+ # Read the whole image eagerly
|
|
|
+ self._whole_image = self.ds.ReadAsArray()
|
|
|
+
|
|
|
+ def read_block(self, xoff, yoff, xsize, ysize):
|
|
|
+ # First dim is channel
|
|
|
+ return self._whole_image[:, yoff:yoff + ysize, xoff:xoff + xsize]
|
|
|
+
|
|
|
+
|
|
|
def slider_predict(predict_func,
|
|
|
img_file,
|
|
|
save_dir,
|
|
@@ -251,6 +278,7 @@ def slider_predict(predict_func,
|
|
|
invalid_value,
|
|
|
merge_strategy,
|
|
|
batch_size,
|
|
|
+ eager_load=False,
|
|
|
show_progress=False):
|
|
|
"""
|
|
|
Do inference using sliding windows.
|
|
@@ -275,10 +303,19 @@ def slider_predict(predict_func,
|
|
|
traversal order, respectively. 'accum' means determining the class
|
|
|
of an overlapping pixel according to accumulated probabilities.
|
|
|
batch_size (int): Batch size used in inference.
|
|
|
+ eager_load (bool, optional): Whether to load the whole image(s) eagerly.
|
|
|
+ Defaults to False.
|
|
|
show_progress (bool, optional): Whether to show prediction progress with a
|
|
|
progress bar. Defaults to True.
|
|
|
"""
|
|
|
|
|
|
+ def _construct_reader(eager_load, *args, **kwargs):
|
|
|
+ if eager_load:
|
|
|
+ reader = EagerBlockReader(*args, **kwargs)
|
|
|
+ else:
|
|
|
+ reader = GDALLazyBlockReader(*args, **kwargs)
|
|
|
+ return reader
|
|
|
+
|
|
|
try:
|
|
|
from osgeo import gdal
|
|
|
except:
|
|
@@ -311,11 +348,14 @@ def slider_predict(predict_func,
|
|
|
raise ValueError("Tuple `img_file` must have the length of two.")
|
|
|
# Assume that two input images have the same size
|
|
|
src_data = gdal.Open(img_file[0])
|
|
|
+ reader = _construct_reader(eager_load=eager_load, ds=src_data)
|
|
|
src2_data = gdal.Open(img_file[1])
|
|
|
+ reader2 = _construct_reader(eager_load=eager_load, ds=src2_data)
|
|
|
# Output name is the same as the name of the first image
|
|
|
file_name = osp.basename(osp.normpath(img_file[0]))
|
|
|
else:
|
|
|
src_data = gdal.Open(img_file)
|
|
|
+ reader = _construct_reader(eager_load=eager_load, ds=src_data)
|
|
|
file_name = osp.basename(osp.normpath(img_file))
|
|
|
|
|
|
# Get size of original raster
|
|
@@ -395,10 +435,10 @@ def slider_predict(predict_func,
|
|
|
is_end_of_col = False
|
|
|
|
|
|
# Read
|
|
|
- im = read_block(src_data, xoff, yoff, xsize, ysize)
|
|
|
+ im = reader.get_block(xoff, yoff, xsize, ysize)
|
|
|
|
|
|
if isinstance(img_file, tuple):
|
|
|
- im2 = read_block(src2_data, xoff, yoff, xsize, ysize)
|
|
|
+ im2 = reader2.get_block(xoff, yoff, xsize, ysize)
|
|
|
batch_data.append((im, im2))
|
|
|
else:
|
|
|
batch_data.append(im)
|
|
@@ -423,7 +463,6 @@ def slider_predict(predict_func,
|
|
|
# Write to file
|
|
|
band.WriteArray(pred, xoff_, yoff_)
|
|
|
|
|
|
- dst_data.FlushCache()
|
|
|
batch_data.clear()
|
|
|
batch_offsets.clear()
|
|
|
|
|
@@ -433,6 +472,9 @@ def slider_predict(predict_func,
|
|
|
pb.update(1)
|
|
|
pb.set_description("{} out of {} blocks processed.".format(
|
|
|
cnt, num_blocks))
|
|
|
+ # Flush cache when finishing each row
|
|
|
+ dst_data.FlushCache()
|
|
|
|
|
|
+ dst_data.FlushCache()
|
|
|
dst_data = None
|
|
|
logging.info("GeoTiff file saved in {}.".format(save_file))
|