Browse Source

Refactor and fix bugs

Bobholamovic 2 years ago
parent
commit
2f60abd899
1 changed files with 172 additions and 40 deletions
  1. 172 40
      paddlers/tasks/utils/slider_predict.py

+ 172 - 40
paddlers/tasks/utils/slider_predict.py

@@ -19,6 +19,7 @@ from abc import ABCMeta, abstractmethod
 from collections import Counter, defaultdict
 from collections import Counter, defaultdict
 
 
 import numpy as np
 import numpy as np
+from tqdm import tqdm
 
 
 import paddlers.utils.logging as logging
 import paddlers.utils.logging as logging
 
 
@@ -31,6 +32,7 @@ class Cache(metaclass=ABCMeta):
 
 
 class SlowCache(Cache):
 class SlowCache(Cache):
     def __init__(self):
     def __init__(self):
+        super(SlowCache, self).__init__()
         self.cache = defaultdict(Counter)
         self.cache = defaultdict(Counter)
 
 
     def push_pixel(self, i, j, l):
     def push_pixel(self, i, j, l):
@@ -66,6 +68,7 @@ class SlowCache(Cache):
 
 
 class ProbCache(Cache):
 class ProbCache(Cache):
     def __init__(self, h, w, ch, cw, sh, sw, dtype=np.float32, order='c'):
     def __init__(self, h, w, ch, cw, sh, sw, dtype=np.float32, order='c'):
+        super(ProbCache, self).__init__()
         self.cache = None
         self.cache = None
         self.h = h
         self.h = h
         self.w = w
         self.w = w
@@ -116,20 +119,139 @@ class ProbCache(Cache):
             self._alloc_memory(nc)
             self._alloc_memory(nc)
         self.cache[i_st:i_st + h, j_st:j_st + w] += prob_map
         self.cache[i_st:i_st + h, j_st:j_st + w] += prob_map
 
 
-    def roll_cache(self):
+    def roll_cache(self, shift):
         if self.order == 'c':
         if self.order == 'c':
-            self.cache[:-self.sh] = self.cache[self.sh:]
+            self.cache[:-shift] = self.cache[shift:]
-            self.cache[-self.sh:, :] = 0
+            self.cache[-shift:, :] = 0
         elif self.order == 'f':
         elif self.order == 'f':
-            self.cache[:, :-self.sw] = self.cache[:, self.sw:]
+            self.cache[:, :-shift] = self.cache[:, shift:]
-            self.cache[:, -self.sw:] = 0
+            self.cache[:, -shift:] = 0
 
 
     def get_block(self, i_st, j_st, h, w):
     def get_block(self, i_st, j_st, h, w):
         return np.argmax(self.cache[i_st:i_st + h, j_st:j_st + w], axis=2)
         return np.argmax(self.cache[i_st:i_st + h, j_st:j_st + w], axis=2)
 
 
 
 
-def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
+class OverlapProcessor(metaclass=ABCMeta):
-                   transforms, invalid_value, merge_strategy, batch_size):
+    def __init__(self, h, w, ch, cw, sh, sw):
+        super(OverlapProcessor, self).__init__()
+        self.h = h
+        self.w = w
+        self.ch = ch
+        self.cw = cw
+        self.sh = sh
+        self.sw = sw
+
+    @abstractmethod
+    def process_pred(self, out, xoff, yoff):
+        pass
+
+
+class KeepFirstProcessor(OverlapProcessor):
+    def __init__(self, h, w, ch, cw, sh, sw, ds, inval=255):
+        super(KeepFirstProcessor, self).__init__(h, w, ch, cw, sh, sw)
+        self.ds = ds
+        self.inval = inval
+
+    def process_pred(self, out, xoff, yoff):
+        pred = out['label_map']
+        pred = pred[:self.ch, :self.cw]
+        rd_block = self.ds.ReadAsArray(xoff, yoff, self.cw, self.ch)
+        mask = rd_block != self.inval
+        pred = np.where(mask, rd_block, pred)
+        return pred
+
+
+class KeepLastProcessor(OverlapProcessor):
+    def process_pred(self, out, xoff, yoff):
+        pred = out['label_map']
+        pred = pred[:self.ch, :self.cw]
+        return pred
+
+
+class AccumProcessor(OverlapProcessor):
+    def __init__(self,
+                 h,
+                 w,
+                 ch,
+                 cw,
+                 sh,
+                 sw,
+                 dtype=np.float16,
+                 assign_weight=True):
+        super(AccumProcessor, self).__init__(h, w, ch, cw, sh, sw)
+        self.cache = ProbCache(h, w, ch, cw, sh, sw, dtype=dtype, order='c')
+        self.prev_yoff = None
+        self.assign_weight = assign_weight
+
+    def process_pred(self, out, xoff, yoff):
+        if self.prev_yoff is not None and yoff != self.prev_yoff:
+            if yoff < self.prev_yoff:
+                raise RuntimeError
+            self.cache.roll_cache(yoff - self.prev_yoff)
+        pred = out['label_map']
+        pred = pred[:self.ch, :self.cw]
+        prob = out['score_map']
+        prob = prob[:self.ch, :self.cw]
+        if self.assign_weight:
+            prob = assign_border_weights(prob, border_ratio=0.25, inplace=True)
+        self.cache.update_block(0, xoff, self.ch, self.cw, prob)
+        pred = self.cache.get_block(0, xoff, self.ch, self.cw)
+        self.prev_yoff = yoff
+        return pred
+
+
+def assign_border_weights(array, weight=0.5, border_ratio=0.25, inplace=True):
+    if not inplace:
+        array = array.copy()
+    h, w = array.shape[:2]
+    hm, wm = int(h * border_ratio), int(w * border_ratio)
+    array[:hm] *= weight
+    array[-hm:] *= weight
+    array[:, :wm] *= weight
+    array[:, -wm:] *= weight
+    return array
+
+
+def read_block(ds,
+               xoff,
+               yoff,
+               xsize,
+               ysize,
+               tar_xsize=None,
+               tar_ysize=None,
+               pad_val=0):
+    if tar_xsize is None:
+        tar_xsize = xsize
+    if tar_ysize is None:
+        tar_ysize = ysize
+    # Read data from dataset
+    block = ds.ReadAsArray(xoff, yoff, xsize, ysize)
+    c, real_ysize, real_xsize = block.shape
+    assert real_ysize == ysize and real_xsize == xsize
+    # [c, h, w] -> [h, w, c]
+    block = block.transpose((1, 2, 0))
+    if (real_ysize, real_xsize) != (tar_ysize, tar_xsize):
+        if real_ysize >= tar_ysize or real_xsize >= tar_xsize:
+            raise ValueError
+        padded_block = np.full(
+            (tar_ysize, tar_xsize, c), fill_value=pad_val, dtype=block.dtype)
+        # Fill
+        padded_block[:real_ysize, :real_xsize] = block
+        return padded_block
+    else:
+        return block
+
+
+def slider_predict(predict_func,
+                   img_file,
+                   save_dir,
+                   block_size,
+                   overlap,
+                   transforms,
+                   invalid_value,
+                   merge_strategy,
+                   batch_size,
+                   show_progress=False):
     """
     """
     Do inference using sliding windows.
     Do inference using sliding windows.
 
 
@@ -153,6 +275,8 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
             traversal order, respectively. 'accum' means determining the class 
             traversal order, respectively. 'accum' means determining the class 
             of an overlapping pixel according to accumulated probabilities.
             of an overlapping pixel according to accumulated probabilities.
         batch_size (int): Batch size used in inference.
         batch_size (int): Batch size used in inference.
+        show_progress (bool, optional): Whether to show prediction progress with a 
+            progress bar. Defaults to True.
     """
     """
 
 
     try:
     try:
@@ -175,10 +299,6 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
         raise ValueError(
         raise ValueError(
             "`overlap` must be a tuple/list of length 2 or an integer.")
             "`overlap` must be a tuple/list of length 2 or an integer.")
 
 
-    if merge_strategy not in ('keep_first', 'keep_last', 'accum'):
-        raise ValueError("{} is not a supported stragegy for block merging.".
-                         format(merge_strategy))
-
     step = np.array(
     step = np.array(
         block_size, dtype=np.int32) - np.array(
         block_size, dtype=np.int32) - np.array(
             overlap, dtype=np.int32)
             overlap, dtype=np.int32)
@@ -234,29 +354,50 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
         # When there is no overlap or the whole image is used as input, 
         # When there is no overlap or the whole image is used as input, 
         # use 'keep_last' strategy as it introduces least overheads
         # use 'keep_last' strategy as it introduces least overheads
         merge_strategy = 'keep_last'
         merge_strategy = 'keep_last'
-    if merge_strategy == 'accum':
-        cache = ProbCache(height, width, *block_size[::-1], *step[::-1])
 
 
+    if merge_strategy == 'keep_first':
+        overlap_processor = KeepFirstProcessor(
+            height,
+            width,
+            *block_size[::-1],
+            *step[::-1],
+            band,
+            inval=invalid_value)
+    elif merge_strategy == 'keep_last':
+        overlap_processor = KeepLastProcessor(height, width, *block_size[::-1],
+                                              *step[::-1])
+    elif merge_strategy == 'accum':
+        overlap_processor = AccumProcessor(height, width, *block_size[::-1],
+                                           *step[::-1])
+    else:
+        raise ValueError("{} is not a supported stragegy for block merging.".
+                         format(merge_strategy))
+
+    xsize, ysize = block_size
+    num_blocks = math.ceil(height / step[1]) * math.ceil(width / step[0])
+    cnt = 0
+    if show_progress:
+        pb = tqdm(total=num_blocks)
     batch_data = []
     batch_data = []
     batch_offsets = []
     batch_offsets = []
     for yoff in range(0, height, step[1]):
     for yoff in range(0, height, step[1]):
         for xoff in range(0, width, step[0]):
         for xoff in range(0, width, step[0]):
-            xsize, ysize = block_size
             if xoff + xsize > width:
             if xoff + xsize > width:
                 xoff = width - xsize
                 xoff = width - xsize
+                is_end_of_row = True
+            else:
+                is_end_of_row = False
             if yoff + ysize > height:
             if yoff + ysize > height:
                 yoff = height - ysize
                 yoff = height - ysize
+                is_end_of_col = True
+            else:
+                is_end_of_col = False
 
 
-            is_end_of_col = yoff + ysize >= height
+            # Read
-            is_end_of_row = xoff + xsize >= width
+            im = read_block(src_data, xoff, yoff, xsize, ysize)
-
-            # Read and fill
-            im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
-                (1, 2, 0))
 
 
             if isinstance(img_file, tuple):
             if isinstance(img_file, tuple):
-                im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
+                im2 = read_block(src2_data, xoff, yoff, xsize, ysize)
-                    (1, 2, 0))
                 batch_data.append((im, im2))
                 batch_data.append((im, im2))
             else:
             else:
                 batch_data.append(im)
                 batch_data.append(im)
@@ -276,24 +417,8 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
                 batch_out = predict_func(batch_data, transforms=transforms)
                 batch_out = predict_func(batch_data, transforms=transforms)
 
 
                 for out, (xoff_, yoff_) in zip(batch_out, batch_offsets):
                 for out, (xoff_, yoff_) in zip(batch_out, batch_offsets):
-                    pred = out['label_map'].astype('uint8')
+                    # Get processed result
-                    pred = pred[:ysize, :xsize]
+                    pred = overlap_processor.process_pred(out, xoff_, yoff_)
-
-                    # Deal with overlapping pixels
-                    if merge_strategy == 'keep_first':
-                        rd_block = band.ReadAsArray(xoff_, yoff_, xsize, ysize)
-                        mask = rd_block != invalid_value
-                        pred = np.where(mask, rd_block, pred)
-                    elif merge_strategy == 'keep_last':
-                        pass
-                    elif merge_strategy == 'accum':
-                        prob = out['score_map']
-                        prob = prob[:ysize, :xsize]
-                        cache.update_block(0, xoff_, ysize, xsize, prob)
-                        pred = cache.get_block(0, xoff_, ysize, xsize)
-                        if xoff_ + xsize >= width:
-                            cache.roll_cache()
-
                     # Write to file
                     # Write to file
                     band.WriteArray(pred, xoff_, yoff_)
                     band.WriteArray(pred, xoff_, yoff_)
 
 
@@ -301,5 +426,12 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
                 batch_data.clear()
                 batch_data.clear()
                 batch_offsets.clear()
                 batch_offsets.clear()
 
 
+            cnt += 1
+
+            if show_progress:
+                pb.update(1)
+                pb.set_description("{} out of {} blocks processed.".format(
+                    cnt, num_blocks))
+
     dst_data = None
     dst_data = None
     logging.info("GeoTiff file saved in {}.".format(save_file))
     logging.info("GeoTiff file saved in {}.".format(save_file))