|
|
@@ -781,7 +781,7 @@ class BaseSegmenter(BaseModel):
|
|
|
|
|
|
class UNet(BaseSegmenter):
|
|
|
def __init__(self,
|
|
|
- input_channel=3,
|
|
|
+ in_channels=3,
|
|
|
num_classes=2,
|
|
|
use_mixed_loss=False,
|
|
|
use_deconv=False,
|
|
|
@@ -793,7 +793,7 @@ class UNet(BaseSegmenter):
|
|
|
})
|
|
|
super(UNet, self).__init__(
|
|
|
model_name='UNet',
|
|
|
- input_channel=input_channel,
|
|
|
+ input_channel=in_channels,
|
|
|
num_classes=num_classes,
|
|
|
use_mixed_loss=use_mixed_loss,
|
|
|
**params)
|
|
|
@@ -801,7 +801,7 @@ class UNet(BaseSegmenter):
|
|
|
|
|
|
class DeepLabV3P(BaseSegmenter):
|
|
|
def __init__(self,
|
|
|
- input_channel=3,
|
|
|
+ in_channels=3,
|
|
|
num_classes=2,
|
|
|
backbone='ResNet50_vd',
|
|
|
use_mixed_loss=False,
|
|
|
@@ -819,7 +819,7 @@ class DeepLabV3P(BaseSegmenter):
|
|
|
if params.get('with_net', True):
|
|
|
with DisablePrint():
|
|
|
backbone = getattr(paddleseg.models, backbone)(
|
|
|
- input_channel=input_channel, output_stride=output_stride)
|
|
|
+ input_channel=in_channels, output_stride=output_stride)
|
|
|
else:
|
|
|
backbone = None
|
|
|
params.update({
|