|
@@ -103,11 +103,11 @@ class Predictor(object):
|
|
|
config.enable_use_gpu(200, gpu_id)
|
|
|
config.switch_ir_optim(True)
|
|
|
if use_trt:
|
|
|
- if self._model.model_type == 'segmenter':
|
|
|
+ if self.model_type == 'segmenter':
|
|
|
logging.warning(
|
|
|
"Semantic segmentation models do not support TensorRT acceleration, "
|
|
|
"TensorRT is forcibly disabled.")
|
|
|
- elif self._model.model_type == 'detector' and 'RCNN' in self._model.__class__.__name__:
|
|
|
+ elif self.model_type == 'detector' and 'RCNN' in self._model.__class__.__name__:
|
|
|
logging.warning(
|
|
|
"RCNN models do not support TensorRT acceleration, "
|
|
|
"TensorRT is forcibly disabled.")
|
|
@@ -150,30 +150,29 @@ class Predictor(object):
|
|
|
def preprocess(self, images, transforms):
|
|
|
preprocessed_samples = self._model.preprocess(
|
|
|
images, transforms, to_tensor=False)
|
|
|
- if self._model.model_type == 'classifier':
|
|
|
+ if self.model_type == 'classifier':
|
|
|
preprocessed_samples = {'image': preprocessed_samples[0]}
|
|
|
- elif self._model.model_type == 'segmenter':
|
|
|
+ elif self.model_type == 'segmenter':
|
|
|
preprocessed_samples = {
|
|
|
'image': preprocessed_samples[0],
|
|
|
'ori_shape': preprocessed_samples[1]
|
|
|
}
|
|
|
- elif self._model.model_type == 'detector':
|
|
|
+ elif self.model_type == 'detector':
|
|
|
pass
|
|
|
- elif self._model.model_type == 'change_detector':
|
|
|
+ elif self.model_type == 'change_detector':
|
|
|
preprocessed_samples = {
|
|
|
'image': preprocessed_samples[0],
|
|
|
'image2': preprocessed_samples[1],
|
|
|
'ori_shape': preprocessed_samples[2]
|
|
|
}
|
|
|
- elif self._model.model_type == 'restorer':
|
|
|
+ elif self.model_type == 'restorer':
|
|
|
preprocessed_samples = {
|
|
|
'image': preprocessed_samples[0],
|
|
|
'tar_shape': preprocessed_samples[1]
|
|
|
}
|
|
|
else:
|
|
|
logging.error(
|
|
|
- "Invalid model type {}".format(self._model.model_type),
|
|
|
- exit=True)
|
|
|
+ "Invalid model type {}".format(self.model_type), exit=True)
|
|
|
return preprocessed_samples
|
|
|
|
|
|
def postprocess(self,
|
|
@@ -182,7 +181,7 @@ class Predictor(object):
|
|
|
ori_shape=None,
|
|
|
tar_shape=None,
|
|
|
transforms=None):
|
|
|
- if self._model.model_type == 'classifier':
|
|
|
+ if self.model_type == 'classifier':
|
|
|
true_topk = min(self._model.num_classes, topk)
|
|
|
if self._model.postprocess is None:
|
|
|
self._model.build_postprocess_from_labels(topk)
|
|
@@ -198,7 +197,7 @@ class Predictor(object):
|
|
|
'scores_map': s,
|
|
|
'label_names_map': n,
|
|
|
} for l, s, n in zip(class_ids, scores, label_names)]
|
|
|
- elif self._model.model_type in ('segmenter', 'change_detector'):
|
|
|
+ elif self.model_type in ('segmenter', 'change_detector'):
|
|
|
label_map, score_map = self._model.postprocess(
|
|
|
net_outputs,
|
|
|
batch_origin_shape=ori_shape,
|
|
@@ -207,13 +206,13 @@ class Predictor(object):
|
|
|
'label_map': l,
|
|
|
'score_map': s
|
|
|
} for l, s in zip(label_map, score_map)]
|
|
|
- elif self._model.model_type == 'detector':
|
|
|
+ elif self.model_type == 'detector':
|
|
|
net_outputs = {
|
|
|
k: v
|
|
|
for k, v in zip(['bbox', 'bbox_num', 'mask'], net_outputs)
|
|
|
}
|
|
|
preds = self._model.postprocess(net_outputs)
|
|
|
- elif self._model.model_type == 'restorer':
|
|
|
+ elif self.model_type == 'restorer':
|
|
|
res_maps = self._model.postprocess(
|
|
|
net_outputs[0],
|
|
|
batch_tar_shape=tar_shape,
|
|
@@ -221,8 +220,7 @@ class Predictor(object):
|
|
|
preds = [{'res_map': res_map} for res_map in res_maps]
|
|
|
else:
|
|
|
logging.error(
|
|
|
- "Invalid model type {}.".format(self._model.model_type),
|
|
|
- exit=True)
|
|
|
+ "Invalid model type {}.".format(self.model_type), exit=True)
|
|
|
|
|
|
return preds
|
|
|
|
|
@@ -360,6 +358,12 @@ class Predictor(object):
|
|
|
batch_size (int, optional): Batch size used in inference. Defaults to 1.
|
|
|
quiet (bool, optional): If True, disable the progress bar. Defaults to False.
|
|
|
"""
|
|
|
+
|
|
|
+ if self.model_type not in ('segmenter', 'change_detector'):
|
|
|
+ raise RuntimeError(
|
|
|
+ "Model type is {}, which does not support inference with sliding windows.".
|
|
|
+ format(self.model_type))
|
|
|
+
|
|
|
slider_predict(
|
|
|
partial(
|
|
|
self.predict, quiet=True),
|
|
@@ -375,3 +379,7 @@ class Predictor(object):
|
|
|
|
|
|
def batch_predict(self, image_list, **params):
|
|
|
return self.predict(img_file=image_list, **params)
|
|
|
+
|
|
|
+ @property
|
|
|
+ def model_type(self):
|
|
|
+ return self._model.model_type
|