Bläddra i källkod

Refactor and fix bugs

Bobholamovic 2 år sedan
förälder
incheckning
2f60abd899
1 ändrade filer med 172 tillägg och 40 borttagningar
  1. 172 40
      paddlers/tasks/utils/slider_predict.py

+ 172 - 40
paddlers/tasks/utils/slider_predict.py

@@ -19,6 +19,7 @@ from abc import ABCMeta, abstractmethod
 from collections import Counter, defaultdict
 
 import numpy as np
+from tqdm import tqdm
 
 import paddlers.utils.logging as logging
 
@@ -31,6 +32,7 @@ class Cache(metaclass=ABCMeta):
 
 class SlowCache(Cache):
     def __init__(self):
+        super(SlowCache, self).__init__()
         self.cache = defaultdict(Counter)
 
     def push_pixel(self, i, j, l):
@@ -66,6 +68,7 @@ class SlowCache(Cache):
 
 class ProbCache(Cache):
     def __init__(self, h, w, ch, cw, sh, sw, dtype=np.float32, order='c'):
+        super(ProbCache, self).__init__()
         self.cache = None
         self.h = h
         self.w = w
@@ -116,20 +119,139 @@ class ProbCache(Cache):
             self._alloc_memory(nc)
         self.cache[i_st:i_st + h, j_st:j_st + w] += prob_map
 
-    def roll_cache(self):
+    def roll_cache(self, shift):
         if self.order == 'c':
-            self.cache[:-self.sh] = self.cache[self.sh:]
-            self.cache[-self.sh:, :] = 0
+            self.cache[:-shift] = self.cache[shift:]
+            self.cache[-shift:, :] = 0
         elif self.order == 'f':
-            self.cache[:, :-self.sw] = self.cache[:, self.sw:]
-            self.cache[:, -self.sw:] = 0
+            self.cache[:, :-shift] = self.cache[:, shift:]
+            self.cache[:, -shift:] = 0
 
     def get_block(self, i_st, j_st, h, w):
         return np.argmax(self.cache[i_st:i_st + h, j_st:j_st + w], axis=2)
 
 
-def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
-                   transforms, invalid_value, merge_strategy, batch_size):
+class OverlapProcessor(metaclass=ABCMeta):
+    def __init__(self, h, w, ch, cw, sh, sw):
+        super(OverlapProcessor, self).__init__()
+        self.h = h
+        self.w = w
+        self.ch = ch
+        self.cw = cw
+        self.sh = sh
+        self.sw = sw
+
+    @abstractmethod
+    def process_pred(self, out, xoff, yoff):
+        pass
+
+
+class KeepFirstProcessor(OverlapProcessor):
+    def __init__(self, h, w, ch, cw, sh, sw, ds, inval=255):
+        super(KeepFirstProcessor, self).__init__(h, w, ch, cw, sh, sw)
+        self.ds = ds
+        self.inval = inval
+
+    def process_pred(self, out, xoff, yoff):
+        pred = out['label_map']
+        pred = pred[:self.ch, :self.cw]
+        rd_block = self.ds.ReadAsArray(xoff, yoff, self.cw, self.ch)
+        mask = rd_block != self.inval
+        pred = np.where(mask, rd_block, pred)
+        return pred
+
+
+class KeepLastProcessor(OverlapProcessor):
+    def process_pred(self, out, xoff, yoff):
+        pred = out['label_map']
+        pred = pred[:self.ch, :self.cw]
+        return pred
+
+
+class AccumProcessor(OverlapProcessor):
+    def __init__(self,
+                 h,
+                 w,
+                 ch,
+                 cw,
+                 sh,
+                 sw,
+                 dtype=np.float16,
+                 assign_weight=True):
+        super(AccumProcessor, self).__init__(h, w, ch, cw, sh, sw)
+        self.cache = ProbCache(h, w, ch, cw, sh, sw, dtype=dtype, order='c')
+        self.prev_yoff = None
+        self.assign_weight = assign_weight
+
+    def process_pred(self, out, xoff, yoff):
+        if self.prev_yoff is not None and yoff != self.prev_yoff:
+            if yoff < self.prev_yoff:
+                raise RuntimeError
+            self.cache.roll_cache(yoff - self.prev_yoff)
+        pred = out['label_map']
+        pred = pred[:self.ch, :self.cw]
+        prob = out['score_map']
+        prob = prob[:self.ch, :self.cw]
+        if self.assign_weight:
+            prob = assign_border_weights(prob, border_ratio=0.25, inplace=True)
+        self.cache.update_block(0, xoff, self.ch, self.cw, prob)
+        pred = self.cache.get_block(0, xoff, self.ch, self.cw)
+        self.prev_yoff = yoff
+        return pred
+
+
+def assign_border_weights(array, weight=0.5, border_ratio=0.25, inplace=True):
+    if not inplace:
+        array = array.copy()
+    h, w = array.shape[:2]
+    hm, wm = int(h * border_ratio), int(w * border_ratio)
+    array[:hm] *= weight
+    array[-hm:] *= weight
+    array[:, :wm] *= weight
+    array[:, -wm:] *= weight
+    return array
+
+
+def read_block(ds,
+               xoff,
+               yoff,
+               xsize,
+               ysize,
+               tar_xsize=None,
+               tar_ysize=None,
+               pad_val=0):
+    if tar_xsize is None:
+        tar_xsize = xsize
+    if tar_ysize is None:
+        tar_ysize = ysize
+    # Read data from dataset
+    block = ds.ReadAsArray(xoff, yoff, xsize, ysize)
+    c, real_ysize, real_xsize = block.shape
+    assert real_ysize == ysize and real_xsize == xsize
+    # [c, h, w] -> [h, w, c]
+    block = block.transpose((1, 2, 0))
+    if (real_ysize, real_xsize) != (tar_ysize, tar_xsize):
+        if real_ysize >= tar_ysize or real_xsize >= tar_xsize:
+            raise ValueError
+        padded_block = np.full(
+            (tar_ysize, tar_xsize, c), fill_value=pad_val, dtype=block.dtype)
+        # Fill
+        padded_block[:real_ysize, :real_xsize] = block
+        return padded_block
+    else:
+        return block
+
+
+def slider_predict(predict_func,
+                   img_file,
+                   save_dir,
+                   block_size,
+                   overlap,
+                   transforms,
+                   invalid_value,
+                   merge_strategy,
+                   batch_size,
+                   show_progress=False):
     """
     Do inference using sliding windows.
 
@@ -153,6 +275,8 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
             traversal order, respectively. 'accum' means determining the class 
             of an overlapping pixel according to accumulated probabilities.
         batch_size (int): Batch size used in inference.
+        show_progress (bool, optional): Whether to show prediction progress with a 
+            progress bar. Defaults to True.
     """
 
     try:
@@ -175,10 +299,6 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
         raise ValueError(
             "`overlap` must be a tuple/list of length 2 or an integer.")
 
-    if merge_strategy not in ('keep_first', 'keep_last', 'accum'):
-        raise ValueError("{} is not a supported stragegy for block merging.".
-                         format(merge_strategy))
-
     step = np.array(
         block_size, dtype=np.int32) - np.array(
             overlap, dtype=np.int32)
@@ -234,29 +354,50 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
         # When there is no overlap or the whole image is used as input, 
         # use 'keep_last' strategy as it introduces least overheads
         merge_strategy = 'keep_last'
-    if merge_strategy == 'accum':
-        cache = ProbCache(height, width, *block_size[::-1], *step[::-1])
 
+    if merge_strategy == 'keep_first':
+        overlap_processor = KeepFirstProcessor(
+            height,
+            width,
+            *block_size[::-1],
+            *step[::-1],
+            band,
+            inval=invalid_value)
+    elif merge_strategy == 'keep_last':
+        overlap_processor = KeepLastProcessor(height, width, *block_size[::-1],
+                                              *step[::-1])
+    elif merge_strategy == 'accum':
+        overlap_processor = AccumProcessor(height, width, *block_size[::-1],
+                                           *step[::-1])
+    else:
+        raise ValueError("{} is not a supported stragegy for block merging.".
+                         format(merge_strategy))
+
+    xsize, ysize = block_size
+    num_blocks = math.ceil(height / step[1]) * math.ceil(width / step[0])
+    cnt = 0
+    if show_progress:
+        pb = tqdm(total=num_blocks)
     batch_data = []
     batch_offsets = []
     for yoff in range(0, height, step[1]):
         for xoff in range(0, width, step[0]):
-            xsize, ysize = block_size
             if xoff + xsize > width:
                 xoff = width - xsize
+                is_end_of_row = True
+            else:
+                is_end_of_row = False
             if yoff + ysize > height:
                 yoff = height - ysize
+                is_end_of_col = True
+            else:
+                is_end_of_col = False
 
-            is_end_of_col = yoff + ysize >= height
-            is_end_of_row = xoff + xsize >= width
-
-            # Read and fill
-            im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
-                (1, 2, 0))
+            # Read
+            im = read_block(src_data, xoff, yoff, xsize, ysize)
 
             if isinstance(img_file, tuple):
-                im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
-                    (1, 2, 0))
+                im2 = read_block(src2_data, xoff, yoff, xsize, ysize)
                 batch_data.append((im, im2))
             else:
                 batch_data.append(im)
@@ -276,24 +417,8 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
                 batch_out = predict_func(batch_data, transforms=transforms)
 
                 for out, (xoff_, yoff_) in zip(batch_out, batch_offsets):
-                    pred = out['label_map'].astype('uint8')
-                    pred = pred[:ysize, :xsize]
-
-                    # Deal with overlapping pixels
-                    if merge_strategy == 'keep_first':
-                        rd_block = band.ReadAsArray(xoff_, yoff_, xsize, ysize)
-                        mask = rd_block != invalid_value
-                        pred = np.where(mask, rd_block, pred)
-                    elif merge_strategy == 'keep_last':
-                        pass
-                    elif merge_strategy == 'accum':
-                        prob = out['score_map']
-                        prob = prob[:ysize, :xsize]
-                        cache.update_block(0, xoff_, ysize, xsize, prob)
-                        pred = cache.get_block(0, xoff_, ysize, xsize)
-                        if xoff_ + xsize >= width:
-                            cache.roll_cache()
-
+                    # Get processed result
+                    pred = overlap_processor.process_pred(out, xoff_, yoff_)
                     # Write to file
                     band.WriteArray(pred, xoff_, yoff_)
 
@@ -301,5 +426,12 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
                 batch_data.clear()
                 batch_offsets.clear()
 
+            cnt += 1
+
+            if show_progress:
+                pb.update(1)
+                pb.set_description("{} out of {} blocks processed.".format(
+                    cnt, num_blocks))
+
     dst_data = None
     logging.info("GeoTiff file saved in {}.".format(save_file))