|
@@ -111,10 +111,10 @@ class BaseChangeDetector(BaseModel):
|
|
|
if mode == 'test':
|
|
|
origin_shape = inputs[2]
|
|
|
if self.status == 'Infer':
|
|
|
- label_map_list, score_map_list = self._postprocess(
|
|
|
+ label_map_list, score_map_list = self.postprocess(
|
|
|
net_out, origin_shape, transforms=inputs[3])
|
|
|
else:
|
|
|
- logit_list = self._postprocess(
|
|
|
+ logit_list = self.postprocess(
|
|
|
logit, origin_shape, transforms=inputs[3])
|
|
|
label_map_list = []
|
|
|
score_map_list = []
|
|
@@ -142,7 +142,7 @@ class BaseChangeDetector(BaseModel):
|
|
|
raise ValueError("Expected label.ndim == 4 but got {}".format(
|
|
|
label.ndim))
|
|
|
origin_shape = [label.shape[-2:]]
|
|
|
- pred = self._postprocess(
|
|
|
+ pred = self.postprocess(
|
|
|
pred, origin_shape, transforms=inputs[3])[0] # NCHW
|
|
|
intersect_area, pred_area, label_area = paddleseg.utils.metrics.calculate_area(
|
|
|
pred, label, self.num_classes)
|
|
@@ -553,7 +553,7 @@ class BaseChangeDetector(BaseModel):
|
|
|
images = [img_file]
|
|
|
else:
|
|
|
images = img_file
|
|
|
- batch_im1, batch_im2, batch_origin_shape = self._preprocess(
|
|
|
+ batch_im1, batch_im2, batch_origin_shape = self.preprocess(
|
|
|
images, transforms, self.model_type)
|
|
|
self.net.eval()
|
|
|
data = (batch_im1, batch_im2, batch_origin_shape, transforms.transforms)
|
|
@@ -664,7 +664,7 @@ class BaseChangeDetector(BaseModel):
|
|
|
dst_data = None
|
|
|
print("GeoTiff saved in {}.".format(save_file))
|
|
|
|
|
|
- def _preprocess(self, images, transforms, to_tensor=True):
|
|
|
+ def preprocess(self, images, transforms, to_tensor=True):
|
|
|
self._check_transforms(transforms, 'test')
|
|
|
batch_im1, batch_im2 = list(), list()
|
|
|
batch_ori_shape = list()
|
|
@@ -736,7 +736,7 @@ class BaseChangeDetector(BaseModel):
|
|
|
batch_restore_list.append(restore_list)
|
|
|
return batch_restore_list
|
|
|
|
|
|
- def _postprocess(self, batch_pred, batch_origin_shape, transforms):
|
|
|
+ def postprocess(self, batch_pred, batch_origin_shape, transforms):
|
|
|
batch_restore_list = BaseChangeDetector.get_transforms_shape_info(
|
|
|
batch_origin_shape, transforms)
|
|
|
if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
|