crf.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from skimage.color import gray2rgb
  16. import pydensecrf.densecrf as dcrf
  17. from pydensecrf.utils import unary_from_labels
  18. def conditional_random_field(original_image: np.ndarray,
  19. mask: np.ndarray) -> np.ndarray:
  20. """
  21. Conditional random field.
  22. The original article refers to
  23. Krhenbühl, Philipp, Koltun V. "Efficient Inference in Fully Connected CRFs with Gaussian Edge Potentials."
  24. (https://arxiv.org/abs/1210.5644v1).
  25. The implementation procedure refers to this repo:
  26. https://github.com/apletea/Computer-Vision
  27. Args:
  28. original_image (np.ndarray): Original image. Shape is [H, W, 3].
  29. mask (np.ndarray): Mask to refine. Shape is [H, W].
  30. Returns:
  31. np.ndarray: Mask after CRF.
  32. """
  33. n_labels = len(np.unique(mask))
  34. mask3 = gray2rgb(mask)
  35. annotated_label = mask3[:, :, 0] + (mask3[:, :, 1] << 8) + (mask3[:, :, 2]
  36. << 16)
  37. _, labels = np.unique(annotated_label, return_inverse=True)
  38. img_shape = original_image.shape
  39. d = dcrf.DenseCRF2D(img_shape[1], img_shape[0], n_labels)
  40. U = unary_from_labels(labels, n_labels, gt_prob=0.7, zero_unsure=False)
  41. d.setUnaryEnergy(U)
  42. d.addPairwiseGaussian(
  43. sxy=(3, 3),
  44. compat=3,
  45. kernel=dcrf.DIAG_KERNEL,
  46. normalization=dcrf.NORMALIZE_SYMMETRIC)
  47. Q = d.inference(10)
  48. MAP = np.argmax(Q, axis=0)
  49. MAP = MAP.reshape((img_shape[0], img_shape[1]))
  50. return MAP.astype("uint8")