|
@@ -142,7 +142,7 @@ class BaseSegmenter(BaseModel):
|
|
|
pred = paddle.unsqueeze(net_out[0], axis=1) # NCHW
|
|
|
else:
|
|
|
pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
|
|
|
- label = inputs['mask']
|
|
|
+ label = inputs['mask'].astype('int64')
|
|
|
if label.ndim == 3:
|
|
|
paddle.unsqueeze_(label, axis=1)
|
|
|
if label.ndim != 4:
|
|
@@ -158,7 +158,9 @@ class BaseSegmenter(BaseModel):
|
|
|
self.num_classes)
|
|
|
if mode == 'train':
|
|
|
loss_list = metrics.loss_computation(
|
|
|
- logits_list=net_out, labels=inputs['mask'], losses=self.losses)
|
|
|
+ logits_list=net_out,
|
|
|
+ labels=inputs['mask'].astype('int64'),
|
|
|
+ losses=self.losses)
|
|
|
loss = sum(loss_list)
|
|
|
outputs['loss'] = loss
|
|
|
return outputs
|
|
@@ -947,7 +949,7 @@ class C2FNet(BaseSegmenter):
|
|
|
pred = paddle.unsqueeze(net_out[0], axis=1) # NCHW
|
|
|
else:
|
|
|
pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
|
|
|
- label = inputs['mask']
|
|
|
+ label = inputs['mask'].astype('int64')
|
|
|
if label.ndim == 3:
|
|
|
paddle.unsqueeze_(label, axis=1)
|
|
|
if label.ndim != 4:
|
|
@@ -962,7 +964,8 @@ class C2FNet(BaseSegmenter):
|
|
|
outputs['conf_mat'] = metrics.confusion_matrix(pred, label,
|
|
|
self.num_classes)
|
|
|
if mode == 'train':
|
|
|
- net_out = net(inputs['image'], heatmaps, inputs['mask'])
|
|
|
+ net_out = net(inputs['image'], heatmaps,
|
|
|
+ inputs['mask'].astype('int64'))
|
|
|
logit = [net_out[0], ]
|
|
|
labels = net_out[1]
|
|
|
outputs = OrderedDict()
|