123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437 |
- import os
- import os.path as osp
- import math
- from abc import ABCMeta, abstractmethod
- from collections import Counter, defaultdict
- import numpy as np
- from tqdm import tqdm
- import paddlers.utils.logging as logging
- class Cache(metaclass=ABCMeta):
- @abstractmethod
- def get_block(self, i_st, j_st, h, w):
- pass
- class SlowCache(Cache):
- def __init__(self):
- super(SlowCache, self).__init__()
- self.cache = defaultdict(Counter)
- def push_pixel(self, i, j, l):
- self.cache[(i, j)][l] += 1
- def push_block(self, i_st, j_st, h, w, data):
- for i in range(0, h):
- for j in range(0, w):
- self.push_pixel(i_st + i, j_st + j, data[i, j])
- def pop_pixel(self, i, j):
- self.cache.pop((i, j))
- def pop_block(self, i_st, j_st, h, w):
- for i in range(0, h):
- for j in range(0, w):
- self.pop_pixel(i_st + i, j_st + j)
- def get_pixel(self, i, j):
- winners = self.cache[(i, j)].most_common(1)
- winner = winners[0]
- return winner[0]
- def get_block(self, i_st, j_st, h, w):
- block = []
- for i in range(i_st, i_st + h):
- row = []
- for j in range(j_st, j_st + w):
- row.append(self.get_pixel(i, j))
- block.append(row)
- return np.asarray(block)
- class ProbCache(Cache):
- def __init__(self, h, w, ch, cw, sh, sw, dtype=np.float32, order='c'):
- super(ProbCache, self).__init__()
- self.cache = None
- self.h = h
- self.w = w
- self.ch = ch
- self.cw = cw
- self.sh = sh
- self.sw = sw
- if not issubclass(dtype, np.floating):
- raise TypeError("`dtype` must be one of the floating types.")
- self.dtype = dtype
- order = order.lower()
- if order not in ('c', 'f'):
- raise ValueError("`order` other than 'c' and 'f' is not supported.")
- self.order = order
- def _alloc_memory(self, nc):
- if self.order == 'c':
-
-
-
-
-
-
-
-
-
- self.cache = np.zeros((self.ch, self.w, nc), dtype=self.dtype)
- elif self.order == 'f':
-
-
-
-
-
-
-
-
-
-
-
-
-
- self.cache = np.zeros((self.h, self.cw, nc), dtype=self.dtype)
- def update_block(self, i_st, j_st, h, w, prob_map):
- if self.cache is None:
- nc = prob_map.shape[2]
-
- self._alloc_memory(nc)
- self.cache[i_st:i_st + h, j_st:j_st + w] += prob_map
- def roll_cache(self, shift):
- if self.order == 'c':
- self.cache[:-shift] = self.cache[shift:]
- self.cache[-shift:, :] = 0
- elif self.order == 'f':
- self.cache[:, :-shift] = self.cache[:, shift:]
- self.cache[:, -shift:] = 0
- def get_block(self, i_st, j_st, h, w):
- return np.argmax(self.cache[i_st:i_st + h, j_st:j_st + w], axis=2)
- class OverlapProcessor(metaclass=ABCMeta):
- def __init__(self, h, w, ch, cw, sh, sw):
- super(OverlapProcessor, self).__init__()
- self.h = h
- self.w = w
- self.ch = ch
- self.cw = cw
- self.sh = sh
- self.sw = sw
- @abstractmethod
- def process_pred(self, out, xoff, yoff):
- pass
- class KeepFirstProcessor(OverlapProcessor):
- def __init__(self, h, w, ch, cw, sh, sw, ds, inval=255):
- super(KeepFirstProcessor, self).__init__(h, w, ch, cw, sh, sw)
- self.ds = ds
- self.inval = inval
- def process_pred(self, out, xoff, yoff):
- pred = out['label_map']
- pred = pred[:self.ch, :self.cw]
- rd_block = self.ds.ReadAsArray(xoff, yoff, self.cw, self.ch)
- mask = rd_block != self.inval
- pred = np.where(mask, rd_block, pred)
- return pred
- class KeepLastProcessor(OverlapProcessor):
- def process_pred(self, out, xoff, yoff):
- pred = out['label_map']
- pred = pred[:self.ch, :self.cw]
- return pred
- class AccumProcessor(OverlapProcessor):
- def __init__(self,
- h,
- w,
- ch,
- cw,
- sh,
- sw,
- dtype=np.float16,
- assign_weight=True):
- super(AccumProcessor, self).__init__(h, w, ch, cw, sh, sw)
- self.cache = ProbCache(h, w, ch, cw, sh, sw, dtype=dtype, order='c')
- self.prev_yoff = None
- self.assign_weight = assign_weight
- def process_pred(self, out, xoff, yoff):
- if self.prev_yoff is not None and yoff != self.prev_yoff:
- if yoff < self.prev_yoff:
- raise RuntimeError
- self.cache.roll_cache(yoff - self.prev_yoff)
- pred = out['label_map']
- pred = pred[:self.ch, :self.cw]
- prob = out['score_map']
- prob = prob[:self.ch, :self.cw]
- if self.assign_weight:
- prob = assign_border_weights(prob, border_ratio=0.25, inplace=True)
- self.cache.update_block(0, xoff, self.ch, self.cw, prob)
- pred = self.cache.get_block(0, xoff, self.ch, self.cw)
- self.prev_yoff = yoff
- return pred
- def assign_border_weights(array, weight=0.5, border_ratio=0.25, inplace=True):
- if not inplace:
- array = array.copy()
- h, w = array.shape[:2]
- hm, wm = int(h * border_ratio), int(w * border_ratio)
- array[:hm] *= weight
- array[-hm:] *= weight
- array[:, :wm] *= weight
- array[:, -wm:] *= weight
- return array
- def read_block(ds,
- xoff,
- yoff,
- xsize,
- ysize,
- tar_xsize=None,
- tar_ysize=None,
- pad_val=0):
- if tar_xsize is None:
- tar_xsize = xsize
- if tar_ysize is None:
- tar_ysize = ysize
-
- block = ds.ReadAsArray(xoff, yoff, xsize, ysize)
- c, real_ysize, real_xsize = block.shape
- assert real_ysize == ysize and real_xsize == xsize
-
- block = block.transpose((1, 2, 0))
- if (real_ysize, real_xsize) != (tar_ysize, tar_xsize):
- if real_ysize >= tar_ysize or real_xsize >= tar_xsize:
- raise ValueError
- padded_block = np.full(
- (tar_ysize, tar_xsize, c), fill_value=pad_val, dtype=block.dtype)
-
- padded_block[:real_ysize, :real_xsize] = block
- return padded_block
- else:
- return block
- def slider_predict(predict_func,
- img_file,
- save_dir,
- block_size,
- overlap,
- transforms,
- invalid_value,
- merge_strategy,
- batch_size,
- show_progress=False):
- """
- Do inference using sliding windows.
- Args:
- predict_func (callable): A callable object that makes the prediction.
- img_file (str|tuple[str]): Image path(s).
- save_dir (str): Directory that contains saved geotiff file.
- block_size (list[int] | tuple[int] | int):
- Size of block. If `block_size` is list or tuple, it should be in
- (W, H) format.
- overlap (list[int] | tuple[int] | int):
- Overlap between two blocks. If `overlap` is list or tuple, it should
- be in (W, H) format.
- transforms (paddlers.transforms.Compose|None): Transforms for inputs. If
- None, the transforms for evaluation process will be used.
- invalid_value (int): Value that marks invalid pixels in output image.
- Defaults to 255.
- merge_strategy (str): Strategy to merge overlapping blocks. Choices are
- {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last'
- means keeping the values of the first and the last block in
- traversal order, respectively. 'accum' means determining the class
- of an overlapping pixel according to accumulated probabilities.
- batch_size (int): Batch size used in inference.
- show_progress (bool, optional): Whether to show prediction progress with a
- progress bar. Defaults to True.
- """
- try:
- from osgeo import gdal
- except:
- import gdal
- if isinstance(block_size, int):
- block_size = (block_size, block_size)
- elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
- block_size = tuple(block_size)
- else:
- raise ValueError(
- "`block_size` must be a tuple/list of length 2 or an integer.")
- if isinstance(overlap, int):
- overlap = (overlap, overlap)
- elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
- overlap = tuple(overlap)
- else:
- raise ValueError(
- "`overlap` must be a tuple/list of length 2 or an integer.")
- step = np.array(
- block_size, dtype=np.int32) - np.array(
- overlap, dtype=np.int32)
- if step[0] == 0 or step[1] == 0:
- raise ValueError("`block_size` and `overlap` should not be equal.")
- if isinstance(img_file, tuple):
- if len(img_file) != 2:
- raise ValueError("Tuple `img_file` must have the length of two.")
-
- src_data = gdal.Open(img_file[0])
- src2_data = gdal.Open(img_file[1])
-
- file_name = osp.basename(osp.normpath(img_file[0]))
- else:
- src_data = gdal.Open(img_file)
- file_name = osp.basename(osp.normpath(img_file))
-
- width = src_data.RasterXSize
- height = src_data.RasterYSize
- bands = src_data.RasterCount
-
-
- if bands == 1:
- logging.warning(
- f"Detected `bands=1`. Please note that currently `slider_predict()` does not properly handle SAR images."
- )
- if block_size[0] > width or block_size[1] > height:
- raise ValueError("`block_size` should not be larger than image size.")
- driver = gdal.GetDriverByName("GTiff")
- if not osp.exists(save_dir):
- os.makedirs(save_dir)
-
- file_name = osp.splitext(file_name)[0] + ".tif"
- save_file = osp.join(save_dir, file_name)
- dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
-
- dst_data.SetGeoTransform(src_data.GetGeoTransform())
- dst_data.SetProjection(src_data.GetProjection())
-
- band = dst_data.GetRasterBand(1)
- band.WriteArray(
- np.full(
- (height, width), fill_value=invalid_value, dtype="uint8"))
- if overlap == (0, 0) or block_size == (width, height):
-
-
- merge_strategy = 'keep_last'
- if merge_strategy == 'keep_first':
- overlap_processor = KeepFirstProcessor(
- height,
- width,
- *block_size[::-1],
- *step[::-1],
- band,
- inval=invalid_value)
- elif merge_strategy == 'keep_last':
- overlap_processor = KeepLastProcessor(height, width, *block_size[::-1],
- *step[::-1])
- elif merge_strategy == 'accum':
- overlap_processor = AccumProcessor(height, width, *block_size[::-1],
- *step[::-1])
- else:
- raise ValueError("{} is not a supported stragegy for block merging.".
- format(merge_strategy))
- xsize, ysize = block_size
- num_blocks = math.ceil(height / step[1]) * math.ceil(width / step[0])
- cnt = 0
- if show_progress:
- pb = tqdm(total=num_blocks)
- batch_data = []
- batch_offsets = []
- for yoff in range(0, height, step[1]):
- for xoff in range(0, width, step[0]):
- if xoff + xsize > width:
- xoff = width - xsize
- is_end_of_row = True
- else:
- is_end_of_row = False
- if yoff + ysize > height:
- yoff = height - ysize
- is_end_of_col = True
- else:
- is_end_of_col = False
-
- im = read_block(src_data, xoff, yoff, xsize, ysize)
- if isinstance(img_file, tuple):
- im2 = read_block(src2_data, xoff, yoff, xsize, ysize)
- batch_data.append((im, im2))
- else:
- batch_data.append(im)
- batch_offsets.append((xoff, yoff))
- len_batch = len(batch_data)
- if is_end_of_row and is_end_of_col and len_batch < batch_size:
-
- batch_data = batch_data + [batch_data[-1]] * (batch_size -
- len_batch)
-
- if len(batch_data) == batch_size:
-
- batch_out = predict_func(batch_data, transforms=transforms)
- for out, (xoff_, yoff_) in zip(batch_out, batch_offsets):
-
- pred = overlap_processor.process_pred(out, xoff_, yoff_)
-
- band.WriteArray(pred, xoff_, yoff_)
- dst_data.FlushCache()
- batch_data.clear()
- batch_offsets.clear()
- cnt += 1
- if show_progress:
- pb.update(1)
- pb.set_description("{} out of {} blocks processed.".format(
- cnt, num_blocks))
- dst_data = None
- logging.info("GeoTiff file saved in {}.".format(save_file))
|