|
@@ -118,13 +118,12 @@ class BaseSegmenter(BaseModel):
|
|
|
logit = net_out[0]
|
|
|
outputs = OrderedDict()
|
|
|
if mode == 'test':
|
|
|
- origin_shape = inputs[1]
|
|
|
+ batch_restore_list = inputs[-1]
|
|
|
if self.status == 'Infer':
|
|
|
label_map_list, score_map_list = self.postprocess(
|
|
|
- net_out, origin_shape, transforms=inputs[2])
|
|
|
+ net_out, batch_restore_list)
|
|
|
else:
|
|
|
- logit_list = self.postprocess(
|
|
|
- logit, origin_shape, transforms=inputs[2])
|
|
|
+ logit_list = self.postprocess(logit, batch_restore_list)
|
|
|
label_map_list = []
|
|
|
score_map_list = []
|
|
|
for logit in logit_list:
|
|
@@ -140,6 +139,7 @@ class BaseSegmenter(BaseModel):
|
|
|
outputs['score_map'] = score_map_list
|
|
|
|
|
|
if mode == 'eval':
|
|
|
+ batch_restore_list = inputs[-1]
|
|
|
if self.status == 'Infer':
|
|
|
pred = paddle.unsqueeze(net_out[0], axis=1) # NCHW
|
|
|
else:
|
|
@@ -150,9 +150,7 @@ class BaseSegmenter(BaseModel):
|
|
|
if label.ndim != 4:
|
|
|
raise ValueError("Expected label.ndim == 4 but got {}".format(
|
|
|
label.ndim))
|
|
|
- origin_shape = [label.shape[-2:]]
|
|
|
- pred = self.postprocess(
|
|
|
- pred, origin_shape, transforms=inputs[2])[0] # NCHW
|
|
|
+ pred = self.postprocess(pred, batch_restore_list)[0] # NCHW
|
|
|
intersect_area, pred_area, label_area = ppseg.utils.metrics.calculate_area(
|
|
|
pred, label, self.num_classes)
|
|
|
outputs['intersect_area'] = intersect_area
|
|
@@ -441,7 +439,6 @@ class BaseSegmenter(BaseModel):
|
|
|
math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
|
|
|
with paddle.no_grad():
|
|
|
for step, data in enumerate(self.eval_data_loader):
|
|
|
- data.append(eval_dataset.transforms.transforms)
|
|
|
outputs = self.run(self.net, data, 'eval')
|
|
|
pred_area = outputs['pred_area']
|
|
|
label_area = outputs['label_area']
|
|
@@ -529,10 +526,10 @@ class BaseSegmenter(BaseModel):
|
|
|
images = [img_file]
|
|
|
else:
|
|
|
images = img_file
|
|
|
- batch_im, batch_origin_shape = self.preprocess(images, transforms,
|
|
|
- self.model_type)
|
|
|
+ batch_im, batch_trans_info = self.preprocess(images, transforms,
|
|
|
+ self.model_type)
|
|
|
self.net.eval()
|
|
|
- data = (batch_im, batch_origin_shape, transforms.transforms)
|
|
|
+ data = (batch_im, batch_trans_info)
|
|
|
outputs = self.run(self.net, data, 'test')
|
|
|
label_map_list = outputs['label_map']
|
|
|
score_map_list = outputs['score_map']
|
|
@@ -594,75 +591,24 @@ class BaseSegmenter(BaseModel):
|
|
|
def preprocess(self, images, transforms, to_tensor=True):
|
|
|
self._check_transforms(transforms, 'test')
|
|
|
batch_im = list()
|
|
|
- batch_ori_shape = list()
|
|
|
+ batch_trans_info = list()
|
|
|
for im in images:
|
|
|
if isinstance(im, str):
|
|
|
im = decode_image(im, read_raw=True)
|
|
|
- ori_shape = im.shape[:2]
|
|
|
sample = {'image': im}
|
|
|
- im = transforms(sample)[0]
|
|
|
+ data = transforms(sample)
|
|
|
+ im = data[0]
|
|
|
+ trans_info = data[-1]
|
|
|
batch_im.append(im)
|
|
|
- batch_ori_shape.append(ori_shape)
|
|
|
+ batch_trans_info.append(trans_info)
|
|
|
if to_tensor:
|
|
|
batch_im = paddle.to_tensor(batch_im)
|
|
|
else:
|
|
|
batch_im = np.asarray(batch_im)
|
|
|
|
|
|
- return batch_im, batch_ori_shape
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def get_transforms_shape_info(batch_ori_shape, transforms):
|
|
|
- # TODO: Store transform meta info when applying transforms
|
|
|
- # and not here
|
|
|
- batch_restore_list = list()
|
|
|
- for ori_shape in batch_ori_shape:
|
|
|
- restore_list = list()
|
|
|
- h, w = ori_shape[0], ori_shape[1]
|
|
|
- for op in transforms:
|
|
|
- if op.__class__.__name__ == 'Resize':
|
|
|
- restore_list.append(('resize', (h, w)))
|
|
|
- h, w = op.target_size
|
|
|
- elif op.__class__.__name__ == 'ResizeByShort':
|
|
|
- restore_list.append(('resize', (h, w)))
|
|
|
- im_short_size = min(h, w)
|
|
|
- im_long_size = max(h, w)
|
|
|
- scale = float(op.short_size) / float(im_short_size)
|
|
|
- if 0 < op.max_size < np.round(scale * im_long_size):
|
|
|
- scale = float(op.max_size) / float(im_long_size)
|
|
|
- h = int(round(h * scale))
|
|
|
- w = int(round(w * scale))
|
|
|
- elif op.__class__.__name__ == 'ResizeByLong':
|
|
|
- restore_list.append(('resize', (h, w)))
|
|
|
- im_long_size = max(h, w)
|
|
|
- scale = float(op.long_size) / float(im_long_size)
|
|
|
- h = int(round(h * scale))
|
|
|
- w = int(round(w * scale))
|
|
|
- elif op.__class__.__name__ == 'Pad':
|
|
|
- if op.target_size:
|
|
|
- target_h, target_w = op.target_size
|
|
|
- else:
|
|
|
- target_h = int(
|
|
|
- (np.ceil(h / op.size_divisor) * op.size_divisor))
|
|
|
- target_w = int(
|
|
|
- (np.ceil(w / op.size_divisor) * op.size_divisor))
|
|
|
-
|
|
|
- if op.pad_mode == -1:
|
|
|
- offsets = op.offsets
|
|
|
- elif op.pad_mode == 0:
|
|
|
- offsets = [0, 0]
|
|
|
- elif op.pad_mode == 1:
|
|
|
- offsets = [(target_h - h) // 2, (target_w - w) // 2]
|
|
|
- else:
|
|
|
- offsets = [target_h - h, target_w - w]
|
|
|
- restore_list.append(('padding', (h, w), offsets))
|
|
|
- h, w = target_h, target_w
|
|
|
+ return batch_im, batch_trans_info
|
|
|
|
|
|
- batch_restore_list.append(restore_list)
|
|
|
- return batch_restore_list
|
|
|
-
|
|
|
- def postprocess(self, batch_pred, batch_origin_shape, transforms):
|
|
|
- batch_restore_list = BaseSegmenter.get_transforms_shape_info(
|
|
|
- batch_origin_shape, transforms)
|
|
|
+ def postprocess(self, batch_pred, batch_restore_list):
|
|
|
if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
|
|
|
return self._infer_postprocess(
|
|
|
batch_label_map=batch_pred[0],
|
|
@@ -979,17 +925,18 @@ class C2FNet(BaseSegmenter):
|
|
|
pre_coarse = self.coarse_model(inputs[0])
|
|
|
pre_coarse = pre_coarse[0]
|
|
|
heatmaps = pre_coarse
|
|
|
+
|
|
|
if mode == 'test':
|
|
|
+ batch_restore_list = inputs[-1]
|
|
|
net_out = net(inputs[0], heatmaps)
|
|
|
logit = net_out[0]
|
|
|
outputs = OrderedDict()
|
|
|
origin_shape = inputs[1]
|
|
|
if self.status == 'Infer':
|
|
|
label_map_list, score_map_list = self.postprocess(
|
|
|
- net_out, origin_shape, transforms=inputs[2])
|
|
|
+ net_out, batch_restore_list)
|
|
|
else:
|
|
|
- logit_list = self.postprocess(
|
|
|
- logit, origin_shape, transforms=inputs[2])
|
|
|
+ logit_list = self.postprocess(logit, batch_restore_list)
|
|
|
label_map_list = []
|
|
|
score_map_list = []
|
|
|
for logit in logit_list:
|
|
@@ -1005,6 +952,7 @@ class C2FNet(BaseSegmenter):
|
|
|
outputs['score_map'] = score_map_list
|
|
|
|
|
|
if mode == 'eval':
|
|
|
+ batch_restore_list = inputs[-1]
|
|
|
net_out = net(inputs[0], heatmaps)
|
|
|
logit = net_out[0]
|
|
|
outputs = OrderedDict()
|
|
@@ -1018,9 +966,7 @@ class C2FNet(BaseSegmenter):
|
|
|
if label.ndim != 4:
|
|
|
raise ValueError("Expected label.ndim == 4 but got {}".format(
|
|
|
label.ndim))
|
|
|
- origin_shape = [label.shape[-2:]]
|
|
|
- pred = self.postprocess(
|
|
|
- pred, origin_shape, transforms=inputs[2])[0] # NCHW
|
|
|
+ pred = self.postprocess(pred, batch_restore_list)[0] # NCHW
|
|
|
intersect_area, pred_area, label_area = ppseg.utils.metrics.calculate_area(
|
|
|
pred, label, self.num_classes)
|
|
|
outputs['intersect_area'] = intersect_area
|