Selaa lähdekoodia

[Feat] Add predict() method for ChangeDetector (#74)

Lin Manhui 3 vuotta sitten
vanhempi
commit
8024d83606
1 muutettua tiedostoa jossa 30 lisäystä ja 14 poistoa
  1. 30 14
      paddlers/tasks/change_detector.py

+ 30 - 14
paddlers/tasks/change_detector.py

@@ -381,6 +381,12 @@ class BaseChangeDetector(BaseModel):
 
         Returns:
             collections.OrderedDict with key-value pairs:
+                For binary change detection (number of classes == 2), the key-value pairs are like:
+                {"iou": `intersection over union for the change class`,
+                 "f1": `F1 score for the change class`,
+                 "oacc": `overall accuracy`,
+                 "kappa": ` kappa coefficient`}.
+                For multi-class change detection (number of classes > 2), the key-value pairs are like:
                 {"miou": `mean intersection over union`,
                  "category_iou": `category-wise mean intersection over union`,
                  "oacc": `overall accuracy`,
@@ -408,7 +414,7 @@ class BaseChangeDetector(BaseModel):
             batch_size_each_card = 1
             batch_size = batch_size_each_card * paddlers.env_info['num']
             logging.warning(
-                "Segmenter only supports batch_size=1 for each gpu/cpu card " \
+                "ChangeDetector only supports batch_size=1 for each gpu/cpu card " \
                 "during evaluation, so batch_size " \
                 "is forcibly set to {}.".format(batch_size)
             )
@@ -471,11 +477,17 @@ class BaseChangeDetector(BaseModel):
                                               label_area_all)
         category_f1score = metrics.f1_score(intersect_area_all, pred_area_all,
                                             label_area_all)
-        eval_metrics = OrderedDict(
-            zip([
-                'miou', 'category_iou', 'oacc', 'category_acc', 'kappa',
-                'category_F1-score'
-            ], [miou, class_iou, oacc, class_acc, kappa, category_f1score]))
+
+        if len(class_acc) > 2:
+            eval_metrics = OrderedDict(
+                zip([
+                    'miou', 'category_iou', 'oacc', 'category_acc', 'kappa',
+                    'category_F1-score'
+                ], [miou, class_iou, oacc, class_acc, kappa, category_f1score]))
+        else:
+            eval_metrics = OrderedDict(
+                zip(['iou', 'f1', 'oacc', 'kappa'],
+                    [class_iou[1], category_f1score[1], oacc, kappa]))
 
         if return_details:
             conf_mat = sum(conf_mat_all)
@@ -488,14 +500,14 @@ class BaseChangeDetector(BaseModel):
         Do inference.
         Args:
             Args:
-            img_file(List[np.ndarray or str], str or np.ndarray):
-                Image path or decoded image data in a BGR format, which also could constitute a list,
-                meaning all images to be predicted as a mini-batch.
+            img_file(List[tuple], Tuple[str or np.ndarray]):
+                Tuple of image paths or decoded image data in a BGR format for bi-temporal images, which also could constitute 
+                a list, meaning all image pairs to be predicted as a mini-batch.
             transforms(paddlers.transforms.Compose or None, optional):
                 Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
 
         Returns:
-            If img_file is a string or np.array, the result is a dict with key-value pairs:
+            If img_file is a tuple of string or np.array, the result is a dict with key-value pairs:
             {"label map": `label map`, "score_map": `score map`}.
             If img_file is a list, the result is a list composed of dicts with the corresponding fields:
             label_map(np.ndarray): the predicted label map (HW)
@@ -506,14 +518,18 @@ class BaseChangeDetector(BaseModel):
             raise Exception("transforms need to be defined, now is None.")
         if transforms is None:
             transforms = self.test_transforms
-        if isinstance(img_file, (str, np.ndarray)):
+        if isinstance(img_file, tuple):
+            if not len(img_file) == 2 and any(
+                    map(lambda obj: not isinstance(obj, (str, np.ndarray)),
+                        img_file)):
+                raise TypeError
             images = [img_file]
         else:
             images = img_file
-        batch_im, batch_origin_shape = self._preprocess(images, transforms,
-                                                        self.model_type)
+        batch_im1, batch_im2, batch_origin_shape = self._preprocess(
+            images, transforms, self.model_type)
         self.net.eval()
-        data = (batch_im, batch_origin_shape, transforms.transforms)
+        data = (batch_im1, batch_im2, batch_origin_shape, transforms.transforms)
         outputs = self.run(self.net, data, 'test')
         label_map_list = outputs['label_map']
         score_map_list = outputs['score_map']