Pārlūkot izejas kodu

Optimize visualize

Bobholamovic 2 gadi atpakaļ
vecāks
revīzija
66f65d6967
1 mainītis faili ar 24 papildinājumiem un 19 dzēšanām
  1. 24 19
      paddlers/tasks/utils/visualize.py

+ 24 - 19
paddlers/tasks/utils/visualize.py

@@ -17,6 +17,7 @@ import cv2
 import numpy as np
 import time
 
+import paddlers.transforms as T
 import paddlers.utils.logging as logging
 from paddlers.utils import is_pic
 from .det_metrics.coco_utils import loadRes
@@ -44,17 +45,20 @@ def visualize_detection(image, result, threshold=0.5, save_dir='./',
         return image
 
 
-def visualize_segmentation(image, result, weight=0.6, save_dir='./',
+def visualize_segmentation(result,
+                           image=None,
+                           weight=0.6,
+                           save_dir=None,
                            color=None):
     """
-    Convert segment result to color image, and save added image.
+    Convert segmentation result to color image, and save the mixed image.
 
     Args:
-        image (str): Path of original image.
         result (dict): Predicted results.
+        image (str|numpy.ndarray|None, optional): Path of original image. Defaults to None.
         weight (float, optional): Weight used to mix the original image with the predicted image.
             Defaults to 0.6.
-        save_dir (str, optional): Directory for saving visualized image. Defaults to './'.
+        save_dir (str|None, optional): Directory for saving visualized image. Defaults to None.
         color (list|None): None or list of BGR indices for each label. Defaults to None.
     """
 
@@ -71,26 +75,27 @@ def visualize_segmentation(image, result, weight=0.6, save_dir='./',
     c3 = cv2.LUT(label_map, color_map[:, 2])
     pseudo_img = np.dstack((c1, c2, c3))
 
-    if isinstance(image, np.ndarray):
-        im = image
-        image_name = str(int(time.time() * 1000)) + '.jpg'
-        if image.shape[2] != 3:
-            logging.info(
-                "The image is not 3-channel array, so predicted label map is shown as a pseudo color image."
-            )
-            weight = 0.
-    else:
+    if isinstance(image, str):
+        # Automatically parse image name
         image_name = os.path.split(image)[-1]
         if not is_pic(image):
-            logging.info(
-                "The image cannot be opened by opencv, so predicted label map is shown as a pseudo color image."
-            )
             image_name = image_name.split('.')[0] + '.jpg'
             weight = 0.
         else:
-            im = cv2.imread(image)
+            # Use BGR for backward compatibility
+            im = T.decode_image(image, to_rgb=False)
+
+        if im.shape[2] != 3:
+            logging.warning(
+                "The image does not have exactly 3 channels, so the visualized result will be shown as a pseudo color image."
+            )
+            weight = 0.
+    else:
+        image_name = str(int(time.time() * 1000)) + '.jpg'
+        if isinstance(image, np.ndarray):
+            im = image
 
-    if abs(weight) < 1e-5:
+    if image is None or abs(weight) < 1e-5:
         vis_result = pseudo_img
     else:
         vis_result = cv2.addWeighted(im, weight,
@@ -101,7 +106,7 @@ def visualize_segmentation(image, result, weight=0.6, save_dir='./',
             os.makedirs(save_dir)
         out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
         cv2.imwrite(out_path, vis_result)
-        logging.info('The visualized result is saved as {}'.format(out_path))
+        logging.info('The visualized result is saved in {}.'.format(out_path))
     else:
         return vis_result