|
@@ -28,7 +28,6 @@ import paddlers.rs_models.seg as cmseg
|
|
|
import paddlers.utils.logging as logging
|
|
|
from paddlers.models import seg_losses
|
|
|
from paddlers.transforms import Resize, decode_image, construct_sample
|
|
|
-from paddlers.transforms.operators import ArrangeSegmenter
|
|
|
from paddlers.utils import get_single_card_bs, DisablePrint
|
|
|
from paddlers.utils.checkpoint import seg_pretrain_weights_dict
|
|
|
from .base import BaseModel
|
|
@@ -43,8 +42,6 @@ __all__ = [
|
|
|
|
|
|
|
|
|
class BaseSegmenter(BaseModel):
|
|
|
- _arrange = ArrangeSegmenter
|
|
|
-
|
|
|
def __init__(self,
|
|
|
model_name,
|
|
|
num_classes=2,
|
|
@@ -117,7 +114,7 @@ class BaseSegmenter(BaseModel):
|
|
|
|
|
|
def run(self, net, inputs, mode):
|
|
|
inputs, batch_restore_list = inputs
|
|
|
- net_out = net(inputs[0])
|
|
|
+ net_out = net(inputs['image'])
|
|
|
logit = net_out[0]
|
|
|
outputs = OrderedDict()
|
|
|
if mode == 'test':
|
|
@@ -145,7 +142,7 @@ class BaseSegmenter(BaseModel):
|
|
|
pred = paddle.unsqueeze(net_out[0], axis=1) # NCHW
|
|
|
else:
|
|
|
pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
|
|
|
- label = inputs[1]
|
|
|
+ label = inputs['mask']
|
|
|
if label.ndim == 3:
|
|
|
paddle.unsqueeze_(label, axis=1)
|
|
|
if label.ndim != 4:
|
|
@@ -161,7 +158,7 @@ class BaseSegmenter(BaseModel):
|
|
|
self.num_classes)
|
|
|
if mode == 'train':
|
|
|
loss_list = metrics.loss_computation(
|
|
|
- logits_list=net_out, labels=inputs[1], losses=self.losses)
|
|
|
+ logits_list=net_out, labels=inputs['mask'], losses=self.losses)
|
|
|
loss = sum(loss_list)
|
|
|
outputs['loss'] = loss
|
|
|
return outputs
|
|
@@ -429,7 +426,6 @@ class BaseSegmenter(BaseModel):
|
|
|
"is forcibly set to {}.".format(batch_size))
|
|
|
self.eval_data_loader = self.build_data_loader(
|
|
|
eval_dataset, batch_size=batch_size, mode='eval')
|
|
|
- self._check_arrange(eval_dataset.transforms, 'eval')
|
|
|
|
|
|
intersect_area_all = 0
|
|
|
pred_area_all = 0
|
|
@@ -528,8 +524,6 @@ class BaseSegmenter(BaseModel):
|
|
|
images = [img_file]
|
|
|
else:
|
|
|
images = img_file
|
|
|
- transforms = self._build_transforms(transforms, "test")
|
|
|
- self._check_arrange(transforms, "test")
|
|
|
data = self.preprocess(images, transforms, self.model_type)
|
|
|
self.net.eval()
|
|
|
outputs = self.run(self.net, data, 'test')
|
|
@@ -599,7 +593,7 @@ class BaseSegmenter(BaseModel):
|
|
|
im = decode_image(im, read_raw=True)
|
|
|
sample = construct_sample(image=im)
|
|
|
data = transforms(sample)
|
|
|
- im = data[0][0]
|
|
|
+ im = data[0]['image']
|
|
|
trans_info = data[1]
|
|
|
batch_im.append(im)
|
|
|
batch_trans_info.append(trans_info)
|
|
@@ -608,7 +602,7 @@ class BaseSegmenter(BaseModel):
|
|
|
else:
|
|
|
batch_im = np.asarray(batch_im)
|
|
|
|
|
|
- return (batch_im, ), batch_trans_info
|
|
|
+ return {'image': batch_im}, batch_trans_info
|
|
|
|
|
|
def postprocess(self, batch_pred, batch_restore_list):
|
|
|
if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
|
|
@@ -681,12 +675,6 @@ class BaseSegmenter(BaseModel):
|
|
|
score_maps.append(score_map.squeeze())
|
|
|
return label_maps, score_maps
|
|
|
|
|
|
- def _check_arrange(self, transforms, mode):
|
|
|
- super()._check_arrange(transforms, mode)
|
|
|
- if not isinstance(transforms.arrange, ArrangeSegmenter):
|
|
|
- raise TypeError(
|
|
|
- "`transforms.arrange` must be an `ArrangeSegmenter` object.")
|
|
|
-
|
|
|
def set_losses(self, losses, weights=None):
|
|
|
if weights is None:
|
|
|
weights = [1. for _ in range(len(losses))]
|
|
@@ -924,15 +912,14 @@ class C2FNet(BaseSegmenter):
|
|
|
def run(self, net, inputs, mode):
|
|
|
inputs, batch_restore_list = inputs
|
|
|
with paddle.no_grad():
|
|
|
- pre_coarse = self.coarse_model(inputs[0])
|
|
|
+ pre_coarse = self.coarse_model(inputs['image'])
|
|
|
pre_coarse = pre_coarse[0]
|
|
|
heatmaps = pre_coarse
|
|
|
|
|
|
if mode == 'test':
|
|
|
- net_out = net(inputs[0], heatmaps)
|
|
|
+ net_out = net(inputs['image'], heatmaps)
|
|
|
logit = net_out[0]
|
|
|
outputs = OrderedDict()
|
|
|
- origin_shape = inputs[1]
|
|
|
if self.status == 'Infer':
|
|
|
label_map_list, score_map_list = self.postprocess(
|
|
|
net_out, batch_restore_list)
|
|
@@ -953,19 +940,19 @@ class C2FNet(BaseSegmenter):
|
|
|
outputs['score_map'] = score_map_list
|
|
|
|
|
|
if mode == 'eval':
|
|
|
- net_out = net(inputs[0], heatmaps)
|
|
|
+ net_out = net(inputs['image'], heatmaps)
|
|
|
logit = net_out[0]
|
|
|
outputs = OrderedDict()
|
|
|
if self.status == 'Infer':
|
|
|
pred = paddle.unsqueeze(net_out[0], axis=1) # NCHW
|
|
|
else:
|
|
|
pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
|
|
|
- label = inputs[1]
|
|
|
+ label = inputs['mask']
|
|
|
if label.ndim == 3:
|
|
|
paddle.unsqueeze_(label, axis=1)
|
|
|
if label.ndim != 4:
|
|
|
- raise ValueError("Expected label.ndim == 4 but got {}".format(
|
|
|
- label.ndim))
|
|
|
+ raise ValueError(
|
|
|
+ "Expected `label.ndim` == 4 but got {}.".format(label.ndim))
|
|
|
pred = self.postprocess(pred, batch_restore_list)[0] # NCHW
|
|
|
intersect_area, pred_area, label_area = ppseg.utils.metrics.calculate_area(
|
|
|
pred, label, self.num_classes)
|
|
@@ -975,7 +962,7 @@ class C2FNet(BaseSegmenter):
|
|
|
outputs['conf_mat'] = metrics.confusion_matrix(pred, label,
|
|
|
self.num_classes)
|
|
|
if mode == 'train':
|
|
|
- net_out = net(inputs[0], heatmaps, inputs[1])
|
|
|
+ net_out = net(inputs['image'], heatmaps, inputs['mask'])
|
|
|
logit = [net_out[0], ]
|
|
|
labels = net_out[1]
|
|
|
outputs = OrderedDict()
|