Browse Source

Add model type check

Bobholamovic 2 năm trước cách đây
mục cha
commit
9c1b2ea2fe
1 tập tin đã thay đổi với 23 bổ sung15 xóa
  1. 23 15
      paddlers/deploy/predictor.py

+ 23 - 15
paddlers/deploy/predictor.py

@@ -103,11 +103,11 @@ class Predictor(object):
             config.enable_use_gpu(200, gpu_id)
             config.switch_ir_optim(True)
             if use_trt:
-                if self._model.model_type == 'segmenter':
+                if self.model_type == 'segmenter':
                     logging.warning(
                         "Semantic segmentation models do not support TensorRT acceleration, "
                         "TensorRT is forcibly disabled.")
-                elif self._model.model_type == 'detector' and 'RCNN' in self._model.__class__.__name__:
+                elif self.model_type == 'detector' and 'RCNN' in self._model.__class__.__name__:
                     logging.warning(
                         "RCNN models do not support TensorRT acceleration, "
                         "TensorRT is forcibly disabled.")
@@ -150,30 +150,29 @@ class Predictor(object):
     def preprocess(self, images, transforms):
         preprocessed_samples = self._model.preprocess(
             images, transforms, to_tensor=False)
-        if self._model.model_type == 'classifier':
+        if self.model_type == 'classifier':
             preprocessed_samples = {'image': preprocessed_samples[0]}
-        elif self._model.model_type == 'segmenter':
+        elif self.model_type == 'segmenter':
             preprocessed_samples = {
                 'image': preprocessed_samples[0],
                 'ori_shape': preprocessed_samples[1]
             }
-        elif self._model.model_type == 'detector':
+        elif self.model_type == 'detector':
             pass
-        elif self._model.model_type == 'change_detector':
+        elif self.model_type == 'change_detector':
             preprocessed_samples = {
                 'image': preprocessed_samples[0],
                 'image2': preprocessed_samples[1],
                 'ori_shape': preprocessed_samples[2]
             }
-        elif self._model.model_type == 'restorer':
+        elif self.model_type == 'restorer':
             preprocessed_samples = {
                 'image': preprocessed_samples[0],
                 'tar_shape': preprocessed_samples[1]
             }
         else:
             logging.error(
-                "Invalid model type {}".format(self._model.model_type),
-                exit=True)
+                "Invalid model type {}".format(self.model_type), exit=True)
         return preprocessed_samples
 
     def postprocess(self,
@@ -182,7 +181,7 @@ class Predictor(object):
                     ori_shape=None,
                     tar_shape=None,
                     transforms=None):
-        if self._model.model_type == 'classifier':
+        if self.model_type == 'classifier':
             true_topk = min(self._model.num_classes, topk)
             if self._model.postprocess is None:
                 self._model.build_postprocess_from_labels(topk)
@@ -198,7 +197,7 @@ class Predictor(object):
                 'scores_map': s,
                 'label_names_map': n,
             } for l, s, n in zip(class_ids, scores, label_names)]
-        elif self._model.model_type in ('segmenter', 'change_detector'):
+        elif self.model_type in ('segmenter', 'change_detector'):
             label_map, score_map = self._model.postprocess(
                 net_outputs,
                 batch_origin_shape=ori_shape,
@@ -207,13 +206,13 @@ class Predictor(object):
                 'label_map': l,
                 'score_map': s
             } for l, s in zip(label_map, score_map)]
-        elif self._model.model_type == 'detector':
+        elif self.model_type == 'detector':
             net_outputs = {
                 k: v
                 for k, v in zip(['bbox', 'bbox_num', 'mask'], net_outputs)
             }
             preds = self._model.postprocess(net_outputs)
-        elif self._model.model_type == 'restorer':
+        elif self.model_type == 'restorer':
             res_maps = self._model.postprocess(
                 net_outputs[0],
                 batch_tar_shape=tar_shape,
@@ -221,8 +220,7 @@ class Predictor(object):
             preds = [{'res_map': res_map} for res_map in res_maps]
         else:
             logging.error(
-                "Invalid model type {}.".format(self._model.model_type),
-                exit=True)
+                "Invalid model type {}.".format(self.model_type), exit=True)
 
         return preds
 
@@ -360,6 +358,12 @@ class Predictor(object):
             batch_size (int, optional): Batch size used in inference. Defaults to 1.
             quiet (bool, optional): If True, disable the progress bar. Defaults to False.
         """
+
+        if self.model_type not in ('segmenter', 'change_detector'):
+            raise RuntimeError(
+                "Model type is {}, which does not support inference with sliding windows.".
+                format(self.model_type))
+
         slider_predict(
             partial(
                 self.predict, quiet=True),
@@ -375,3 +379,7 @@ class Predictor(object):
 
     def batch_predict(self, image_list, **params):
         return self.predict(img_file=image_list, **params)
+
+    @property
+    def model_type(self):
+        return self._model.model_type