Browse Source

input_channel->in_channels

Bobholamovic 3 years ago
parent
commit
7696de288c
1 changed files with 4 additions and 4 deletions
  1. 4 4
      paddlers/tasks/segmenter.py

+ 4 - 4
paddlers/tasks/segmenter.py

@@ -781,7 +781,7 @@ class BaseSegmenter(BaseModel):
 
 class UNet(BaseSegmenter):
     def __init__(self,
-                 input_channel=3,
+                 in_channels=3,
                  num_classes=2,
                  use_mixed_loss=False,
                  use_deconv=False,
@@ -793,7 +793,7 @@ class UNet(BaseSegmenter):
         })
         super(UNet, self).__init__(
             model_name='UNet',
-            input_channel=input_channel,
+            input_channel=in_channels,
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
             **params)
@@ -801,7 +801,7 @@ class UNet(BaseSegmenter):
 
 class DeepLabV3P(BaseSegmenter):
     def __init__(self,
-                 input_channel=3,
+                 in_channels=3,
                  num_classes=2,
                  backbone='ResNet50_vd',
                  use_mixed_loss=False,
@@ -819,7 +819,7 @@ class DeepLabV3P(BaseSegmenter):
         if params.get('with_net', True):
             with DisablePrint():
                 backbone = getattr(paddleseg.models, backbone)(
-                    input_channel=input_channel, output_stride=output_stride)
+                    input_channel=in_channels, output_stride=output_stride)
         else:
             backbone = None
         params.update({