| 
					
				 | 
			
			
				@@ -103,11 +103,11 @@ class Predictor(object): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             config.enable_use_gpu(200, gpu_id) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             config.switch_ir_optim(True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if use_trt: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                if self._model.model_type == 'segmenter': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if self.model_type == 'segmenter': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     logging.warning( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         "Semantic segmentation models do not support TensorRT acceleration, " 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         "TensorRT is forcibly disabled.") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                elif self._model.model_type == 'detector' and 'RCNN' in self._model.__class__.__name__: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                elif self.model_type == 'detector' and 'RCNN' in self._model.__class__.__name__: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     logging.warning( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         "RCNN models do not support TensorRT acceleration, " 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         "TensorRT is forcibly disabled.") 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -150,30 +150,29 @@ class Predictor(object): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def preprocess(self, images, transforms): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         preprocessed_samples = self._model.preprocess( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             images, transforms, to_tensor=False) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if self._model.model_type == 'classifier': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if self.model_type == 'classifier': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             preprocessed_samples = {'image': preprocessed_samples[0]} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        elif self._model.model_type == 'segmenter': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif self.model_type == 'segmenter': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             preprocessed_samples = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 'image': preprocessed_samples[0], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 'ori_shape': preprocessed_samples[1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        elif self._model.model_type == 'detector': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif self.model_type == 'detector': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             pass 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        elif self._model.model_type == 'change_detector': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif self.model_type == 'change_detector': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             preprocessed_samples = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 'image': preprocessed_samples[0], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 'image2': preprocessed_samples[1], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 'ori_shape': preprocessed_samples[2] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        elif self._model.model_type == 'restorer': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif self.model_type == 'restorer': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             preprocessed_samples = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 'image': preprocessed_samples[0], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 'tar_shape': preprocessed_samples[1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             logging.error( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                "Invalid model type {}".format(self._model.model_type), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                exit=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "Invalid model type {}".format(self.model_type), exit=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return preprocessed_samples 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def postprocess(self, 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -182,7 +181,7 @@ class Predictor(object): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     ori_shape=None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     tar_shape=None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     transforms=None): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if self._model.model_type == 'classifier': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if self.model_type == 'classifier': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             true_topk = min(self._model.num_classes, topk) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if self._model.postprocess is None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 self._model.build_postprocess_from_labels(topk) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -198,7 +197,7 @@ class Predictor(object): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 'scores_map': s, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 'label_names_map': n, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             } for l, s, n in zip(class_ids, scores, label_names)] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        elif self._model.model_type in ('segmenter', 'change_detector'): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif self.model_type in ('segmenter', 'change_detector'): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             label_map, score_map = self._model.postprocess( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 net_outputs, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 batch_origin_shape=ori_shape, 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -207,13 +206,13 @@ class Predictor(object): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 'label_map': l, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 'score_map': s 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             } for l, s in zip(label_map, score_map)] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        elif self._model.model_type == 'detector': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif self.model_type == 'detector': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             net_outputs = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 k: v 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 for k, v in zip(['bbox', 'bbox_num', 'mask'], net_outputs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             preds = self._model.postprocess(net_outputs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        elif self._model.model_type == 'restorer': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif self.model_type == 'restorer': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             res_maps = self._model.postprocess( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 net_outputs[0], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 batch_tar_shape=tar_shape, 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -221,8 +220,7 @@ class Predictor(object): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             preds = [{'res_map': res_map} for res_map in res_maps] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             logging.error( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                "Invalid model type {}.".format(self._model.model_type), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                exit=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "Invalid model type {}.".format(self.model_type), exit=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return preds 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -360,6 +358,12 @@ class Predictor(object): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             batch_size (int, optional): Batch size used in inference. Defaults to 1. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             quiet (bool, optional): If True, disable the progress bar. Defaults to False. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if self.model_type not in ('segmenter', 'change_detector'): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            raise RuntimeError( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "Model type is {}, which does not support inference with sliding windows.". 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                format(self.model_type)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         slider_predict( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             partial( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 self.predict, quiet=True), 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -375,3 +379,7 @@ class Predictor(object): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def batch_predict(self, image_list, **params): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return self.predict(img_file=image_list, **params) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    @property 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def model_type(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return self._model.model_type 
			 |