Explorar o código

[Feature] Update multispectral scene classification (#36)

Yizhou Chen %!s(int64=3) %!d(string=hai) anos
pai
achega
39c82d943a
Modificáronse 37 ficheiros con 1155 adicións e 723 borrados
  1. 10 1
      paddlers/custom_models/cd/__init__.py
  2. 0 0
      paddlers/custom_models/cd/backbones/__init__.py
  3. 0 0
      paddlers/custom_models/cd/backbones/resnet.py
  4. 0 0
      paddlers/custom_models/cd/bit.py
  5. 0 0
      paddlers/custom_models/cd/cdnet.py
  6. 0 0
      paddlers/custom_models/cd/changestar.py
  7. 0 0
      paddlers/custom_models/cd/dsamnet.py
  8. 0 0
      paddlers/custom_models/cd/dsifn.py
  9. 0 0
      paddlers/custom_models/cd/layers/__init__.py
  10. 0 0
      paddlers/custom_models/cd/layers/attention.py
  11. 0 0
      paddlers/custom_models/cd/layers/blocks.py
  12. 0 24
      paddlers/custom_models/cd/models/__init__.py
  13. 0 0
      paddlers/custom_models/cd/param_init.py
  14. 0 0
      paddlers/custom_models/cd/snunet.py
  15. 0 0
      paddlers/custom_models/cd/stanet.py
  16. 0 0
      paddlers/custom_models/cd/unet_ef.py
  17. 0 0
      paddlers/custom_models/cd/unet_siamconc.py
  18. 0 0
      paddlers/custom_models/cd/unet_siamdiff.py
  19. 2 0
      paddlers/custom_models/cls/__init__.py
  20. 441 0
      paddlers/custom_models/cls/condensenet_v2.py
  21. 1 1
      paddlers/custom_models/seg/__init__.py
  22. 84 2
      paddlers/custom_models/seg/farseg.py
  23. 0 1
      paddlers/custom_models/seg/layers/__init__.py
  24. 8 0
      paddlers/custom_models/seg/layers/layers_lib.py
  25. 0 0
      paddlers/custom_models/seg/layers/param_init.py
  26. 0 15
      paddlers/custom_models/seg/models/__init__.py
  27. 0 15
      paddlers/custom_models/seg/models/farseg/__init__.py
  28. 0 98
      paddlers/custom_models/seg/models/farseg/fpn.py
  29. 0 23
      paddlers/custom_models/seg/models/utils/torch_nn.py
  30. 2 2
      paddlers/datasets/clas_dataset.py
  31. 3 2
      paddlers/tasks/base.py
  32. 3 3
      paddlers/tasks/changedetector.py
  33. 544 520
      paddlers/tasks/classifier.py
  34. 3 3
      paddlers/tasks/segmenter.py
  35. 1 1
      paddlers/transforms/functions.py
  36. 49 0
      tutorials/train/classification/condensenetv2_b_rs_mul.py
  37. 4 12
      tutorials/train/classification/resnet50_vd_rs.py

+ 10 - 1
paddlers/custom_models/cd/__init__.py

@@ -12,4 +12,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from . import models
+from .bit import BIT
+from .cdnet import CDNet
+from .dsifn import DSIFN
+from .stanet import STANet
+from .snunet import SNUNet
+from .dsamnet import DSAMNet
+from .changestar import ChangeStar
+from .unet_ef import UNetEarlyFusion
+from .unet_siamconc import UNetSiamConc
+from .unet_siamdiff import UNetSiamDiff

+ 0 - 0
paddlers/custom_models/cd/models/backbones/__init__.py → paddlers/custom_models/cd/backbones/__init__.py


+ 0 - 0
paddlers/custom_models/cd/models/backbones/resnet.py → paddlers/custom_models/cd/backbones/resnet.py


+ 0 - 0
paddlers/custom_models/cd/models/bit.py → paddlers/custom_models/cd/bit.py


+ 0 - 0
paddlers/custom_models/cd/models/cdnet.py → paddlers/custom_models/cd/cdnet.py


+ 0 - 0
paddlers/custom_models/cd/models/changestar.py → paddlers/custom_models/cd/changestar.py


+ 0 - 0
paddlers/custom_models/cd/models/dsamnet.py → paddlers/custom_models/cd/dsamnet.py


+ 0 - 0
paddlers/custom_models/cd/models/dsifn.py → paddlers/custom_models/cd/dsifn.py


+ 0 - 0
paddlers/custom_models/cd/models/layers/__init__.py → paddlers/custom_models/cd/layers/__init__.py


+ 0 - 0
paddlers/custom_models/cd/models/layers/attention.py → paddlers/custom_models/cd/layers/attention.py


+ 0 - 0
paddlers/custom_models/cd/models/layers/blocks.py → paddlers/custom_models/cd/layers/blocks.py


+ 0 - 24
paddlers/custom_models/cd/models/__init__.py

@@ -1,24 +0,0 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from .bit import BIT
-from .cdnet import CDNet
-from .dsifn import DSIFN
-from .stanet import STANet
-from .snunet import SNUNet
-from .dsamnet import DSAMNet
-from .changestar import ChangeStar
-from .unet_ef import UNetEarlyFusion
-from .unet_siamconc import UNetSiamConc
-from .unet_siamdiff import UNetSiamDiff

+ 0 - 0
paddlers/custom_models/cd/models/param_init.py → paddlers/custom_models/cd/param_init.py


+ 0 - 0
paddlers/custom_models/cd/models/snunet.py → paddlers/custom_models/cd/snunet.py


+ 0 - 0
paddlers/custom_models/cd/models/stanet.py → paddlers/custom_models/cd/stanet.py


+ 0 - 0
paddlers/custom_models/cd/models/unet_ef.py → paddlers/custom_models/cd/unet_ef.py


+ 0 - 0
paddlers/custom_models/cd/models/unet_siamconc.py → paddlers/custom_models/cd/unet_siamconc.py


+ 0 - 0
paddlers/custom_models/cd/models/unet_siamdiff.py → paddlers/custom_models/cd/unet_siamdiff.py


+ 2 - 0
paddlers/custom_models/cls/__init__.py

@@ -11,3 +11,5 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
+from .condensenet_v2 import CondenseNetV2_a, CondenseNetV2_b, CondenseNetV2_c

+ 441 - 0
paddlers/custom_models/cls/condensenet_v2.py

@@ -0,0 +1,441 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is based on https://github.com/AgentMaker/Paddle-Image-Models
+Ths copyright of AgentMaker/Paddle-Image-Models is as follows:
+Apache License [see LICENSE for details]
+"""
+
+import paddle
+import paddle.nn as nn
+
+__all__ = ["CondenseNetV2_a", "CondenseNetV2_b", "CondenseNetV2_c"]
+
+
+class SELayer(nn.Layer):
+    def __init__(self, inplanes, reduction=16):
+        super(SELayer, self).__init__()
+        self.avg_pool = nn.AdaptiveAvgPool2D(1)
+        self.fc = nn.Sequential(
+            nn.Linear(
+                inplanes, inplanes // reduction, bias_attr=False),
+            nn.ReLU(),
+            nn.Linear(
+                inplanes // reduction, inplanes, bias_attr=False),
+            nn.Sigmoid(), )
+
+    def forward(self, x):
+        b, c, _, _ = x.shape
+        y = self.avg_pool(x).reshape((b, c))
+        y = self.fc(y).reshape((b, c, 1, 1))
+        return x * y.expand_as(x)
+
+
+class HS(nn.Layer):
+    def __init__(self):
+        super(HS, self).__init__()
+        self.relu6 = nn.ReLU6()
+
+    def forward(self, inputs):
+        return inputs * self.relu6(inputs + 3) / 6
+
+
+class Conv(nn.Sequential):
+    def __init__(
+            self,
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride=1,
+            padding=0,
+            groups=1,
+            activation="ReLU",
+            bn_momentum=0.9, ):
+        super(Conv, self).__init__()
+        self.add_sublayer(
+            "norm", nn.BatchNorm2D(
+                in_channels, momentum=bn_momentum))
+        if activation == "ReLU":
+            self.add_sublayer("activation", nn.ReLU())
+        elif activation == "HS":
+            self.add_sublayer("activation", HS())
+        else:
+            raise NotImplementedError
+        self.add_sublayer(
+            "conv",
+            nn.Conv2D(
+                in_channels,
+                out_channels,
+                kernel_size=kernel_size,
+                stride=stride,
+                padding=padding,
+                bias_attr=False,
+                groups=groups, ), )
+
+
+def ShuffleLayer(x, groups):
+    batchsize, num_channels, height, width = x.shape
+    channels_per_group = num_channels // groups
+    # reshape
+    x = x.reshape((batchsize, groups, channels_per_group, height, width))
+    # transpose
+    x = x.transpose((0, 2, 1, 3, 4))
+    # reshape
+    x = x.reshape((batchsize, -1, height, width))
+    return x
+
+
+def ShuffleLayerTrans(x, groups):
+    batchsize, num_channels, height, width = x.shape
+    channels_per_group = num_channels // groups
+    # reshape
+    x = x.reshape((batchsize, channels_per_group, groups, height, width))
+    # transpose
+    x = x.transpose((0, 2, 1, 3, 4))
+    # reshape
+    x = x.reshape((batchsize, -1, height, width))
+    return x
+
+
+class CondenseLGC(nn.Layer):
+    def __init__(
+            self,
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride=1,
+            padding=0,
+            groups=1,
+            activation="ReLU", ):
+        super(CondenseLGC, self).__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.groups = groups
+        self.norm = nn.BatchNorm2D(self.in_channels)
+        if activation == "ReLU":
+            self.activation = nn.ReLU()
+        elif activation == "HS":
+            self.activation = HS()
+        else:
+            raise NotImplementedError
+        self.conv = nn.Conv2D(
+            self.in_channels,
+            self.out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=padding,
+            groups=self.groups,
+            bias_attr=False, )
+        self.register_buffer(
+            "index", paddle.zeros(
+                (self.in_channels, ), dtype="int64"))
+
+    def forward(self, x):
+        x = paddle.index_select(x, self.index, axis=1)
+        x = self.norm(x)
+        x = self.activation(x)
+        x = self.conv(x)
+        x = ShuffleLayer(x, self.groups)
+        return x
+
+
+class CondenseSFR(nn.Layer):
+    def __init__(
+            self,
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride=1,
+            padding=0,
+            groups=1,
+            activation="ReLU", ):
+        super(CondenseSFR, self).__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.groups = groups
+        self.norm = nn.BatchNorm2D(self.in_channels)
+        if activation == "ReLU":
+            self.activation = nn.ReLU()
+        elif activation == "HS":
+            self.activation = HS()
+        else:
+            raise NotImplementedError
+        self.conv = nn.Conv2D(
+            self.in_channels,
+            self.out_channels,
+            kernel_size=kernel_size,
+            padding=padding,
+            groups=self.groups,
+            bias_attr=False,
+            stride=stride, )
+        self.register_buffer("index",
+                             paddle.zeros(
+                                 (self.out_channels, self.out_channels)))
+
+    def forward(self, x):
+        x = self.norm(x)
+        x = self.activation(x)
+        x = ShuffleLayerTrans(x, self.groups)
+        x = self.conv(x)  # SIZE: N, C, H, W
+        N, C, H, W = x.shape
+        x = x.reshape((N, C, H * W))
+        x = x.transpose((0, 2, 1))  # SIZE: N, HW, C
+        # x SIZE: N, HW, C; self.index SIZE: C, C; OUTPUT SIZE: N, HW, C
+        x = paddle.matmul(x, self.index)
+        x = x.transpose((0, 2, 1))  # SIZE: N, C, HW
+        x = x.reshape((N, C, H, W))  # SIZE: N, C, HW
+        return x
+
+
+class _SFR_DenseLayer(nn.Layer):
+    def __init__(
+            self,
+            in_channels,
+            growth_rate,
+            group_1x1,
+            group_3x3,
+            group_trans,
+            bottleneck,
+            activation,
+            use_se=False, ):
+        super(_SFR_DenseLayer, self).__init__()
+        self.group_1x1 = group_1x1
+        self.group_3x3 = group_3x3
+        self.group_trans = group_trans
+        self.use_se = use_se
+        # 1x1 conv i --> b*k
+        self.conv_1 = CondenseLGC(
+            in_channels,
+            bottleneck * growth_rate,
+            kernel_size=1,
+            groups=self.group_1x1,
+            activation=activation, )
+        # 3x3 conv b*k --> k
+        self.conv_2 = Conv(
+            bottleneck * growth_rate,
+            growth_rate,
+            kernel_size=3,
+            padding=1,
+            groups=self.group_3x3,
+            activation=activation, )
+        # 1x1 res conv k(8-16-32)--> i (k*l)
+        self.sfr = CondenseSFR(
+            growth_rate,
+            in_channels,
+            kernel_size=1,
+            groups=self.group_trans,
+            activation=activation, )
+        if self.use_se:
+            self.se = SELayer(inplanes=growth_rate, reduction=1)
+
+    def forward(self, x):
+        x_ = x
+        x = self.conv_1(x)
+        x = self.conv_2(x)
+        if self.use_se:
+            x = self.se(x)
+        sfr_feature = self.sfr(x)
+        y = x_ + sfr_feature
+        return paddle.concat([y, x], 1)
+
+
+class _SFR_DenseBlock(nn.Sequential):
+    def __init__(
+            self,
+            num_layers,
+            in_channels,
+            growth_rate,
+            group_1x1,
+            group_3x3,
+            group_trans,
+            bottleneck,
+            activation,
+            use_se, ):
+        super(_SFR_DenseBlock, self).__init__()
+        for i in range(num_layers):
+            layer = _SFR_DenseLayer(
+                in_channels + i * growth_rate,
+                growth_rate,
+                group_1x1,
+                group_3x3,
+                group_trans,
+                bottleneck,
+                activation,
+                use_se, )
+            self.add_sublayer("denselayer_%d" % (i + 1), layer)
+
+
+class _Transition(nn.Layer):
+    def __init__(self):
+        super(_Transition, self).__init__()
+        self.pool = nn.AvgPool2D(kernel_size=2, stride=2)
+
+    def forward(self, x):
+        x = self.pool(x)
+        return x
+
+
+class CondenseNetV2(nn.Layer):
+    def __init__(
+            self,
+            stages,
+            growth,
+            HS_start_block,
+            SE_start_block,
+            fc_channel,
+            group_1x1,
+            group_3x3,
+            group_trans,
+            bottleneck,
+            last_se_reduction,
+            in_channels=3,
+            class_num=1000, ):
+        super(CondenseNetV2, self).__init__()
+        self.stages = stages
+        self.growth = growth
+        self.in_channels = in_channels
+        self.class_num = class_num
+        self.last_se_reduction = last_se_reduction
+        assert len(self.stages) == len(self.growth)
+        self.progress = 0.0
+
+        self.init_stride = 2
+        self.pool_size = 7
+
+        self.features = nn.Sequential()
+        # Initial nChannels should be 3
+        self.num_features = 2 * self.growth[0]
+        # Dense-block 1 (224x224)
+        self.features.add_sublayer(
+            "init_conv",
+            nn.Conv2D(
+                in_channels,
+                self.num_features,
+                kernel_size=3,
+                stride=self.init_stride,
+                padding=1,
+                bias_attr=False, ), )
+        for i in range(len(self.stages)):
+            activation = "HS" if i >= HS_start_block else "ReLU"
+            use_se = True if i >= SE_start_block else False
+            # Dense-block i
+            self.add_block(i, group_1x1, group_3x3, group_trans, bottleneck,
+                           activation, use_se)
+
+        self.fc = nn.Linear(self.num_features, fc_channel)
+        self.fc_act = HS()
+
+        # Classifier layer
+        if class_num > 0:
+            self.classifier = nn.Linear(fc_channel, class_num)
+        self._initialize()
+
+    def add_block(self, i, group_1x1, group_3x3, group_trans, bottleneck,
+                  activation, use_se):
+        # Check if ith is the last one
+        last = i == len(self.stages) - 1
+        block = _SFR_DenseBlock(
+            num_layers=self.stages[i],
+            in_channels=self.num_features,
+            growth_rate=self.growth[i],
+            group_1x1=group_1x1,
+            group_3x3=group_3x3,
+            group_trans=group_trans,
+            bottleneck=bottleneck,
+            activation=activation,
+            use_se=use_se, )
+        self.features.add_sublayer("denseblock_%d" % (i + 1), block)
+        self.num_features += self.stages[i] * self.growth[i]
+        if not last:
+            trans = _Transition()
+            self.features.add_sublayer("transition_%d" % (i + 1), trans)
+        else:
+            self.features.add_sublayer("norm_last",
+                                       nn.BatchNorm2D(self.num_features))
+            self.features.add_sublayer("relu_last", nn.ReLU())
+            self.features.add_sublayer("pool_last",
+                                       nn.AvgPool2D(self.pool_size))
+            # if useSE:
+            self.features.add_sublayer(
+                "se_last",
+                SELayer(
+                    self.num_features, reduction=self.last_se_reduction))
+
+    def forward(self, x):
+        features = self.features(x)
+        out = features.reshape((features.shape[0], -1))
+        out = self.fc(out)
+        out = self.fc_act(out)
+
+        if self.class_num > 0:
+            out = self.classifier(out)
+
+        return out
+
+    def _initialize(self):
+        # initialize
+        for m in self.sublayers():
+            if isinstance(m, nn.Conv2D):
+                nn.initializer.KaimingNormal()(m.weight)
+            elif isinstance(m, nn.BatchNorm2D):
+                nn.initializer.Constant(value=1.0)(m.weight)
+                nn.initializer.Constant(value=0.0)(m.bias)
+
+
+def CondenseNetV2_a(**kwargs):
+    model = CondenseNetV2(
+        stages=[1, 1, 4, 6, 8],
+        growth=[8, 8, 16, 32, 64],
+        HS_start_block=2,
+        SE_start_block=3,
+        fc_channel=828,
+        group_1x1=8,
+        group_3x3=8,
+        group_trans=8,
+        bottleneck=4,
+        last_se_reduction=16,
+        **kwargs)
+    return model
+
+
+def CondenseNetV2_b(**kwargs):
+    model = CondenseNetV2(
+        stages=[2, 4, 6, 8, 6],
+        growth=[6, 12, 24, 48, 96],
+        HS_start_block=2,
+        SE_start_block=3,
+        fc_channel=1024,
+        group_1x1=6,
+        group_3x3=6,
+        group_trans=6,
+        bottleneck=4,
+        last_se_reduction=16,
+        **kwargs)
+    return model
+
+
+def CondenseNetV2_c(**kwargs):
+    model = CondenseNetV2(
+        stages=[4, 6, 8, 10, 8],
+        growth=[8, 16, 32, 64, 128],
+        HS_start_block=2,
+        SE_start_block=3,
+        fc_channel=1024,
+        group_1x1=8,
+        group_3x3=8,
+        group_trans=8,
+        bottleneck=4,
+        last_se_reduction=16,
+        **kwargs)
+    return model

+ 1 - 1
paddlers/custom_models/seg/__init__.py

@@ -12,4 +12,4 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from . import models
+from .farseg import FarSeg

+ 84 - 2
paddlers/custom_models/seg/models/farseg/farseg.py → paddlers/custom_models/seg/farseg.py

@@ -21,8 +21,90 @@ import math
 import paddle.nn as nn
 import paddle.nn.functional as F
 from paddle.vision.models import resnet50
-from .fpn import FPN
-from ..utils import Identity
+from paddle import nn
+import paddle.nn.functional as F
+from .layers import (Identity, ConvReLU, kaiming_normal_init, constant_init)
+
+
+class FPN(nn.Layer):
+    """
+    Module that adds FPN on top of a list of feature maps.
+    The feature maps are currently supposed to be in increasing depth
+    order, and must be consecutive
+    """
+
+    def __init__(self,
+                 in_channels_list,
+                 out_channels,
+                 conv_block=ConvReLU,
+                 top_blocks=None):
+        super(FPN, self).__init__()
+        self.inner_blocks = []
+        self.layer_blocks = []
+        for idx, in_channels in enumerate(in_channels_list, 1):
+            inner_block = "fpn_inner{}".format(idx)
+            layer_block = "fpn_layer{}".format(idx)
+            if in_channels == 0:
+                continue
+            inner_block_module = conv_block(in_channels, out_channels, 1)
+            layer_block_module = conv_block(out_channels, out_channels, 3, 1)
+            self.add_sublayer(inner_block, inner_block_module)
+            self.add_sublayer(layer_block, layer_block_module)
+            for module in [inner_block_module, layer_block_module]:
+                for m in module.sublayers():
+                    if isinstance(m, nn.Conv2D):
+                        kaiming_normal_init(m.weight)
+            self.inner_blocks.append(inner_block)
+            self.layer_blocks.append(layer_block)
+        self.top_blocks = top_blocks
+
+    def forward(self, x):
+        last_inner = getattr(self, self.inner_blocks[-1])(x[-1])
+        results = [getattr(self, self.layer_blocks[-1])(last_inner)]
+        for feature, inner_block, layer_block in zip(
+                x[:-1][::-1], self.inner_blocks[:-1][::-1],
+                self.layer_blocks[:-1][::-1]):
+            if not inner_block:
+                continue
+            inner_top_down = F.interpolate(
+                last_inner, scale_factor=2, mode="nearest")
+            inner_lateral = getattr(self, inner_block)(feature)
+            last_inner = inner_lateral + inner_top_down
+            results.insert(0, getattr(self, layer_block)(last_inner))
+        if isinstance(self.top_blocks, LastLevelP6P7):
+            last_results = self.top_blocks(x[-1], results[-1])
+            results.extend(last_results)
+        elif isinstance(self.top_blocks, LastLevelMaxPool):
+            last_results = self.top_blocks(results[-1])
+            results.extend(last_results)
+        return tuple(results)
+
+
+class LastLevelMaxPool(nn.Layer):
+    def forward(self, x):
+        return [F.max_pool2d(x, 1, 2, 0)]
+
+
+class LastLevelP6P7(nn.Layer):
+    """
+    This module is used in RetinaNet to generate extra layers, P6 and P7.
+    """
+
+    def __init__(self, in_channels, out_channels):
+        super(LastLevelP6P7, self).__init__()
+        self.p6 = nn.Conv2D(in_channels, out_channels, 3, 2, 1)
+        self.p7 = nn.Conv2D(out_channels, out_channels, 3, 2, 1)
+        for module in [self.p6, self.p7]:
+            for m in module.sublayers():
+                kaiming_normal_init(m.weight)
+                constant_init(m.bias, value=0)
+        self.use_P5 = in_channels == out_channels
+
+    def forward(self, c5, p5):
+        x = p5 if self.use_P5 else c5
+        p6 = self.p6(x)
+        p7 = self.p7(F.relu(p6))
+        return [p6, p7]
 
 
 class SceneRelation(nn.Layer):

+ 0 - 1
paddlers/custom_models/seg/models/utils/__init__.py → paddlers/custom_models/seg/layers/__init__.py

@@ -12,6 +12,5 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .torch_nn import *
 from .param_init import *
 from .layers_lib import *

+ 8 - 0
paddlers/custom_models/seg/models/utils/layers_lib.py → paddlers/custom_models/seg/layers/layers_lib.py

@@ -138,3 +138,11 @@ class Activation(nn.Layer):
             return self.act_func(x)
         else:
             return x
+
+
+class Identity(nn.Layer):
+    def __init__(self, *args, **kwargs):
+        super(Identity, self).__init__()
+
+    def forward(self, input):
+        return input

+ 0 - 0
paddlers/custom_models/seg/models/utils/param_init.py → paddlers/custom_models/seg/layers/param_init.py


+ 0 - 15
paddlers/custom_models/seg/models/__init__.py

@@ -1,15 +0,0 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from .farseg import FarSeg

+ 0 - 15
paddlers/custom_models/seg/models/farseg/__init__.py

@@ -1,15 +0,0 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from .farseg import FarSeg

+ 0 - 98
paddlers/custom_models/seg/models/farseg/fpn.py

@@ -1,98 +0,0 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from paddle import nn
-import paddle.nn.functional as F
-from ..utils import (ConvReLU, kaiming_normal_init, constant_init)
-
-
-class FPN(nn.Layer):
-    """
-    Module that adds FPN on top of a list of feature maps.
-    The feature maps are currently supposed to be in increasing depth
-    order, and must be consecutive
-    """
-
-    def __init__(self,
-                 in_channels_list,
-                 out_channels,
-                 conv_block=ConvReLU,
-                 top_blocks=None):
-        super(FPN, self).__init__()
-        self.inner_blocks = []
-        self.layer_blocks = []
-        for idx, in_channels in enumerate(in_channels_list, 1):
-            inner_block = "fpn_inner{}".format(idx)
-            layer_block = "fpn_layer{}".format(idx)
-            if in_channels == 0:
-                continue
-            inner_block_module = conv_block(in_channels, out_channels, 1)
-            layer_block_module = conv_block(out_channels, out_channels, 3, 1)
-            self.add_sublayer(inner_block, inner_block_module)
-            self.add_sublayer(layer_block, layer_block_module)
-            for module in [inner_block_module, layer_block_module]:
-                for m in module.sublayers():
-                    if isinstance(m, nn.Conv2D):
-                        kaiming_normal_init(m.weight)
-            self.inner_blocks.append(inner_block)
-            self.layer_blocks.append(layer_block)
-        self.top_blocks = top_blocks
-
-    def forward(self, x):
-        last_inner = getattr(self, self.inner_blocks[-1])(x[-1])
-        results = [getattr(self, self.layer_blocks[-1])(last_inner)]
-        for feature, inner_block, layer_block in zip(
-                x[:-1][::-1], self.inner_blocks[:-1][::-1],
-                self.layer_blocks[:-1][::-1]):
-            if not inner_block:
-                continue
-            inner_top_down = F.interpolate(
-                last_inner, scale_factor=2, mode="nearest")
-            inner_lateral = getattr(self, inner_block)(feature)
-            last_inner = inner_lateral + inner_top_down
-            results.insert(0, getattr(self, layer_block)(last_inner))
-        if isinstance(self.top_blocks, LastLevelP6P7):
-            last_results = self.top_blocks(x[-1], results[-1])
-            results.extend(last_results)
-        elif isinstance(self.top_blocks, LastLevelMaxPool):
-            last_results = self.top_blocks(results[-1])
-            results.extend(last_results)
-        return tuple(results)
-
-
-class LastLevelMaxPool(nn.Layer):
-    def forward(self, x):
-        return [F.max_pool2d(x, 1, 2, 0)]
-
-
-class LastLevelP6P7(nn.Layer):
-    """
-    This module is used in RetinaNet to generate extra layers, P6 and P7.
-    """
-
-    def __init__(self, in_channels, out_channels):
-        super(LastLevelP6P7, self).__init__()
-        self.p6 = nn.Conv2D(in_channels, out_channels, 3, 2, 1)
-        self.p7 = nn.Conv2D(out_channels, out_channels, 3, 2, 1)
-        for module in [self.p6, self.p7]:
-            for m in module.sublayers():
-                kaiming_normal_init(m.weight)
-                constant_init(m.bias, value=0)
-        self.use_P5 = in_channels == out_channels
-
-    def forward(self, c5, p5):
-        x = p5 if self.use_P5 else c5
-        p6 = self.p6(x)
-        p7 = self.p7(F.relu(p6))
-        return [p6, p7]

+ 0 - 23
paddlers/custom_models/seg/models/utils/torch_nn.py

@@ -1,23 +0,0 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import paddle.nn as nn
-
-
-class Identity(nn.Layer):
-    def __init__(self, *args, **kwargs):
-        super(Identity, self).__init__()
-
-    def forward(self, input):
-        return input

+ 2 - 2
paddlers/datasets/clas_dataset.py

@@ -64,10 +64,10 @@ class ClasDataset(Dataset):
                         "so the space cannot be in the image or label path, but the line[{}] of " \
                         " file_list[{}] has a space in the image or label path.".format(line, file_list))
                 items[0] = path_normalization(items[0])
-                if not is_pic(items[0]):
-                    continue
                 full_path_im = osp.join(data_dir, items[0])
                 label = items[1]
+                if not is_pic(full_path_im):
+                    continue
                 if not osp.exists(full_path_im):
                     raise IOError('Image file {} does not exist!'.format(
                         full_path_im))

+ 3 - 2
paddlers/tasks/base.py

@@ -39,6 +39,7 @@ from .utils.infer_nets import InferNet
 class BaseModel:
     def __init__(self, model_type):
         self.model_type = model_type
+        self.in_channels = None
         self.num_classes = None
         self.labels = None
         self.version = paddlers.__version__
@@ -130,8 +131,8 @@ class BaseModel:
         info['version'] = paddlers.__version__
         info['Model'] = self.__class__.__name__
         info['_Attributes'] = dict(
-            [('model_type', self.model_type), ('num_classes', self.num_classes),
-             ('labels', self.labels),
+            [('model_type', self.model_type), ('in_channels', self.in_channels),
+             ('num_classes', self.num_classes), ('labels', self.labels),
              ('fixed_input_shape', self.fixed_input_shape),
              ('best_accuracy', self.best_accuracy),
              ('best_model_epoch', self.best_model_epoch)])

+ 3 - 3
paddlers/tasks/changedetector.py

@@ -24,7 +24,7 @@ import paddle.nn.functional as F
 from paddle.static import InputSpec
 
 import paddlers
-import paddlers.custom_models.cd as cd
+import paddlers.custom_models.cd as cmcd
 import paddlers.utils.logging as logging
 import paddlers.models.ppseg as paddleseg
 from paddlers.transforms import arrange_transforms
@@ -65,8 +65,8 @@ class BaseChangeDetector(BaseModel):
 
     def build_net(self, **params):
         # TODO: add other model
-        net = cd.models.__dict__[self.model_name](num_classes=self.num_classes,
-                                                  **params)
+        net = cmcd.__dict__[self.model_name](num_classes=self.num_classes,
+                                             **params)
         return net
 
     def _fix_transforms_shape(self, image_shape):

+ 544 - 520
paddlers/tasks/classifier.py

@@ -1,520 +1,544 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import math
-import os.path as osp
-import numpy as np
-from collections import OrderedDict
-import paddle
-import paddle.nn.functional as F
-from paddle.static import InputSpec
-import paddlers.models.ppcls as paddleclas
-import paddlers
-from paddlers.transforms import arrange_transforms
-from paddlers.utils import get_single_card_bs, DisablePrint
-import paddlers.utils.logging as logging
-from .base import BaseModel
-from paddlers.models.ppcls.metric import build_metrics
-from paddlers.models.ppcls.loss import build_loss
-from paddlers.models.ppcls.data.postprocess import build_postprocess
-from paddlers.utils.checkpoint import cls_pretrain_weights_dict
-from paddlers.transforms import ImgDecoder, Resize
-
-__all__ = ["ResNet50_vd", "MobileNetV3_small_x1_0", "HRNet_W18_C"]
-
-
-class BaseClassifier(BaseModel):
-    def __init__(self,
-                 model_name,
-                 num_classes=2,
-                 use_mixed_loss=False,
-                 **params):
-        self.init_params = locals()
-        if 'with_net' in self.init_params:
-            del self.init_params['with_net']
-        super(BaseClassifier, self).__init__('classifier')
-        if not hasattr(paddleclas.arch.backbone, model_name):
-            raise Exception("ERROR: There's no model named {}.".format(
-                model_name))
-        self.model_name = model_name
-        self.num_classes = num_classes
-        self.use_mixed_loss = use_mixed_loss
-        self.metrics = None
-        self.losses = None
-        self.labels = None
-        self._postprocess = None
-        if params.get('with_net', True):
-            params.pop('with_net', None)
-            self.net = self.build_net(**params)
-        self.find_unused_parameters = True
-
-    def build_net(self, **params):
-        with paddle.utils.unique_name.guard():
-            net = paddleclas.arch.backbone.__dict__[self.model_name](
-                class_num=self.num_classes, **params)
-        return net
-
-    def _fix_transforms_shape(self, image_shape):
-        if hasattr(self, 'test_transforms'):
-            if self.test_transforms is not None:
-                has_resize_op = False
-                resize_op_idx = -1
-                normalize_op_idx = len(self.test_transforms.transforms)
-                for idx, op in enumerate(self.test_transforms.transforms):
-                    name = op.__class__.__name__
-                    if name == 'Normalize':
-                        normalize_op_idx = idx
-                    if 'Resize' in name:
-                        has_resize_op = True
-                        resize_op_idx = idx
-
-                if not has_resize_op:
-                    self.test_transforms.transforms.insert(
-                        normalize_op_idx, Resize(target_size=image_shape))
-                else:
-                    self.test_transforms.transforms[resize_op_idx] = Resize(
-                        target_size=image_shape)
-
-    def _get_test_inputs(self, image_shape):
-        if image_shape is not None:
-            if len(image_shape) == 2:
-                image_shape = [1, 3] + image_shape
-            self._fix_transforms_shape(image_shape[-2:])
-        else:
-            image_shape = [None, 3, -1, -1]
-        self.fixed_input_shape = image_shape
-        input_spec = [
-            InputSpec(
-                shape=image_shape, name='image', dtype='float32')
-        ]
-        return input_spec
-
-    def run(self, net, inputs, mode):
-        net_out = net(inputs[0])
-        label = paddle.to_tensor(inputs[1], dtype="int64")
-        outputs = OrderedDict()
-        if mode == 'test':
-            result = self._postprocess(net_out)
-            outputs = result[0]
-
-        if mode == 'eval':
-            # print(self._postprocess(net_out)[0])  # for test
-            label = paddle.unsqueeze(label, axis=-1)
-            metric_dict = self.metrics(net_out, label)
-            outputs['top1'] = metric_dict["top1"]
-            outputs['top5'] = metric_dict["top5"]
-
-        if mode == 'train':
-            loss_list = self.losses(net_out, label)
-            outputs['loss'] = loss_list['loss']
-        return outputs
-
-    def default_metric(self):
-        default_config = [{"TopkAcc": {"topk": [1, 5]}}]
-        return build_metrics(default_config)
-
-    def default_loss(self):
-        # TODO: use mixed loss and other loss
-        default_config = [{"CELoss": {"weight": 1.0}}]
-        return build_loss(default_config)
-
-    def default_optimizer(self,
-                          parameters,
-                          learning_rate,
-                          num_epochs,
-                          num_steps_each_epoch,
-                          last_epoch=-1,
-                          L2_coeff=0.00007):
-        decay_step = num_epochs * num_steps_each_epoch
-        lr_scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
-            learning_rate, T_max=decay_step, eta_min=0, last_epoch=last_epoch)
-        optimizer = paddle.optimizer.Momentum(
-            learning_rate=lr_scheduler,
-            parameters=parameters,
-            momentum=0.9,
-            weight_decay=paddle.regularizer.L2Decay(L2_coeff))
-        return optimizer
-
-    def default_postprocess(self, class_id_map_file):
-        default_config = {
-            "name": "Topk",
-            "topk": 1,
-            "class_id_map_file": class_id_map_file
-        }
-        return build_postprocess(default_config)
-
-    def train(self,
-              num_epochs,
-              train_dataset,
-              train_batch_size=2,
-              eval_dataset=None,
-              optimizer=None,
-              save_interval_epochs=1,
-              log_interval_steps=2,
-              save_dir='output',
-              pretrain_weights='IMAGENET',
-              learning_rate=0.1,
-              lr_decay_power=0.9,
-              early_stop=False,
-              early_stop_patience=5,
-              use_vdl=True,
-              resume_checkpoint=None):
-        """
-        Train the model.
-        Args:
-            num_epochs(int): The number of epochs.
-            train_dataset(paddlers.dataset): Training dataset.
-            train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
-            eval_dataset(paddlers.dataset, optional):
-                Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
-            optimizer(paddle.optimizer.Optimizer or None, optional):
-                Optimizer used in training. If None, a default optimizer is used. Defaults to None.
-            save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
-            log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
-            save_dir(str, optional): Directory to save the model. Defaults to 'output'.
-            pretrain_weights(str or None, optional):
-                None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'CITYSCAPES'.
-            learning_rate(float, optional): Learning rate for training. Defaults to .025.
-            lr_decay_power(float, optional): Learning decay power. Defaults to .9.
-            early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
-            early_stop_patience(int, optional): Early stop patience. Defaults to 5.
-            use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
-            resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
-                If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
-                `pretrain_weights` can be set simultaneously. Defaults to None.
-
-        """
-        if self.status == 'Infer':
-            logging.error(
-                "Exported inference model does not support training.",
-                exit=True)
-        if pretrain_weights is not None and resume_checkpoint is not None:
-            logging.error(
-                "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
-                exit=True)
-        self.labels = train_dataset.labels
-        if self.losses is None:
-            self.losses = self.default_loss()
-        self.metrics = self.default_metric()
-        self._postprocess = self.default_postprocess(train_dataset.label_list)
-        # print(self._postprocess.class_id_map)
-
-        if optimizer is None:
-            num_steps_each_epoch = train_dataset.num_samples // train_batch_size
-            self.optimizer = self.default_optimizer(
-                self.net.parameters(), learning_rate, num_epochs,
-                num_steps_each_epoch, lr_decay_power)
-        else:
-            self.optimizer = optimizer
-
-        if pretrain_weights is not None and not osp.exists(pretrain_weights):
-            if pretrain_weights not in cls_pretrain_weights_dict[
-                    self.model_name]:
-                logging.warning(
-                    "Path of pretrain_weights('{}') does not exist!".format(
-                        pretrain_weights))
-                logging.warning("Pretrain_weights is forcibly set to '{}'. "
-                                "If don't want to use pretrain weights, "
-                                "set pretrain_weights to be None.".format(
-                                    cls_pretrain_weights_dict[self.model_name][
-                                        0]))
-                pretrain_weights = cls_pretrain_weights_dict[self.model_name][0]
-        elif pretrain_weights is not None and osp.exists(pretrain_weights):
-            if osp.splitext(pretrain_weights)[-1] != '.pdparams':
-                logging.error(
-                    "Invalid pretrain weights. Please specify a '.pdparams' file.",
-                    exit=True)
-        pretrained_dir = osp.join(save_dir, 'pretrain')
-        is_backbone_weights = False  # pretrain_weights == 'IMAGENET'  # TODO: this is backbone
-        self.net_initialize(
-            pretrain_weights=pretrain_weights,
-            save_dir=pretrained_dir,
-            resume_checkpoint=resume_checkpoint,
-            is_backbone_weights=is_backbone_weights)
-
-        self.train_loop(
-            num_epochs=num_epochs,
-            train_dataset=train_dataset,
-            train_batch_size=train_batch_size,
-            eval_dataset=eval_dataset,
-            save_interval_epochs=save_interval_epochs,
-            log_interval_steps=log_interval_steps,
-            save_dir=save_dir,
-            early_stop=early_stop,
-            early_stop_patience=early_stop_patience,
-            use_vdl=use_vdl)
-
-    def quant_aware_train(self,
-                          num_epochs,
-                          train_dataset,
-                          train_batch_size=2,
-                          eval_dataset=None,
-                          optimizer=None,
-                          save_interval_epochs=1,
-                          log_interval_steps=2,
-                          save_dir='output',
-                          learning_rate=0.0001,
-                          lr_decay_power=0.9,
-                          early_stop=False,
-                          early_stop_patience=5,
-                          use_vdl=True,
-                          resume_checkpoint=None,
-                          quant_config=None):
-        """
-        Quantization-aware training.
-        Args:
-            num_epochs(int): The number of epochs.
-            train_dataset(paddlers.dataset): Training dataset.
-            train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
-            eval_dataset(paddlers.dataset, optional):
-                Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
-            optimizer(paddle.optimizer.Optimizer or None, optional):
-                Optimizer used in training. If None, a default optimizer is used. Defaults to None.
-            save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
-            log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
-            save_dir(str, optional): Directory to save the model. Defaults to 'output'.
-            learning_rate(float, optional): Learning rate for training. Defaults to .025.
-            lr_decay_power(float, optional): Learning decay power. Defaults to .9.
-            early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
-            early_stop_patience(int, optional): Early stop patience. Defaults to 5.
-            use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
-            quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
-                configuration will be used. Defaults to None.
-            resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
-                from. If None, no training checkpoint will be resumed. Defaults to None.
-
-        """
-        self._prepare_qat(quant_config)
-        self.train(
-            num_epochs=num_epochs,
-            train_dataset=train_dataset,
-            train_batch_size=train_batch_size,
-            eval_dataset=eval_dataset,
-            optimizer=optimizer,
-            save_interval_epochs=save_interval_epochs,
-            log_interval_steps=log_interval_steps,
-            save_dir=save_dir,
-            pretrain_weights=None,
-            learning_rate=learning_rate,
-            lr_decay_power=lr_decay_power,
-            early_stop=early_stop,
-            early_stop_patience=early_stop_patience,
-            use_vdl=use_vdl,
-            resume_checkpoint=resume_checkpoint)
-
-    def evaluate(self, eval_dataset, batch_size=1, return_details=False):
-        """
-        Evaluate the model.
-        Args:
-            eval_dataset(paddlers.dataset): Evaluation dataset.
-            batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
-            return_details(bool, optional): Whether to return evaluation details. Defaults to False.
-
-        Returns:
-            collections.OrderedDict with key-value pairs:
-                {"top1": `acc of top1`,
-                 "top5": `acc of top5`}.
-
-        """
-        arrange_transforms(
-            model_type=self.model_type,
-            transforms=eval_dataset.transforms,
-            mode='eval')
-
-        self.net.eval()
-        nranks = paddle.distributed.get_world_size()
-        local_rank = paddle.distributed.get_rank()
-        if nranks > 1:
-            # Initialize parallel environment if not done.
-            if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
-            ):
-                paddle.distributed.init_parallel_env()
-
-        batch_size_each_card = get_single_card_bs(batch_size)
-        if batch_size_each_card > 1:
-            batch_size_each_card = 1
-            batch_size = batch_size_each_card * paddlers.env_info['num']
-            logging.warning(
-                "Segmenter only supports batch_size=1 for each gpu/cpu card " \
-                "during evaluation, so batch_size " \
-                "is forcibly set to {}.".format(batch_size))
-        self.eval_data_loader = self.build_data_loader(
-            eval_dataset, batch_size=batch_size, mode='eval')
-
-        logging.info(
-            "Start to evaluate(total_samples={}, total_steps={})...".format(
-                eval_dataset.num_samples,
-                math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
-
-        top1s = []
-        top5s = []
-        with paddle.no_grad():
-            for step, data in enumerate(self.eval_data_loader):
-                data.append(eval_dataset.transforms.transforms)
-                outputs = self.run(self.net, data, 'eval')
-                top1s.append(outputs["top1"])
-                top5s.append(outputs["top5"])
-
-        top1 = np.mean(top1s)
-        top5 = np.mean(top5s)
-        eval_metrics = OrderedDict(zip(['top1', 'top5'], [top1, top5]))
-        if return_details:
-            # TODO: add details
-            return eval_metrics, None
-        return eval_metrics
-
-    def predict(self, img_file, transforms=None):
-        """
-        Do inference.
-        Args:
-            Args:
-            img_file(List[np.ndarray or str], str or np.ndarray):
-                Image path or decoded image data in a BGR format, which also could constitute a list,
-                meaning all images to be predicted as a mini-batch.
-            transforms(paddlers.transforms.Compose or None, optional):
-                Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
-
-        Returns:
-            If img_file is a string or np.array, the result is a dict with key-value pairs:
-            {"label map": `class_ids_map`, "scores_map": `label_names_map`}.
-            If img_file is a list, the result is a list composed of dicts with the corresponding fields:
-            class_ids_map(np.ndarray): class_ids
-            scores_map(np.ndarray): scores
-            label_names_map(np.ndarray): label_names
-
-        """
-        if transforms is None and not hasattr(self, 'test_transforms'):
-            raise Exception("transforms need to be defined, now is None.")
-        if transforms is None:
-            transforms = self.test_transforms
-        if isinstance(img_file, (str, np.ndarray)):
-            images = [img_file]
-        else:
-            images = img_file
-        batch_im, batch_origin_shape = self._preprocess(images, transforms,
-                                                        self.model_type)
-        self.net.eval()
-        data = (batch_im, batch_origin_shape, transforms.transforms)
-        outputs = self.run(self.net, data, 'test')
-        label_list = outputs['class_ids']
-        score_list = outputs['scores']
-        name_list = outputs['label_names']
-        if isinstance(img_file, list):
-            prediction = [{
-                'class_ids_map': l,
-                'scores_map': s,
-                'label_names_map': n,
-            } for l, s, n in zip(label_list, score_list, name_list)]
-        else:
-            prediction = {
-                'class_ids': label_list[0],
-                'scores': score_list[0],
-                'label_names': name_list[0]
-            }
-        return prediction
-
-    def _preprocess(self, images, transforms, to_tensor=True):
-        arrange_transforms(
-            model_type=self.model_type, transforms=transforms, mode='test')
-        batch_im = list()
-        batch_ori_shape = list()
-        for im in images:
-            sample = {'image': im}
-            if isinstance(sample['image'], str):
-                sample = ImgDecoder(to_rgb=False)(sample)
-            ori_shape = sample['image'].shape[:2]
-            im = transforms(sample)[0]
-            batch_im.append(im)
-            batch_ori_shape.append(ori_shape)
-        if to_tensor:
-            batch_im = paddle.to_tensor(batch_im)
-        else:
-            batch_im = np.asarray(batch_im)
-
-        return batch_im, batch_ori_shape
-
-    @staticmethod
-    def get_transforms_shape_info(batch_ori_shape, transforms):
-        batch_restore_list = list()
-        for ori_shape in batch_ori_shape:
-            restore_list = list()
-            h, w = ori_shape[0], ori_shape[1]
-            for op in transforms:
-                if op.__class__.__name__ == 'Resize':
-                    restore_list.append(('resize', (h, w)))
-                    h, w = op.target_size
-                elif op.__class__.__name__ == 'ResizeByShort':
-                    restore_list.append(('resize', (h, w)))
-                    im_short_size = min(h, w)
-                    im_long_size = max(h, w)
-                    scale = float(op.short_size) / float(im_short_size)
-                    if 0 < op.max_size < np.round(scale * im_long_size):
-                        scale = float(op.max_size) / float(im_long_size)
-                    h = int(round(h * scale))
-                    w = int(round(w * scale))
-                elif op.__class__.__name__ == 'ResizeByLong':
-                    restore_list.append(('resize', (h, w)))
-                    im_long_size = max(h, w)
-                    scale = float(op.long_size) / float(im_long_size)
-                    h = int(round(h * scale))
-                    w = int(round(w * scale))
-                elif op.__class__.__name__ == 'Padding':
-                    if op.target_size:
-                        target_h, target_w = op.target_size
-                    else:
-                        target_h = int(
-                            (np.ceil(h / op.size_divisor) * op.size_divisor))
-                        target_w = int(
-                            (np.ceil(w / op.size_divisor) * op.size_divisor))
-
-                    if op.pad_mode == -1:
-                        offsets = op.offsets
-                    elif op.pad_mode == 0:
-                        offsets = [0, 0]
-                    elif op.pad_mode == 1:
-                        offsets = [(target_h - h) // 2, (target_w - w) // 2]
-                    else:
-                        offsets = [target_h - h, target_w - w]
-                    restore_list.append(('padding', (h, w), offsets))
-                    h, w = target_h, target_w
-
-            batch_restore_list.append(restore_list)
-        return batch_restore_list
-
-
-class ResNet50_vd(BaseClassifier):
-    def __init__(self, num_classes=2, use_mixed_loss=False, **params):
-        super(ResNet50_vd, self).__init__(
-            model_name='ResNet50_vd',
-            num_classes=num_classes,
-            use_mixed_loss=use_mixed_loss,
-            **params)
-
-
-class MobileNetV3_small_x1_0(BaseClassifier):
-    def __init__(self, num_classes=2, use_mixed_loss=False, **params):
-        super(MobileNetV3_small_x1_0, self).__init__(
-            model_name='MobileNetV3_small_x1_0',
-            num_classes=num_classes,
-            use_mixed_loss=use_mixed_loss,
-            **params)
-
-
-class HRNet_W18_C(BaseClassifier):
-    def __init__(self, num_classes=2, use_mixed_loss=False, **params):
-        super(HRNet_W18_C, self).__init__(
-            model_name='HRNet_W18_C',
-            num_classes=num_classes,
-            use_mixed_loss=use_mixed_loss,
-            **params)
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import os.path as osp
+import numpy as np
+from collections import OrderedDict
+import paddle
+import paddle.nn.functional as F
+from paddle.static import InputSpec
+import paddlers.models.ppcls as paddleclas
+import paddlers.custom_models.cls as cmcls
+import paddlers
+from paddlers.transforms import arrange_transforms
+from paddlers.utils import get_single_card_bs, DisablePrint
+import paddlers.utils.logging as logging
+from .base import BaseModel
+from paddlers.models.ppcls.metric import build_metrics
+from paddlers.models.ppcls.loss import build_loss
+from paddlers.models.ppcls.data.postprocess import build_postprocess
+from paddlers.utils.checkpoint import cls_pretrain_weights_dict
+from paddlers.transforms import ImgDecoder, Resize
+
+__all__ = [
+    "ResNet50_vd", "MobileNetV3_small_x1_0", "HRNet_W18_C", "CondenseNetV2_b"
+]
+
+
+class BaseClassifier(BaseModel):
+    def __init__(self,
+                 model_name,
+                 in_channels=3,
+                 num_classes=2,
+                 use_mixed_loss=False,
+                 **params):
+        self.init_params = locals()
+        if 'with_net' in self.init_params:
+            del self.init_params['with_net']
+        super(BaseClassifier, self).__init__('classifier')
+        if not hasattr(paddleclas.arch.backbone, model_name) and \
+           not hasattr(cmcls, model_name):
+            raise Exception("ERROR: There's no model named {}.".format(
+                model_name))
+        self.model_name = model_name
+        self.in_channels = in_channels
+        self.num_classes = num_classes
+        self.use_mixed_loss = use_mixed_loss
+        self.metrics = None
+        self.losses = None
+        self.labels = None
+        self._postprocess = None
+        if params.get('with_net', True):
+            params.pop('with_net', None)
+            self.net = self.build_net(**params)
+        self.find_unused_parameters = True
+
+    def build_net(self, **params):
+        with paddle.utils.unique_name.guard():
+            model = dict(paddleclas.arch.backbone.__dict__,
+                         **cmcls.__dict__)[self.model_name]
+            # TODO: Determine whether there is in_channels
+            try:
+                net = model(
+                    class_num=self.num_classes,
+                    in_channels=self.in_channels,
+                    **params)
+            except:
+                net = model(class_num=self.num_classes, **params)
+                self.in_channels = 3
+        return net
+
+    def _fix_transforms_shape(self, image_shape):
+        if hasattr(self, 'test_transforms'):
+            if self.test_transforms is not None:
+                has_resize_op = False
+                resize_op_idx = -1
+                normalize_op_idx = len(self.test_transforms.transforms)
+                for idx, op in enumerate(self.test_transforms.transforms):
+                    name = op.__class__.__name__
+                    if name == 'Normalize':
+                        normalize_op_idx = idx
+                    if 'Resize' in name:
+                        has_resize_op = True
+                        resize_op_idx = idx
+
+                if not has_resize_op:
+                    self.test_transforms.transforms.insert(
+                        normalize_op_idx, Resize(target_size=image_shape))
+                else:
+                    self.test_transforms.transforms[resize_op_idx] = Resize(
+                        target_size=image_shape)
+
+    def _get_test_inputs(self, image_shape):
+        if image_shape is not None:
+            if len(image_shape) == 2:
+                image_shape = [1, 3] + image_shape
+            self._fix_transforms_shape(image_shape[-2:])
+        else:
+            image_shape = [None, 3, -1, -1]
+        self.fixed_input_shape = image_shape
+        input_spec = [
+            InputSpec(
+                shape=image_shape, name='image', dtype='float32')
+        ]
+        return input_spec
+
+    def run(self, net, inputs, mode):
+        net_out = net(inputs[0])
+        label = paddle.to_tensor(inputs[1], dtype="int64")
+        outputs = OrderedDict()
+        if mode == 'test':
+            result = self._postprocess(net_out)
+            outputs = result[0]
+
+        if mode == 'eval':
+            # print(self._postprocess(net_out)[0])  # for test
+            label = paddle.unsqueeze(label, axis=-1)
+            metric_dict = self.metrics(net_out, label)
+            outputs['top1'] = metric_dict["top1"]
+            outputs['top5'] = metric_dict["top5"]
+
+        if mode == 'train':
+            loss_list = self.losses(net_out, label)
+            outputs['loss'] = loss_list['loss']
+        return outputs
+
+    def default_metric(self):
+        default_config = [{"TopkAcc": {"topk": [1, 5]}}]
+        return build_metrics(default_config)
+
+    def default_loss(self):
+        # TODO: use mixed loss and other loss
+        default_config = [{"CELoss": {"weight": 1.0}}]
+        return build_loss(default_config)
+
+    def default_optimizer(self,
+                          parameters,
+                          learning_rate,
+                          num_epochs,
+                          num_steps_each_epoch,
+                          last_epoch=-1,
+                          L2_coeff=0.00007):
+        decay_step = num_epochs * num_steps_each_epoch
+        lr_scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
+            learning_rate, T_max=decay_step, eta_min=0, last_epoch=last_epoch)
+        optimizer = paddle.optimizer.Momentum(
+            learning_rate=lr_scheduler,
+            parameters=parameters,
+            momentum=0.9,
+            weight_decay=paddle.regularizer.L2Decay(L2_coeff))
+        return optimizer
+
+    def default_postprocess(self, class_id_map_file):
+        default_config = {
+            "name": "Topk",
+            "topk": 1,
+            "class_id_map_file": class_id_map_file
+        }
+        return build_postprocess(default_config)
+
+    def train(self,
+              num_epochs,
+              train_dataset,
+              train_batch_size=2,
+              eval_dataset=None,
+              optimizer=None,
+              save_interval_epochs=1,
+              log_interval_steps=2,
+              save_dir='output',
+              pretrain_weights='IMAGENET',
+              learning_rate=0.1,
+              lr_decay_power=0.9,
+              early_stop=False,
+              early_stop_patience=5,
+              use_vdl=True,
+              resume_checkpoint=None):
+        """
+        Train the model.
+        Args:
+            num_epochs(int): The number of epochs.
+            train_dataset(paddlers.dataset): Training dataset.
+            train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
+            eval_dataset(paddlers.dataset, optional):
+                Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
+            optimizer(paddle.optimizer.Optimizer or None, optional):
+                Optimizer used in training. If None, a default optimizer is used. Defaults to None.
+            save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
+            log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
+            save_dir(str, optional): Directory to save the model. Defaults to 'output'.
+            pretrain_weights(str or None, optional):
+                None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'CITYSCAPES'.
+            learning_rate(float, optional): Learning rate for training. Defaults to .025.
+            lr_decay_power(float, optional): Learning decay power. Defaults to .9.
+            early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
+            early_stop_patience(int, optional): Early stop patience. Defaults to 5.
+            use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
+            resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
+                If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
+                `pretrain_weights` can be set simultaneously. Defaults to None.
+
+        """
+        if self.status == 'Infer':
+            logging.error(
+                "Exported inference model does not support training.",
+                exit=True)
+        if pretrain_weights is not None and resume_checkpoint is not None:
+            logging.error(
+                "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
+                exit=True)
+        self.labels = train_dataset.labels
+        if self.losses is None:
+            self.losses = self.default_loss()
+        self.metrics = self.default_metric()
+        self._postprocess = self.default_postprocess(train_dataset.label_list)
+        # print(self._postprocess.class_id_map)
+
+        if optimizer is None:
+            num_steps_each_epoch = train_dataset.num_samples // train_batch_size
+            self.optimizer = self.default_optimizer(
+                self.net.parameters(), learning_rate, num_epochs,
+                num_steps_each_epoch, lr_decay_power)
+        else:
+            self.optimizer = optimizer
+
+        if pretrain_weights is not None and not osp.exists(pretrain_weights):
+            if pretrain_weights not in cls_pretrain_weights_dict[
+                    self.model_name]:
+                logging.warning(
+                    "Path of pretrain_weights('{}') does not exist!".format(
+                        pretrain_weights))
+                logging.warning("Pretrain_weights is forcibly set to '{}'. "
+                                "If don't want to use pretrain weights, "
+                                "set pretrain_weights to be None.".format(
+                                    cls_pretrain_weights_dict[self.model_name][
+                                        0]))
+                pretrain_weights = cls_pretrain_weights_dict[self.model_name][0]
+        elif pretrain_weights is not None and osp.exists(pretrain_weights):
+            if osp.splitext(pretrain_weights)[-1] != '.pdparams':
+                logging.error(
+                    "Invalid pretrain weights. Please specify a '.pdparams' file.",
+                    exit=True)
+        pretrained_dir = osp.join(save_dir, 'pretrain')
+        is_backbone_weights = False  # pretrain_weights == 'IMAGENET'  # TODO: this is backbone
+        self.net_initialize(
+            pretrain_weights=pretrain_weights,
+            save_dir=pretrained_dir,
+            resume_checkpoint=resume_checkpoint,
+            is_backbone_weights=is_backbone_weights)
+
+        self.train_loop(
+            num_epochs=num_epochs,
+            train_dataset=train_dataset,
+            train_batch_size=train_batch_size,
+            eval_dataset=eval_dataset,
+            save_interval_epochs=save_interval_epochs,
+            log_interval_steps=log_interval_steps,
+            save_dir=save_dir,
+            early_stop=early_stop,
+            early_stop_patience=early_stop_patience,
+            use_vdl=use_vdl)
+
+    def quant_aware_train(self,
+                          num_epochs,
+                          train_dataset,
+                          train_batch_size=2,
+                          eval_dataset=None,
+                          optimizer=None,
+                          save_interval_epochs=1,
+                          log_interval_steps=2,
+                          save_dir='output',
+                          learning_rate=0.0001,
+                          lr_decay_power=0.9,
+                          early_stop=False,
+                          early_stop_patience=5,
+                          use_vdl=True,
+                          resume_checkpoint=None,
+                          quant_config=None):
+        """
+        Quantization-aware training.
+        Args:
+            num_epochs(int): The number of epochs.
+            train_dataset(paddlers.dataset): Training dataset.
+            train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
+            eval_dataset(paddlers.dataset, optional):
+                Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
+            optimizer(paddle.optimizer.Optimizer or None, optional):
+                Optimizer used in training. If None, a default optimizer is used. Defaults to None.
+            save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
+            log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
+            save_dir(str, optional): Directory to save the model. Defaults to 'output'.
+            learning_rate(float, optional): Learning rate for training. Defaults to .025.
+            lr_decay_power(float, optional): Learning decay power. Defaults to .9.
+            early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
+            early_stop_patience(int, optional): Early stop patience. Defaults to 5.
+            use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
+            quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
+                configuration will be used. Defaults to None.
+            resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
+                from. If None, no training checkpoint will be resumed. Defaults to None.
+
+        """
+        self._prepare_qat(quant_config)
+        self.train(
+            num_epochs=num_epochs,
+            train_dataset=train_dataset,
+            train_batch_size=train_batch_size,
+            eval_dataset=eval_dataset,
+            optimizer=optimizer,
+            save_interval_epochs=save_interval_epochs,
+            log_interval_steps=log_interval_steps,
+            save_dir=save_dir,
+            pretrain_weights=None,
+            learning_rate=learning_rate,
+            lr_decay_power=lr_decay_power,
+            early_stop=early_stop,
+            early_stop_patience=early_stop_patience,
+            use_vdl=use_vdl,
+            resume_checkpoint=resume_checkpoint)
+
+    def evaluate(self, eval_dataset, batch_size=1, return_details=False):
+        """
+        Evaluate the model.
+        Args:
+            eval_dataset(paddlers.dataset): Evaluation dataset.
+            batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
+            return_details(bool, optional): Whether to return evaluation details. Defaults to False.
+
+        Returns:
+            collections.OrderedDict with key-value pairs:
+                {"top1": `acc of top1`,
+                 "top5": `acc of top5`}.
+
+        """
+        arrange_transforms(
+            model_type=self.model_type,
+            transforms=eval_dataset.transforms,
+            mode='eval')
+
+        self.net.eval()
+        nranks = paddle.distributed.get_world_size()
+        local_rank = paddle.distributed.get_rank()
+        if nranks > 1:
+            # Initialize parallel environment if not done.
+            if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
+            ):
+                paddle.distributed.init_parallel_env()
+
+        batch_size_each_card = get_single_card_bs(batch_size)
+        if batch_size_each_card > 1:
+            batch_size_each_card = 1
+            batch_size = batch_size_each_card * paddlers.env_info['num']
+            logging.warning(
+                "Segmenter only supports batch_size=1 for each gpu/cpu card " \
+                "during evaluation, so batch_size " \
+                "is forcibly set to {}.".format(batch_size))
+        self.eval_data_loader = self.build_data_loader(
+            eval_dataset, batch_size=batch_size, mode='eval')
+
+        logging.info(
+            "Start to evaluate(total_samples={}, total_steps={})...".format(
+                eval_dataset.num_samples,
+                math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
+
+        top1s = []
+        top5s = []
+        with paddle.no_grad():
+            for step, data in enumerate(self.eval_data_loader):
+                data.append(eval_dataset.transforms.transforms)
+                outputs = self.run(self.net, data, 'eval')
+                top1s.append(outputs["top1"])
+                top5s.append(outputs["top5"])
+
+        top1 = np.mean(top1s)
+        top5 = np.mean(top5s)
+        eval_metrics = OrderedDict(zip(['top1', 'top5'], [top1, top5]))
+        if return_details:
+            # TODO: add details
+            return eval_metrics, None
+        return eval_metrics
+
+    def predict(self, img_file, transforms=None):
+        """
+        Do inference.
+        Args:
+            Args:
+            img_file(List[np.ndarray or str], str or np.ndarray):
+                Image path or decoded image data in a BGR format, which also could constitute a list,
+                meaning all images to be predicted as a mini-batch.
+            transforms(paddlers.transforms.Compose or None, optional):
+                Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
+
+        Returns:
+            If img_file is a string or np.array, the result is a dict with key-value pairs:
+            {"label map": `class_ids_map`, "scores_map": `label_names_map`}.
+            If img_file is a list, the result is a list composed of dicts with the corresponding fields:
+            class_ids_map(np.ndarray): class_ids
+            scores_map(np.ndarray): scores
+            label_names_map(np.ndarray): label_names
+
+        """
+        if transforms is None and not hasattr(self, 'test_transforms'):
+            raise Exception("transforms need to be defined, now is None.")
+        if transforms is None:
+            transforms = self.test_transforms
+        if isinstance(img_file, (str, np.ndarray)):
+            images = [img_file]
+        else:
+            images = img_file
+        batch_im, batch_origin_shape = self._preprocess(images, transforms,
+                                                        self.model_type)
+        self.net.eval()
+        data = (batch_im, batch_origin_shape, transforms.transforms)
+        outputs = self.run(self.net, data, 'test')
+        label_list = outputs['class_ids']
+        score_list = outputs['scores']
+        name_list = outputs['label_names']
+        if isinstance(img_file, list):
+            prediction = [{
+                'class_ids_map': l,
+                'scores_map': s,
+                'label_names_map': n,
+            } for l, s, n in zip(label_list, score_list, name_list)]
+        else:
+            prediction = {
+                'class_ids': label_list[0],
+                'scores': score_list[0],
+                'label_names': name_list[0]
+            }
+        return prediction
+
+    def _preprocess(self, images, transforms, to_tensor=True):
+        arrange_transforms(
+            model_type=self.model_type, transforms=transforms, mode='test')
+        batch_im = list()
+        batch_ori_shape = list()
+        for im in images:
+            sample = {'image': im}
+            if isinstance(sample['image'], str):
+                sample = ImgDecoder(to_rgb=False)(sample)
+            ori_shape = sample['image'].shape[:2]
+            im = transforms(sample)[0]
+            batch_im.append(im)
+            batch_ori_shape.append(ori_shape)
+        if to_tensor:
+            batch_im = paddle.to_tensor(batch_im)
+        else:
+            batch_im = np.asarray(batch_im)
+
+        return batch_im, batch_ori_shape
+
+    @staticmethod
+    def get_transforms_shape_info(batch_ori_shape, transforms):
+        batch_restore_list = list()
+        for ori_shape in batch_ori_shape:
+            restore_list = list()
+            h, w = ori_shape[0], ori_shape[1]
+            for op in transforms:
+                if op.__class__.__name__ == 'Resize':
+                    restore_list.append(('resize', (h, w)))
+                    h, w = op.target_size
+                elif op.__class__.__name__ == 'ResizeByShort':
+                    restore_list.append(('resize', (h, w)))
+                    im_short_size = min(h, w)
+                    im_long_size = max(h, w)
+                    scale = float(op.short_size) / float(im_short_size)
+                    if 0 < op.max_size < np.round(scale * im_long_size):
+                        scale = float(op.max_size) / float(im_long_size)
+                    h = int(round(h * scale))
+                    w = int(round(w * scale))
+                elif op.__class__.__name__ == 'ResizeByLong':
+                    restore_list.append(('resize', (h, w)))
+                    im_long_size = max(h, w)
+                    scale = float(op.long_size) / float(im_long_size)
+                    h = int(round(h * scale))
+                    w = int(round(w * scale))
+                elif op.__class__.__name__ == 'Padding':
+                    if op.target_size:
+                        target_h, target_w = op.target_size
+                    else:
+                        target_h = int(
+                            (np.ceil(h / op.size_divisor) * op.size_divisor))
+                        target_w = int(
+                            (np.ceil(w / op.size_divisor) * op.size_divisor))
+
+                    if op.pad_mode == -1:
+                        offsets = op.offsets
+                    elif op.pad_mode == 0:
+                        offsets = [0, 0]
+                    elif op.pad_mode == 1:
+                        offsets = [(target_h - h) // 2, (target_w - w) // 2]
+                    else:
+                        offsets = [target_h - h, target_w - w]
+                    restore_list.append(('padding', (h, w), offsets))
+                    h, w = target_h, target_w
+
+            batch_restore_list.append(restore_list)
+        return batch_restore_list
+
+
+class ResNet50_vd(BaseClassifier):
+    def __init__(self, num_classes=2, use_mixed_loss=False, **params):
+        super(ResNet50_vd, self).__init__(
+            model_name='ResNet50_vd',
+            num_classes=num_classes,
+            use_mixed_loss=use_mixed_loss,
+            **params)
+
+
+class MobileNetV3_small_x1_0(BaseClassifier):
+    def __init__(self, num_classes=2, use_mixed_loss=False, **params):
+        super(MobileNetV3_small_x1_0, self).__init__(
+            model_name='MobileNetV3_small_x1_0',
+            num_classes=num_classes,
+            use_mixed_loss=use_mixed_loss,
+            **params)
+
+
+class HRNet_W18_C(BaseClassifier):
+    def __init__(self, num_classes=2, use_mixed_loss=False, **params):
+        super(HRNet_W18_C, self).__init__(
+            model_name='HRNet_W18_C',
+            num_classes=num_classes,
+            use_mixed_loss=use_mixed_loss,
+            **params)
+
+
+class CondenseNetV2_b(BaseClassifier):
+    def __init__(self, num_classes=2, use_mixed_loss=False, **params):
+        super(CondenseNetV2_b, self).__init__(
+            model_name='CondenseNetV2_b',
+            num_classes=num_classes,
+            use_mixed_loss=use_mixed_loss,
+            **params)

+ 3 - 3
paddlers/tasks/segmenter.py

@@ -21,7 +21,7 @@ import paddle
 import paddle.nn.functional as F
 from paddle.static import InputSpec
 import paddlers.models.ppseg as paddleseg
-import paddlers.custom_models.seg as seg
+import paddlers.custom_models.seg as cmseg
 import paddlers
 from paddlers.transforms import arrange_transforms
 from paddlers.utils import get_single_card_bs, DisablePrint
@@ -45,7 +45,7 @@ class BaseSegmenter(BaseModel):
             del self.init_params['with_net']
         super(BaseSegmenter, self).__init__('segmenter')
         if not hasattr(paddleseg.models, model_name) and \
-           not hasattr(seg.models, model_name):
+           not hasattr(cmseg, model_name):
             raise Exception("ERROR: There's no model named {}.".format(
                 model_name))
         self.model_name = model_name
@@ -62,7 +62,7 @@ class BaseSegmenter(BaseModel):
         # TODO: when using paddle.utils.unique_name.guard,
         # DeepLabv3p and HRNet will raise a error
         net = dict(paddleseg.models.__dict__,
-                   **seg.models.__dict__)[self.model_name](
+                   **cmseg.__dict__)[self.model_name](
                        num_classes=self.num_classes, **params)
         return net
 

+ 1 - 1
paddlers/transforms/functions.py

@@ -315,7 +315,7 @@ def select_bands(im, band_list=[1, 2, 3]):
             raise ValueError("The element in band_list must > 1 and <= {}.".
                              format(str(total_band)))
         result.append(im[:, :, band])
-    ima = np.stack(result, axis=0)
+    ima = np.stack(result, axis=-1)
     return ima
 
 

+ 49 - 0
tutorials/train/classification/condensenetv2_b_rs_mul.py

@@ -0,0 +1,49 @@
+import paddlers as pdrs
+from paddlers import transforms as T
+
+# 定义训练和验证时的transforms
+train_transforms = T.Compose([
+    T.BandSelecting([5, 10, 15, 20, 25]),  # for tet
+    T.Resize(target_size=224),
+    T.RandomHorizontalFlip(),
+    T.Normalize(
+        mean=[0.5, 0.5, 0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5, 0.5, 0.5]),
+])
+
+eval_transforms = T.Compose([
+    T.BandSelecting([5, 10, 15, 20, 25]),
+    T.Resize(target_size=224),
+    T.Normalize(
+        mean=[0.5, 0.5, 0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5, 0.5, 0.5]),
+])
+
+# 定义训练和验证所用的数据集
+train_dataset = pdrs.datasets.ClasDataset(
+    data_dir='tutorials/train/classification/DataSet',
+    file_list='tutorials/train/classification/DataSet/train_list.txt',
+    label_list='tutorials/train/classification/DataSet/label_list.txt',
+    transforms=train_transforms,
+    num_workers=0,
+    shuffle=True)
+
+eval_dataset = pdrs.datasets.ClasDataset(
+    data_dir='tutorials/train/classification/DataSet',
+    file_list='tutorials/train/classification/DataSet/val_list.txt',
+    label_list='tutorials/train/classification/DataSet/label_list.txt',
+    transforms=eval_transforms,
+    num_workers=0,
+    shuffle=False)
+
+# 初始化模型
+num_classes = len(train_dataset.labels)
+model = pdrs.tasks.CondenseNetV2_b(in_channels=5, num_classes=num_classes)
+
+# 进行训练
+model.train(
+    num_epochs=100,
+    pretrain_weights=None,
+    train_dataset=train_dataset,
+    train_batch_size=4,
+    eval_dataset=eval_dataset,
+    learning_rate=3e-4,
+    save_dir='output/condensenetv2_b')

+ 4 - 12
tutorials/train/classification/resnet50_vd_rs.py

@@ -1,7 +1,3 @@
-import sys
-
-sys.path.append("E:/dataFiles/github/PaddleRS")
-
 import paddlers as pdrs
 from paddlers import transforms as T
 
@@ -9,7 +5,6 @@ from paddlers import transforms as T
 # https://aistudio.baidu.com/aistudio/datasetdetail/63189
 
 # 定义训练和验证时的transforms
-# API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/transforms/transforms.md
 train_transforms = T.Compose([
     T.Resize(target_size=512),
     T.RandomHorizontalFlip(),
@@ -24,9 +19,8 @@ eval_transforms = T.Compose([
 ])
 
 # 定义训练和验证所用的数据集
-# API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/datasets.md
 train_dataset = pdrs.datasets.ClasDataset(
-    data_dir='E:/dataFiles/github/PaddleRS/tutorials/train/classification/DataSet',
+    data_dir='tutorials/train/classification/DataSet',
     file_list='tutorials/train/classification/DataSet/train_list.txt',
     label_list='tutorials/train/classification/DataSet/label_list.txt',
     transforms=train_transforms,
@@ -34,20 +28,18 @@ train_dataset = pdrs.datasets.ClasDataset(
     shuffle=True)
 
 eval_dataset = pdrs.datasets.ClasDataset(
-    data_dir='E:/dataFiles/github/PaddleRS/tutorials/train/classification/DataSet',
+    data_dir='tutorials/train/classification/DataSet',
     file_list='tutorials/train/classification/DataSet/test_list.txt',
     label_list='tutorials/train/classification/DataSet/label_list.txt',
     transforms=eval_transforms,
     num_workers=0,
     shuffle=False)
 
-# 初始化模型,并进行训练
-# 可使用VisualDL查看训练指标,参考https://github.com/PaddlePaddle/paddlers/blob/develop/docs/visualdl.md
+# 初始化模型
 num_classes = len(train_dataset.labels)
 model = pdrs.tasks.ResNet50_vd(num_classes=num_classes)
 
-# API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/models/semantic_segmentation.md
-# 各参数介绍与调整说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/parameters.md
+# 进行训练
 model.train(
     num_epochs=10,
     train_dataset=train_dataset,