Эх сурвалжийг харах

[Feat] Add some postprogress functions (#87)

* [Feat] Init new postpro apis

* [Feat] Add some utils postpro apis

* [Feat] Update postpros

* [Feat] Init add crf

* [Feat] Init add mrf

* [Test] Add test of postpro

* [Fix] Bugfix in postpro

* [Docs] Update a note

* [Feat] Add postprogress change detection filter

* Fix OpenCV version

* Update opencv version

* [Docs] Update postprogress docstring

* [Test] Fix postpro test image ext

* [Fix] Update postpro bugfix

* [Test] Update postpro test

---------

Co-authored-by: Bobholamovic <mhlin425@whu.edu.cn>
Yizhou Chen 2 жил өмнө
parent
commit
faa67ac48a

+ 12 - 1
paddlers/utils/postprocs/__init__.py

@@ -1,4 +1,4 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,3 +14,14 @@
 
 from .regularization import building_regularization
 from .connection import cut_road_connection
+from .mrf import markov_random_field
+from .utils import (prepro_mask, del_small_connection, fill_small_holes,
+                    morphological_operation, deal_one_class)
+from .change_filter import change_detection_filter
+
+try:
+    from .crf import conditional_random_field
+except ImportError:
+    print(
+        "Can not use `conditional_random_field`. Please install pydensecrf first!"
+    )

+ 60 - 0
paddlers/utils/postprocs/change_filter.py

@@ -0,0 +1,60 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, Optional, Any
+
+import numpy as np
+
+from paddlers.transforms.operators import AppendIndex
+
+
+def change_detection_filter(mask: np.ndarray,
+                            t1: np.ndarray,
+                            t2: np.ndarray,
+                            threshold1: float,
+                            threshold2: float,
+                            index_type: str="NDVI",
+                            band_indices: Optional[Dict]=None,
+                            satellite: Optional[str]=None,
+                            **kwargs: Dict[str, Any]) -> np.ndarray:
+    """
+    Remote sensing index filter. It is a postprocessing method for change detection tasks.
+
+    E.g. Filter plant seasonal variations in non-urban scenes
+    1. Calculate NDVI of the two images separately
+    2. Obtain vegetation mask by threshold filter
+    3. Take the intersection of the two vegetation masks, called veg_mask
+    4. Filter mask through veg_mask
+
+    Args:
+        mask (np.ndarray): Change mask predicted by a change detection model. Shape is [H, W].
+        t1 (np.ndarray): Original image of time 1.
+        t2 (np.ndarray): Original image of time 2.
+        threshold1 (float): Threshold of time 1.
+        threshold2 (float): Threshold of time 2.
+        
+        For other arguments please refer to the data transformation operator `AppendIndex`
+        (paddlers/transforms/operators.py)
+
+    Returns:
+        np.ndarray: Filtered mask.
+    """
+    index_calculator = AppendIndex(index_type, band_indices, satellite,
+                                   **kwargs)
+    index1 = index_calculator._compute_index(t1)
+    index2 = index_calculator._compute_index(t2)
+    imask1 = (index1 > threshold1).astype("uint8")
+    imask2 = (index2 > threshold2).astype("uint8")
+    imask = (imask1 + imask2 != 2).astype("uint8")
+    return mask * imask

+ 5 - 4
paddlers/utils/postprocs/connection.py

@@ -25,7 +25,7 @@ with warnings.catch_warnings():
     from sklearn import metrics
     from sklearn.cluster import KMeans
 
-from .utils import prepro_mask, calc_distance
+from .utils import del_small_connection, calc_distance
 
 
 def cut_road_connection(mask: np.ndarray,
@@ -46,21 +46,22 @@ def cut_road_connection(mask: np.ndarray,
     2. We unmark the breakpoints if the angle between the two road extensions is less than 90°.
 
     Args:
-        mask (np.ndarray): Mask of road.
+        mask (np.ndarray): Mask of road. Shape is [H, W] and values are 0 or 1.
         area_threshold (int, optional): Threshold to filter out small connected area. Default is 32.
         line_width (int, optional): Width of the line used for patching. Default is 6.
 
     Returns:
         np.ndarray: Mask of road after connecting cut road lines.
     """
-    mask = prepro_mask(mask, area_threshold)
+    mask = del_small_connection(mask, area_threshold)
     skeleton = morphology.skeletonize(mask).astype("uint8")
     break_points = _find_breakpoint(skeleton)
     labels = _k_means(break_points)
     if labels is None:
-        return mask * 255
+        return mask
     match_points = _get_match_points(break_points, labels)
     res = _draw_curve(mask, skeleton, match_points, line_width)
+    res = np.clip(res, 0, 1)
     return res
 
 

+ 57 - 0
paddlers/utils/postprocs/crf.py

@@ -0,0 +1,57 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from skimage.color import gray2rgb
+import pydensecrf.densecrf as dcrf
+from pydensecrf.utils import unary_from_labels
+
+
+def conditional_random_field(original_image: np.ndarray,
+                             mask: np.ndarray) -> np.ndarray:
+    """
+    Conditional random field.
+
+    The original article refers to
+    Krhenbühl, Philipp, Koltun V. "Efficient Inference in Fully Connected CRFs with Gaussian Edge Potentials."
+    (https://arxiv.org/abs/1210.5644v1).
+
+    The implementation procedure refers to this repo: 
+    https://github.com/apletea/Computer-Vision
+
+    Args:
+        original_image (np.ndarray): Original image. Shape is [H, W, 3]. 
+        mask (np.ndarray): Mask to refine. Shape is [H, W].
+
+    Returns:
+        np.ndarray: Mask after CRF.
+    """
+    n_labels = len(np.unique(mask))
+    mask3 = gray2rgb(mask)
+    annotated_label = mask3[:, :, 0] + (mask3[:, :, 1] << 8) + (mask3[:, :, 2]
+                                                                << 16)
+    _, labels = np.unique(annotated_label, return_inverse=True)
+    img_shape = original_image.shape
+    d = dcrf.DenseCRF2D(img_shape[1], img_shape[0], n_labels)
+    U = unary_from_labels(labels, n_labels, gt_prob=0.7, zero_unsure=False)
+    d.setUnaryEnergy(U)
+    d.addPairwiseGaussian(
+        sxy=(3, 3),
+        compat=3,
+        kernel=dcrf.DIAG_KERNEL,
+        normalization=dcrf.NORMALIZE_SYMMETRIC)
+    Q = d.inference(10)
+    MAP = np.argmax(Q, axis=0)
+    MAP = MAP.reshape((img_shape[0], img_shape[1]))
+    return MAP.astype("uint8")

+ 99 - 0
paddlers/utils/postprocs/mrf.py

@@ -0,0 +1,99 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import cv2
+import numpy as np
+
+
+def markov_random_field(original_image: np.ndarray,
+                        mask: np.ndarray,
+                        max_iter: int=2) -> np.ndarray:
+    """
+    Markov random field.
+
+    Args:
+        original_image (np.ndarray): Original image. Shape is [H, W, 3]. 
+        mask (np.ndarray): Mask to refine. Shape is [H, W].
+        max_iter (int, optional): Maximum number of iterations. Defaults to 2.
+
+    Returns:
+        np.ndarray: Mask after MRF.
+    """
+    img = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY).astype("double")
+    classes = sorted(np.unique(mask).tolist())
+    cluster_num = len(classes)
+    zlab = np.zeros_like(mask)
+    for idx, pix in enumerate(classes, start=1):
+        zlab[mask == pix] = idx
+    mask = zlab.astype('int64')
+    res = _MRF(img, mask, max_iter, cluster_num)
+    return res.astype("uint8") - 1
+
+
+def _MRF(img, label, max_iter, cluster_num):
+    f_u = np.array([0, 1, 0, 0, 0, 0, 0, 0, 0]).reshape(3, 3)
+    f_d = np.array([0, 0, 0, 0, 0, 0, 0, 1, 0]).reshape(3, 3)
+    f_l = np.array([0, 0, 0, 1, 0, 0, 0, 0, 0]).reshape(3, 3)
+    f_r = np.array([0, 0, 0, 0, 0, 1, 0, 0, 0]).reshape(3, 3)
+    f_ul = np.array([1, 0, 0, 0, 0, 0, 0, 0, 0]).reshape(3, 3)
+    f_ur = np.array([0, 0, 1, 0, 0, 0, 0, 0, 0]).reshape(3, 3)
+    f_dl = np.array([0, 0, 0, 0, 0, 0, 1, 0, 0]).reshape(3, 3)
+    f_dr = np.array([0, 0, 0, 0, 0, 0, 0, 0, 1]).reshape(3, 3)
+    iter = 0
+    while iter < max_iter:
+        iter = iter + 1
+        # print(iter)
+        label_u = cv2.filter2D(np.array(label, dtype=np.uint8), -1, f_u)
+        label_d = cv2.filter2D(np.array(label, dtype=np.uint8), -1, f_d)
+        label_l = cv2.filter2D(np.array(label, dtype=np.uint8), -1, f_l)
+        label_r = cv2.filter2D(np.array(label, dtype=np.uint8), -1, f_r)
+        label_ul = cv2.filter2D(np.array(label, dtype=np.uint8), -1, f_ul)
+        label_ur = cv2.filter2D(np.array(label, dtype=np.uint8), -1, f_ur)
+        label_dl = cv2.filter2D(np.array(label, dtype=np.uint8), -1, f_dl)
+        label_dr = cv2.filter2D(np.array(label, dtype=np.uint8), -1, f_dr)
+        m, n = label.shape
+        p_c = np.zeros((cluster_num, m, n))
+        for i in range(cluster_num):
+            label_i = (i + 1) * np.ones((m, n))
+            u_T = 1 * np.logical_not(label_i - label_u)
+            d_T = 1 * np.logical_not(label_i - label_d)
+            l_T = 1 * np.logical_not(label_i - label_l)
+            r_T = 1 * np.logical_not(label_i - label_r)
+            ul_T = 1 * np.logical_not(label_i - label_ul)
+            ur_T = 1 * np.logical_not(label_i - label_ur)
+            dl_T = 1 * np.logical_not(label_i - label_dl)
+            dr_T = 1 * np.logical_not(label_i - label_dr)
+            temp = u_T + d_T + l_T + r_T + ul_T + ur_T + dl_T + dr_T
+            p_c[i, :] = (1.0 / 8) * temp
+        p_c[p_c == 0] = 0.0001
+        mu = np.zeros((1, cluster_num))
+        sigma = np.zeros((1, cluster_num))
+        for i in range(cluster_num):
+            index = np.where(label == (i + 1))
+            data_c = img[index]
+            mu[0, i] = np.mean(data_c)
+            sigma[0, i] = np.var(data_c)
+        p_sc = np.zeros((cluster_num, m, n))
+        one_a = np.ones((m, n))
+        for j in range(cluster_num):
+            MU = mu[0, j] * one_a
+            p_sc[j, :] = (1. / np.sqrt(2. * np.pi * sigma[0, j])) * np.exp(
+                -1. * ((img - MU)**2) / (2 * sigma[0, j]))
+        X_out = np.log(p_c) + np.log(p_sc)
+        label_c = X_out.reshape(cluster_num, m * n)
+        label_c_t = label_c.T
+        label_m = np.argmax(label_c_t, axis=1)
+        label_m = label_m + np.ones(label_m.shape)
+        label = label_m.reshape(m, n)
+    return label

+ 5 - 6
paddlers/utils/postprocs/regularization.py

@@ -17,7 +17,7 @@ import math
 import cv2
 import numpy as np
 
-from .utils import prepro_mask, calc_distance
+from .utils import del_small_connection, calc_distance, morphological_operation
 
 S = 20
 TD = 3
@@ -44,7 +44,7 @@ def building_regularization(mask: np.ndarray, W: int=32) -> np.ndarray:
     The implementation is not fully consistent with the article.
 
     Args:
-        mask (np.ndarray): Mask of building.
+        mask (np.ndarray): Mask of building. Shape is [H, W] and values are 0 or 1.
         W (int, optional): Minimum threshold in main direction. Default is 32.
             The larger W, the more regular the image, but the worse the image detail.
 
@@ -52,7 +52,7 @@ def building_regularization(mask: np.ndarray, W: int=32) -> np.ndarray:
         np.ndarray: Mask of building after regularized.
     """
     # check and pro processing
-    mask = prepro_mask(mask)
+    mask = del_small_connection(mask)
     mask_shape = mask.shape
     # find contours
     contours, hierarchys = cv2.findContours(mask, cv2.RETR_TREE,
@@ -68,9 +68,8 @@ def building_regularization(mask: np.ndarray, W: int=32) -> np.ndarray:
         contour = _fine(contour, W)  # fine
         res_contours.append((contour, _get_priority(hierarchy)))
     result = _fill(mask, res_contours)  # fill
-    result = cv2.morphologyEx(result, cv2.MORPH_OPEN,
-                              cv2.getStructuringElement(cv2.MORPH_RECT,
-                                                        (3, 3)))  # open
+    result = morphological_operation(result, "open")
+    result = np.clip(result, 0, 1)
     return result
 
 

+ 137 - 19
paddlers/utils/postprocs/utils.py

@@ -1,4 +1,4 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,31 +12,67 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import numpy as np
-import cv2
+from typing import Union, Callable, Dict, Any
 
-
-def prepro_mask(mask: np.ndarray, area_threshold: int=32) -> np.ndarray:
-    mask_shape = mask.shape
-    if len(mask_shape) != 2:
-        mask = mask[..., 0]
-    mask = mask.astype("uint8")
-    mask = _del_small_connection(mask, area_threshold)
-    class_num = len(np.unique(mask))
-    if class_num != 2:
-        _, mask = cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY |
-                                cv2.THRESH_OTSU)
-    mask = np.clip(mask, 0, 1).astype("uint8")  # 0-255 / 0-1 -> 0-1
-    return mask
+import cv2
+import numpy as np
+import paddle
 
 
 def calc_distance(p1: np.ndarray, p2: np.ndarray) -> float:
     return float(np.sqrt(np.sum(np.power((p1[0] - p2[0]), 2))))
 
 
-def _del_small_connection(pred: np.ndarray, threshold: int=32) -> np.ndarray:
-    result = np.zeros_like(pred)
-    contours, reals = cv2.findContours(pred, cv2.RETR_TREE,
+def prepro_mask(input: Union[paddle.Tensor, np.ndarray]) -> np.ndarray:
+    """
+    Standardized mask.
+
+    Args:
+        input (Union[paddle.Tensor, np.ndarray]): Mask to refine, or user's mask.
+
+    Returns:
+        np.ndarray: Standard mask.
+    """
+    input_shape = input.shape
+    if isinstance(input, paddle.Tensor):
+        if len(input_shape) == 4:
+            mask = paddle.argmax(input, axis=1).squeeze_().numpy()
+        else:
+            raise ValueError("Invalid tensor, shape must be 4, not " + str(
+                input_shape) + ".")
+    else:
+        if len(input_shape) == 3:
+            mask = input[..., 0]
+        elif len(input_shape) == 2:
+            mask = input
+        else:
+            raise ValueError("Invalid ndarray, shape must be 2 or 3, not " +
+                             str(input_shape) + ".")
+        mask = mask.astype("uint8")
+        class_mask = np.unique(mask)
+        if len(class_mask) == 2:
+            mask = np.clip(mask, 0, 1)  # 0-255 / 0-1 -> 0-1
+        else:
+            if (max(class_mask) > (len(class_mask - 1))):
+                _, mask = cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY |
+                                        cv2.THRESH_OTSU)
+                mask = np.clip(mask, 0, 1)
+    return mask.astype("uint8")
+
+
+def del_small_connection(mask: np.ndarray, threshold: int=32) -> np.ndarray:
+    """
+    Delete the connected region whose pixel area is less than the threshold from mask.
+
+    Args:
+        mask (np.ndarray): Mask to refine. Shape is [H, W] and values are 0 or 1.
+        threshold (int, optional): Threshold of deleted area. Default is 32.
+
+    Returns:
+        np.ndarray: Mask after deleted samll connection.
+    """
+    result = np.zeros_like(mask)
+    contours, reals = cv2.findContours(mask, cv2.RETR_TREE,
                                        cv2.CHAIN_APPROX_NONE)
     for contour, real in zip(contours, reals[0]):
         if real[-1] == -1:
@@ -45,3 +81,85 @@ def _del_small_connection(pred: np.ndarray, threshold: int=32) -> np.ndarray:
         else:
             cv2.fillPoly(result, [contour], (0))
     return result.astype("uint8")
+
+
+def fill_small_holes(mask: np.ndarray, threshold: int=32) -> np.ndarray:
+    """
+    Fill the holed region whose pixel area is less than the threshold from mask.
+
+    Args:
+        mask (np.ndarray): Mask to refine. Shape is [H, W] and values are 0 or 1.
+        threshold (int, optional): Threshold of filled area. Default is 32.
+
+    Returns:
+        np.ndarray: Mask after deleted samll connection.
+    """
+    result = np.zeros_like(mask)
+    contours, reals = cv2.findContours(mask, cv2.RETR_TREE,
+                                       cv2.CHAIN_APPROX_NONE)
+    for contour, real in zip(contours, reals[0]):
+        # Fill father
+        if real[-1] == -1:
+            cv2.fillPoly(result, [contour], (1))
+        # Fill children whose area less than threshold
+        elif real[-1] != -1 and cv2.contourArea(contour) < threshold:
+            cv2.fillPoly(result, [contour], (1))
+        else:
+            cv2.fillPoly(result, [contour], (0))
+    return result.astype("uint8")
+
+
+def morphological_operation(mask: np.ndarray,
+                            ops: str="open",
+                            k_size: int=3,
+                            iterations: int=1) -> np.ndarray:
+    """
+    Morphological operation.
+    Open: It is used to separate objects and eliminate small areas.
+    Close: It is used to eliminating small holes.
+    Erode: It is used to refine goals.
+    Dilate: It is used to Coarse goals.
+
+    Args:
+        mask (np.ndarray): Mask to refine. Shape is [H, W].
+        ops (str): . Defaults to "open".
+        k_size (int, optional): Size of the structuring element. Defaults to 3.
+        iterations (int, optional): Number of times erosion and dilation are applied. Defaults to 1.
+
+    Returns:
+        np.ndarray: Morphologically processed mask.
+    """
+    kv = {
+        "open": cv2.MORPH_OPEN,
+        "close": cv2.MORPH_CLOSE,
+        "erode": cv2.MORPH_ERODE,
+        "dilate": cv2.MORPH_DILATE,
+    }
+    if ops.lower() not in kv.keys():
+        raise ValueError("Invalid ops: " + ops +
+                         ", `ops` must be `open/close/erode/dilate`.")
+    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_size, k_size))
+    opened = cv2.morphologyEx(
+        mask, kv[ops.lower()], kernel, iterations=iterations)
+    return opened.astype("uint8")
+
+
+def deal_one_class(mask: np.ndarray,
+                   class_index: int,
+                   func: Callable,
+                   **kwargs: Dict[str, Any]) -> np.ndarray:
+    """
+    Only a single category is processed. 
+
+    Args:
+        mask (np.ndarray): Mask to refine. Shape is [H, W].
+        class_index (int): Index of class of need processed.
+        func (Callable): Function of processed.
+
+    Returns:
+        np.ndarray: Processed Mask.
+    """
+    btmp = (mask == class_index).astype("uint8")
+    res = func(btmp, **kwargs)
+    res *= class_index
+    return np.where(btmp == 0, mask, res).astype("uint8")

+ 1 - 0
requirements.txt

@@ -18,6 +18,7 @@ openpyxl
 paddleslim >= 2.2.1,< 2.3.5
 pandas
 pycocotools
+# pydensecrf
 scikit-learn == 0.23.2
 scikit-image >= 0.14.0
 scipy

+ 1 - 0
tests/fast_tests.py

@@ -15,3 +15,4 @@
 from rs_models import *
 from tasks import *
 from transforms import *
+from postpros import *

+ 15 - 0
tests/postpros/__init__.py

@@ -0,0 +1,15 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .test_postpros import *

+ 123 - 0
tests/postpros/test_postpros.py

@@ -0,0 +1,123 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+from PIL import Image
+
+import numpy as np
+
+import paddle
+import paddlers.utils.postprocs as P
+from testing_utils import CpuCommonTest
+
+__all__ = ['TestPostProgress']
+
+
+class TestPostProgress(CpuCommonTest):
+    def setUp(self):
+        self.image1 = np.asarray(Image.open("data/ssmt/optical_t2.bmp"))
+        self.image2 = np.asarray(Image.open("data/ssmt/optical_t2.bmp"))
+        self.b_label = np.asarray(Image.open("data/ssmt/binary_gt.bmp")).clip(0,
+                                                                              1)
+        self.m_label = np.asarray(Image.open("data/ssmt/multiclass_gt2.png"))
+
+    def test_prepro_mask(self):
+        mask = copy.deepcopy(self.b_label)
+        mask = P.prepro_mask(mask)
+        self.check_output_equal(len(mask.shape), 2)
+        self.assertEqual(mask.dtype, np.uint8)
+        self.check_output_equal(np.unique(mask), np.array([0, 1]))
+        mask_tensor = paddle.randn((1, 3, 256, 256), dtype="float32")
+        mask_tensor = P.prepro_mask(mask_tensor)
+        self.check_output_equal(len(mask_tensor.shape), 2)
+        self.assertEqual(mask_tensor.dtype, np.uint8)
+        self.check_output_equal(np.unique(mask_tensor), np.array([0, 1, 2]))
+
+    def test_del_small_connection(self):
+        mask = copy.deepcopy(self.b_label)
+        mask = P.prepro_mask(mask)
+        mask = P.del_small_connection(mask)
+        self.check_output_equal(mask.shape, self.b_label.shape)
+        self.assertEqual(mask.dtype, self.b_label.dtype)
+        self.check_output_equal(np.unique(mask), np.unique(self.b_label))
+
+    def test_fill_small_holes(self):
+        mask = copy.deepcopy(self.b_label)
+        mask = P.prepro_mask(mask)
+        mask = P.fill_small_holes(mask)
+        self.check_output_equal(mask.shape, self.b_label.shape)
+        self.assertEqual(mask.dtype, self.b_label.dtype)
+        self.check_output_equal(np.unique(mask), np.unique(self.b_label))
+
+    def test_morphological_operation(self):
+        mask = copy.deepcopy(self.b_label)
+        mask = P.prepro_mask(mask)
+        for op in ["open", "close", "erode", "dilate"]:
+            mask = P.morphological_operation(mask, op)
+            self.check_output_equal(mask.shape, self.b_label.shape)
+            self.assertEqual(mask.dtype, self.b_label.dtype)
+            self.check_output_equal(np.unique(mask), np.unique(self.b_label))
+
+    def test_building_regularization(self):
+        mask = copy.deepcopy(self.b_label)
+        mask = P.prepro_mask(mask)
+        mask = P.building_regularization(mask)
+        self.check_output_equal(mask.shape, self.b_label.shape)
+        self.assertEqual(mask.dtype, self.b_label.dtype)
+        self.check_output_equal(np.unique(mask), np.unique(self.b_label))
+
+    def test_cut_road_connection(self):
+        mask = copy.deepcopy(self.b_label)
+        mask = P.prepro_mask(mask)
+        mask = P.cut_road_connection(mask)
+        self.check_output_equal(mask.shape, self.b_label.shape)
+        self.assertEqual(mask.dtype, self.b_label.dtype)
+        self.check_output_equal(np.unique(mask), np.unique(self.b_label))
+
+    def test_conditional_random_field(self):
+        if "conditional_random_field" in dir(P):
+            mask = copy.deepcopy(self.m_label)
+            mask = P.prepro_mask(mask)
+            mask = P.conditional_random_field(self.image2, mask)
+            self.check_output_equal(mask.shape, self.m_label.shape)
+            self.assertEqual(mask.dtype, self.m_label.dtype)
+            self.check_output_equal(np.unique(mask), np.unique(self.m_label))
+
+    def test_markov_random_field(self):
+        mask = copy.deepcopy(self.m_label)
+        mask = P.prepro_mask(mask)
+        mask = P.markov_random_field(self.image2, mask)
+        self.check_output_equal(mask.shape, self.m_label.shape)
+        self.assertEqual(mask.dtype, self.m_label.dtype)
+        self.check_output_equal(np.unique(mask), np.unique(self.m_label))
+
+    def test_deal_one_class(self):
+        mask = copy.deepcopy(self.m_label)
+        mask = P.prepro_mask(mask)
+        func = P.morphological_operation
+        mask = P.deal_one_class(mask, 1, func, ops="dilate")
+        self.check_output_equal(mask.shape, self.m_label.shape)
+        self.assertEqual(mask.dtype, self.m_label.dtype)
+        self.check_output_equal(np.unique(mask), np.unique(self.m_label))
+
+    def test_change_(self):
+        mask = copy.deepcopy(self.m_label)
+        mask = P.prepro_mask(mask)
+        mask = P.change_detection_filter(mask, self.image1, self.image2, 0.8,
+                                         0.8, "GLI", {"b": 3,
+                                                      "g": 2,
+                                                      "r": 1})
+        self.check_output_equal(mask.shape, self.m_label.shape)
+        self.assertEqual(mask.dtype, self.m_label.dtype)
+        self.check_output_equal(np.unique(mask), np.unique(self.m_label))

+ 3 - 0
tests/run_ci_dev.sh

@@ -23,6 +23,9 @@ echo -e '*****************paddlers_version****'
 git rev-parse HEAD
 
 python -m pip install --user -r requirements.txt
+# According to 
+# https://stackoverflow.com/questions/74972995/opencv-aws-lambda-lib64-libz-so-1-version-zlib-1-2-9-not-found
+python -m pip install opencv-contrib-python==4.6.0.66
 python -m pip install --user -e .
 python -m pip install --user https://versaweb.dl.sourceforge.net/project/gdal-wheels-for-linux/GDAL-3.4.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl