|
@@ -23,15 +23,16 @@ import paddle
|
|
|
import paddle.nn.functional as F
|
|
|
from paddle.static import InputSpec
|
|
|
|
|
|
-import paddlers.models.ppseg as paddleseg
|
|
|
-import paddlers.rs_models.seg as cmseg
|
|
|
import paddlers
|
|
|
-from paddlers.utils import get_single_card_bs, DisablePrint
|
|
|
+import paddlers.models.ppseg as ppseg
|
|
|
+import paddlers.rs_models.seg as cmseg
|
|
|
import paddlers.utils.logging as logging
|
|
|
+from paddlers.models import seg_losses
|
|
|
+from paddlers.transforms import Resize, decode_image
|
|
|
+from paddlers.utils import get_single_card_bs, DisablePrint
|
|
|
+from paddlers.utils.checkpoint import seg_pretrain_weights_dict
|
|
|
from .base import BaseModel
|
|
|
from .utils import seg_metrics as metrics
|
|
|
-from paddlers.utils.checkpoint import seg_pretrain_weights_dict
|
|
|
-from paddlers.transforms import Resize, decode_image
|
|
|
|
|
|
__all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"]
|
|
|
|
|
@@ -41,19 +42,20 @@ class BaseSegmenter(BaseModel):
|
|
|
model_name,
|
|
|
num_classes=2,
|
|
|
use_mixed_loss=False,
|
|
|
+ losses=None,
|
|
|
**params):
|
|
|
self.init_params = locals()
|
|
|
if 'with_net' in self.init_params:
|
|
|
del self.init_params['with_net']
|
|
|
super(BaseSegmenter, self).__init__('segmenter')
|
|
|
- if not hasattr(paddleseg.models, model_name) and \
|
|
|
+ if not hasattr(ppseg.models, model_name) and \
|
|
|
not hasattr(cmseg, model_name):
|
|
|
raise ValueError("ERROR: There is no model named {}.".format(
|
|
|
model_name))
|
|
|
self.model_name = model_name
|
|
|
self.num_classes = num_classes
|
|
|
self.use_mixed_loss = use_mixed_loss
|
|
|
- self.losses = None
|
|
|
+ self.losses = losses
|
|
|
self.labels = None
|
|
|
if params.get('with_net', True):
|
|
|
params.pop('with_net', None)
|
|
@@ -63,9 +65,8 @@ class BaseSegmenter(BaseModel):
|
|
|
def build_net(self, **params):
|
|
|
# TODO: when using paddle.utils.unique_name.guard,
|
|
|
# DeepLabv3p and HRNet will raise a error
|
|
|
- net = dict(paddleseg.models.__dict__,
|
|
|
- **cmseg.__dict__)[self.model_name](
|
|
|
- num_classes=self.num_classes, **params)
|
|
|
+ net = dict(ppseg.models.__dict__, **cmseg.__dict__)[self.model_name](
|
|
|
+ num_classes=self.num_classes, **params)
|
|
|
return net
|
|
|
|
|
|
def _fix_transforms_shape(self, image_shape):
|
|
@@ -143,7 +144,7 @@ class BaseSegmenter(BaseModel):
|
|
|
origin_shape = [label.shape[-2:]]
|
|
|
pred = self._postprocess(
|
|
|
pred, origin_shape, transforms=inputs[2])[0] # NCHW
|
|
|
- intersect_area, pred_area, label_area = paddleseg.utils.metrics.calculate_area(
|
|
|
+ intersect_area, pred_area, label_area = ppseg.utils.metrics.calculate_area(
|
|
|
pred, label, self.num_classes)
|
|
|
outputs['intersect_area'] = intersect_area
|
|
|
outputs['pred_area'] = pred_area
|
|
@@ -161,16 +162,13 @@ class BaseSegmenter(BaseModel):
|
|
|
if isinstance(self.use_mixed_loss, bool):
|
|
|
if self.use_mixed_loss:
|
|
|
losses = [
|
|
|
- paddleseg.models.CrossEntropyLoss(),
|
|
|
- paddleseg.models.LovaszSoftmaxLoss()
|
|
|
+ seg_losses.CrossEntropyLoss(),
|
|
|
+ seg_losses.LovaszSoftmaxLoss()
|
|
|
]
|
|
|
coef = [.8, .2]
|
|
|
- loss_type = [
|
|
|
- paddleseg.models.MixedLoss(
|
|
|
- losses=losses, coef=coef),
|
|
|
- ]
|
|
|
+ loss_type = [seg_losses.MixedLoss(losses=losses, coef=coef), ]
|
|
|
else:
|
|
|
- loss_type = [paddleseg.models.CrossEntropyLoss()]
|
|
|
+ loss_type = [seg_losses.CrossEntropyLoss()]
|
|
|
else:
|
|
|
losses, coef = list(zip(*self.use_mixed_loss))
|
|
|
if not set(losses).issubset(
|
|
@@ -178,11 +176,8 @@ class BaseSegmenter(BaseModel):
|
|
|
raise ValueError(
|
|
|
"Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported."
|
|
|
)
|
|
|
- losses = [getattr(paddleseg.models, loss)() for loss in losses]
|
|
|
- loss_type = [
|
|
|
- paddleseg.models.MixedLoss(
|
|
|
- losses=losses, coef=list(coef))
|
|
|
- ]
|
|
|
+ losses = [getattr(seg_losses, loss)() for loss in losses]
|
|
|
+ loss_type = [seg_losses.MixedLoss(losses=losses, coef=list(coef))]
|
|
|
if self.model_name == 'FastSCNN':
|
|
|
loss_type *= 2
|
|
|
loss_coef = [1.0, 0.4]
|
|
@@ -475,13 +470,13 @@ class BaseSegmenter(BaseModel):
|
|
|
pred_area_all = pred_area_all + pred_area
|
|
|
label_area_all = label_area_all + label_area
|
|
|
conf_mat_all.append(conf_mat)
|
|
|
- class_iou, miou = paddleseg.utils.metrics.mean_iou(
|
|
|
+ class_iou, miou = ppseg.utils.metrics.mean_iou(
|
|
|
intersect_area_all, pred_area_all, label_area_all)
|
|
|
# TODO 确认是按oacc还是macc
|
|
|
- class_acc, oacc = paddleseg.utils.metrics.accuracy(intersect_area_all,
|
|
|
- pred_area_all)
|
|
|
- kappa = paddleseg.utils.metrics.kappa(intersect_area_all, pred_area_all,
|
|
|
- label_area_all)
|
|
|
+ class_acc, oacc = ppseg.utils.metrics.accuracy(intersect_area_all,
|
|
|
+ pred_area_all)
|
|
|
+ kappa = ppseg.utils.metrics.kappa(intersect_area_all, pred_area_all,
|
|
|
+ label_area_all)
|
|
|
category_f1score = metrics.f1_score(intersect_area_all, pred_area_all,
|
|
|
label_area_all)
|
|
|
eval_metrics = OrderedDict(
|
|
@@ -613,15 +608,15 @@ class BaseSegmenter(BaseModel):
|
|
|
ysize = int(height - yoff)
|
|
|
im = src_data.ReadAsArray(int(xoff), int(yoff), xsize,
|
|
|
ysize).transpose((1, 2, 0))
|
|
|
- # fill
|
|
|
+ # Fill
|
|
|
h, w = im.shape[:2]
|
|
|
im_fill = np.zeros(
|
|
|
(block_size[1], block_size[0], bands), dtype=im.dtype)
|
|
|
im_fill[:h, :w, :] = im
|
|
|
- # predict
|
|
|
+ # Predict
|
|
|
pred = self.predict(im_fill,
|
|
|
transforms)["label_map"].astype("uint8")
|
|
|
- # overlap
|
|
|
+ # Overlap
|
|
|
rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize)
|
|
|
mask = (rd_block == pred[:h, :w]) | (rd_block == 255)
|
|
|
temp = pred[:h, :w].copy()
|
|
@@ -778,12 +773,18 @@ class BaseSegmenter(BaseModel):
|
|
|
raise TypeError(
|
|
|
"`transforms.arrange` must be an ArrangeSegmenter object.")
|
|
|
|
|
|
+ def set_losses(self, losses, weights=None):
|
|
|
+ if weights is None:
|
|
|
+ weights = [1. for _ in range(len(losses))]
|
|
|
+ self.losses = {'types': losses, 'coef': weights}
|
|
|
+
|
|
|
|
|
|
class UNet(BaseSegmenter):
|
|
|
def __init__(self,
|
|
|
input_channel=3,
|
|
|
num_classes=2,
|
|
|
use_mixed_loss=False,
|
|
|
+ losses=None,
|
|
|
use_deconv=False,
|
|
|
align_corners=False,
|
|
|
**params):
|
|
@@ -796,6 +797,7 @@ class UNet(BaseSegmenter):
|
|
|
input_channel=input_channel,
|
|
|
num_classes=num_classes,
|
|
|
use_mixed_loss=use_mixed_loss,
|
|
|
+ losses=losses,
|
|
|
**params)
|
|
|
|
|
|
|
|
@@ -805,6 +807,7 @@ class DeepLabV3P(BaseSegmenter):
|
|
|
num_classes=2,
|
|
|
backbone='ResNet50_vd',
|
|
|
use_mixed_loss=False,
|
|
|
+ losses=None,
|
|
|
output_stride=8,
|
|
|
backbone_indices=(0, 3),
|
|
|
aspp_ratios=(1, 12, 24, 36),
|
|
@@ -818,7 +821,7 @@ class DeepLabV3P(BaseSegmenter):
|
|
|
"{'ResNet50_vd', 'ResNet101_vd'}.".format(backbone))
|
|
|
if params.get('with_net', True):
|
|
|
with DisablePrint():
|
|
|
- backbone = getattr(paddleseg.models, backbone)(
|
|
|
+ backbone = getattr(ppseg.models, backbone)(
|
|
|
input_channel=input_channel, output_stride=output_stride)
|
|
|
else:
|
|
|
backbone = None
|
|
@@ -833,6 +836,7 @@ class DeepLabV3P(BaseSegmenter):
|
|
|
model_name='DeepLabV3P',
|
|
|
num_classes=num_classes,
|
|
|
use_mixed_loss=use_mixed_loss,
|
|
|
+ losses=losses,
|
|
|
**params)
|
|
|
|
|
|
|
|
@@ -840,6 +844,7 @@ class FastSCNN(BaseSegmenter):
|
|
|
def __init__(self,
|
|
|
num_classes=2,
|
|
|
use_mixed_loss=False,
|
|
|
+ losses=None,
|
|
|
align_corners=False,
|
|
|
**params):
|
|
|
params.update({'align_corners': align_corners})
|
|
@@ -847,6 +852,7 @@ class FastSCNN(BaseSegmenter):
|
|
|
model_name='FastSCNN',
|
|
|
num_classes=num_classes,
|
|
|
use_mixed_loss=use_mixed_loss,
|
|
|
+ losses=losses,
|
|
|
**params)
|
|
|
|
|
|
|
|
@@ -855,6 +861,7 @@ class HRNet(BaseSegmenter):
|
|
|
num_classes=2,
|
|
|
width=48,
|
|
|
use_mixed_loss=False,
|
|
|
+ losses=None,
|
|
|
align_corners=False,
|
|
|
**params):
|
|
|
if width not in (18, 48):
|
|
@@ -864,7 +871,7 @@ class HRNet(BaseSegmenter):
|
|
|
self.backbone_name = 'HRNet_W{}'.format(width)
|
|
|
if params.get('with_net', True):
|
|
|
with DisablePrint():
|
|
|
- backbone = getattr(paddleseg.models, self.backbone_name)(
|
|
|
+ backbone = getattr(ppseg.models, self.backbone_name)(
|
|
|
align_corners=align_corners)
|
|
|
else:
|
|
|
backbone = None
|
|
@@ -874,6 +881,7 @@ class HRNet(BaseSegmenter):
|
|
|
model_name='FCN',
|
|
|
num_classes=num_classes,
|
|
|
use_mixed_loss=use_mixed_loss,
|
|
|
+ losses=losses,
|
|
|
**params)
|
|
|
self.model_name = 'HRNet'
|
|
|
|
|
@@ -882,6 +890,7 @@ class BiSeNetV2(BaseSegmenter):
|
|
|
def __init__(self,
|
|
|
num_classes=2,
|
|
|
use_mixed_loss=False,
|
|
|
+ losses=None,
|
|
|
align_corners=False,
|
|
|
**params):
|
|
|
params.update({'align_corners': align_corners})
|
|
@@ -889,13 +898,19 @@ class BiSeNetV2(BaseSegmenter):
|
|
|
model_name='BiSeNetV2',
|
|
|
num_classes=num_classes,
|
|
|
use_mixed_loss=use_mixed_loss,
|
|
|
+ losses=losses,
|
|
|
**params)
|
|
|
|
|
|
|
|
|
class FarSeg(BaseSegmenter):
|
|
|
- def __init__(self, num_classes=2, use_mixed_loss=False, **params):
|
|
|
+ def __init__(self,
|
|
|
+ num_classes=2,
|
|
|
+ use_mixed_loss=False,
|
|
|
+ losses=None,
|
|
|
+ **params):
|
|
|
super(FarSeg, self).__init__(
|
|
|
model_name='FarSeg',
|
|
|
num_classes=num_classes,
|
|
|
use_mixed_loss=use_mixed_loss,
|
|
|
+ losses=losses,
|
|
|
**params)
|