浏览代码

[Feat] Add predict() method for ChangeDetector (#74)

Lin Manhui 3 年之前
父节点
当前提交
8024d83606
共有 1 个文件被更改,包括 30 次插入14 次删除
  1. 30 14
      paddlers/tasks/change_detector.py

+ 30 - 14
paddlers/tasks/change_detector.py

@@ -381,6 +381,12 @@ class BaseChangeDetector(BaseModel):
 
 
         Returns:
         Returns:
             collections.OrderedDict with key-value pairs:
             collections.OrderedDict with key-value pairs:
+                For binary change detection (number of classes == 2), the key-value pairs are like:
+                {"iou": `intersection over union for the change class`,
+                 "f1": `F1 score for the change class`,
+                 "oacc": `overall accuracy`,
+                 "kappa": ` kappa coefficient`}.
+                For multi-class change detection (number of classes > 2), the key-value pairs are like:
                 {"miou": `mean intersection over union`,
                 {"miou": `mean intersection over union`,
                  "category_iou": `category-wise mean intersection over union`,
                  "category_iou": `category-wise mean intersection over union`,
                  "oacc": `overall accuracy`,
                  "oacc": `overall accuracy`,
@@ -408,7 +414,7 @@ class BaseChangeDetector(BaseModel):
             batch_size_each_card = 1
             batch_size_each_card = 1
             batch_size = batch_size_each_card * paddlers.env_info['num']
             batch_size = batch_size_each_card * paddlers.env_info['num']
             logging.warning(
             logging.warning(
-                "Segmenter only supports batch_size=1 for each gpu/cpu card " \
+                "ChangeDetector only supports batch_size=1 for each gpu/cpu card " \
                 "during evaluation, so batch_size " \
                 "during evaluation, so batch_size " \
                 "is forcibly set to {}.".format(batch_size)
                 "is forcibly set to {}.".format(batch_size)
             )
             )
@@ -471,11 +477,17 @@ class BaseChangeDetector(BaseModel):
                                               label_area_all)
                                               label_area_all)
         category_f1score = metrics.f1_score(intersect_area_all, pred_area_all,
         category_f1score = metrics.f1_score(intersect_area_all, pred_area_all,
                                             label_area_all)
                                             label_area_all)
-        eval_metrics = OrderedDict(
-            zip([
-                'miou', 'category_iou', 'oacc', 'category_acc', 'kappa',
-                'category_F1-score'
-            ], [miou, class_iou, oacc, class_acc, kappa, category_f1score]))
+
+        if len(class_acc) > 2:
+            eval_metrics = OrderedDict(
+                zip([
+                    'miou', 'category_iou', 'oacc', 'category_acc', 'kappa',
+                    'category_F1-score'
+                ], [miou, class_iou, oacc, class_acc, kappa, category_f1score]))
+        else:
+            eval_metrics = OrderedDict(
+                zip(['iou', 'f1', 'oacc', 'kappa'],
+                    [class_iou[1], category_f1score[1], oacc, kappa]))
 
 
         if return_details:
         if return_details:
             conf_mat = sum(conf_mat_all)
             conf_mat = sum(conf_mat_all)
@@ -488,14 +500,14 @@ class BaseChangeDetector(BaseModel):
         Do inference.
         Do inference.
         Args:
         Args:
             Args:
             Args:
-            img_file(List[np.ndarray or str], str or np.ndarray):
-                Image path or decoded image data in a BGR format, which also could constitute a list,
-                meaning all images to be predicted as a mini-batch.
+            img_file(List[tuple], Tuple[str or np.ndarray]):
+                Tuple of image paths or decoded image data in a BGR format for bi-temporal images, which also could constitute 
+                a list, meaning all image pairs to be predicted as a mini-batch.
             transforms(paddlers.transforms.Compose or None, optional):
             transforms(paddlers.transforms.Compose or None, optional):
                 Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
                 Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
 
 
         Returns:
         Returns:
-            If img_file is a string or np.array, the result is a dict with key-value pairs:
+            If img_file is a tuple of string or np.array, the result is a dict with key-value pairs:
             {"label map": `label map`, "score_map": `score map`}.
             {"label map": `label map`, "score_map": `score map`}.
             If img_file is a list, the result is a list composed of dicts with the corresponding fields:
             If img_file is a list, the result is a list composed of dicts with the corresponding fields:
             label_map(np.ndarray): the predicted label map (HW)
             label_map(np.ndarray): the predicted label map (HW)
@@ -506,14 +518,18 @@ class BaseChangeDetector(BaseModel):
             raise Exception("transforms need to be defined, now is None.")
             raise Exception("transforms need to be defined, now is None.")
         if transforms is None:
         if transforms is None:
             transforms = self.test_transforms
             transforms = self.test_transforms
-        if isinstance(img_file, (str, np.ndarray)):
+        if isinstance(img_file, tuple):
+            if not len(img_file) == 2 and any(
+                    map(lambda obj: not isinstance(obj, (str, np.ndarray)),
+                        img_file)):
+                raise TypeError
             images = [img_file]
             images = [img_file]
         else:
         else:
             images = img_file
             images = img_file
-        batch_im, batch_origin_shape = self._preprocess(images, transforms,
-                                                        self.model_type)
+        batch_im1, batch_im2, batch_origin_shape = self._preprocess(
+            images, transforms, self.model_type)
         self.net.eval()
         self.net.eval()
-        data = (batch_im, batch_origin_shape, transforms.transforms)
+        data = (batch_im1, batch_im2, batch_origin_shape, transforms.transforms)
         outputs = self.run(self.net, data, 'test')
         outputs = self.run(self.net, data, 'test')
         label_map_list = outputs['label_map']
         label_map_list = outputs['label_map']
         score_map_list = outputs['score_map']
         score_map_list = outputs['score_map']