|
@@ -1,4 +1,4 @@
|
|
|
-# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
|
|
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
|
|
#
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
# you may not use this file except in compliance with the License.
|
|
@@ -14,11 +14,13 @@
|
|
|
|
|
|
import paddle
|
|
|
import paddle.nn as nn
|
|
|
-import paddle.nn.functional as F
|
|
|
+from paddle import ParamAttr
|
|
|
+from paddle.regularizer import L2Decay
|
|
|
+from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Dropout, Linear
|
|
|
|
|
|
-from paddlers.models.ppseg.cvlibs import manager
|
|
|
-from paddlers.models.ppseg.utils import utils
|
|
|
-from paddlers.models.ppseg.models import layers
|
|
|
+from paddleseg.cvlibs import manager
|
|
|
+from paddleseg.utils import utils, logger
|
|
|
+from paddleseg.models import layers
|
|
|
|
|
|
__all__ = [
|
|
|
"MobileNetV3_small_x0_35", "MobileNetV3_small_x0_5",
|
|
@@ -28,8 +30,92 @@ __all__ = [
|
|
|
"MobileNetV3_large_x1_0", "MobileNetV3_large_x1_25"
|
|
|
]
|
|
|
|
|
|
-
|
|
|
-def make_divisible(v, divisor=8, min_value=None):
|
|
|
+MODEL_STAGES_PATTERN = {
|
|
|
+ "MobileNetV3_small": ["blocks[0]", "blocks[2]", "blocks[7]", "blocks[10]"],
|
|
|
+ "MobileNetV3_large":
|
|
|
+ ["blocks[0]", "blocks[2]", "blocks[5]", "blocks[11]", "blocks[14]"]
|
|
|
+}
|
|
|
+
|
|
|
+# "large", "small" is just for MobinetV3_large, MobileNetV3_small respectively.
|
|
|
+# The type of "large" or "small" config is a list. Each element(list) represents a depthwise block, which is composed of k, exp, se, act, s.
|
|
|
+# k: kernel_size
|
|
|
+# exp: middle channel number in depthwise block
|
|
|
+# c: output channel number in depthwise block
|
|
|
+# se: whether to use SE block
|
|
|
+# act: which activation to use
|
|
|
+# s: stride in depthwise block
|
|
|
+# d: dilation rate in depthwise block
|
|
|
+NET_CONFIG = {
|
|
|
+ "large": [
|
|
|
+ # k, exp, c, se, act, s
|
|
|
+ [3, 16, 16, False, "relu", 1],
|
|
|
+ [3, 64, 24, False, "relu", 2],
|
|
|
+ [3, 72, 24, False, "relu", 1], # x4
|
|
|
+ [5, 72, 40, True, "relu", 2],
|
|
|
+ [5, 120, 40, True, "relu", 1],
|
|
|
+ [5, 120, 40, True, "relu", 1], # x8
|
|
|
+ [3, 240, 80, False, "hardswish", 2],
|
|
|
+ [3, 200, 80, False, "hardswish", 1],
|
|
|
+ [3, 184, 80, False, "hardswish", 1],
|
|
|
+ [3, 184, 80, False, "hardswish", 1],
|
|
|
+ [3, 480, 112, True, "hardswish", 1],
|
|
|
+ [3, 672, 112, True, "hardswish", 1], # x16
|
|
|
+ [5, 672, 160, True, "hardswish", 2],
|
|
|
+ [5, 960, 160, True, "hardswish", 1],
|
|
|
+ [5, 960, 160, True, "hardswish", 1], # x32
|
|
|
+ ],
|
|
|
+ "small": [
|
|
|
+ # k, exp, c, se, act, s
|
|
|
+ [3, 16, 16, True, "relu", 2],
|
|
|
+ [3, 72, 24, False, "relu", 2],
|
|
|
+ [3, 88, 24, False, "relu", 1],
|
|
|
+ [5, 96, 40, True, "hardswish", 2],
|
|
|
+ [5, 240, 40, True, "hardswish", 1],
|
|
|
+ [5, 240, 40, True, "hardswish", 1],
|
|
|
+ [5, 120, 48, True, "hardswish", 1],
|
|
|
+ [5, 144, 48, True, "hardswish", 1],
|
|
|
+ [5, 288, 96, True, "hardswish", 2],
|
|
|
+ [5, 576, 96, True, "hardswish", 1],
|
|
|
+ [5, 576, 96, True, "hardswish", 1],
|
|
|
+ ],
|
|
|
+ "large_os8": [
|
|
|
+ # k, exp, c, se, act, s, {d}
|
|
|
+ [3, 16, 16, False, "relu", 1],
|
|
|
+ [3, 64, 24, False, "relu", 2],
|
|
|
+ [3, 72, 24, False, "relu", 1], # x4
|
|
|
+ [5, 72, 40, True, "relu", 2],
|
|
|
+ [5, 120, 40, True, "relu", 1],
|
|
|
+ [5, 120, 40, True, "relu", 1], # x8
|
|
|
+ [3, 240, 80, False, "hardswish", 1],
|
|
|
+ [3, 200, 80, False, "hardswish", 1, 2],
|
|
|
+ [3, 184, 80, False, "hardswish", 1, 2],
|
|
|
+ [3, 184, 80, False, "hardswish", 1, 2],
|
|
|
+ [3, 480, 112, True, "hardswish", 1, 2],
|
|
|
+ [3, 672, 112, True, "hardswish", 1, 2],
|
|
|
+ [5, 672, 160, True, "hardswish", 1, 2],
|
|
|
+ [5, 960, 160, True, "hardswish", 1, 4],
|
|
|
+ [5, 960, 160, True, "hardswish", 1, 4],
|
|
|
+ ],
|
|
|
+ "small_os8": [
|
|
|
+ # k, exp, c, se, act, s, {d}
|
|
|
+ [3, 16, 16, True, "relu", 2],
|
|
|
+ [3, 72, 24, False, "relu", 2],
|
|
|
+ [3, 88, 24, False, "relu", 1],
|
|
|
+ [5, 96, 40, True, "hardswish", 1],
|
|
|
+ [5, 240, 40, True, "hardswish", 1, 2],
|
|
|
+ [5, 240, 40, True, "hardswish", 1, 2],
|
|
|
+ [5, 120, 48, True, "hardswish", 1, 2],
|
|
|
+ [5, 144, 48, True, "hardswish", 1, 2],
|
|
|
+ [5, 288, 96, True, "hardswish", 1, 2],
|
|
|
+ [5, 576, 96, True, "hardswish", 1, 4],
|
|
|
+ [5, 576, 96, True, "hardswish", 1, 4],
|
|
|
+ ]
|
|
|
+}
|
|
|
+
|
|
|
+OUT_INDEX = {"large": [2, 5, 11, 14], "small": [0, 2, 7, 10]}
|
|
|
+
|
|
|
+
|
|
|
+def _make_divisible(v, divisor=8, min_value=None):
|
|
|
if min_value is None:
|
|
|
min_value = divisor
|
|
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
|
@@ -38,156 +124,113 @@ def make_divisible(v, divisor=8, min_value=None):
|
|
|
return new_v
|
|
|
|
|
|
|
|
|
-class MobileNetV3(nn.Layer):
|
|
|
- """
|
|
|
- The MobileNetV3 implementation based on PaddlePaddle.
|
|
|
+def _create_act(act):
|
|
|
+ if act == "hardswish":
|
|
|
+ return nn.Hardswish()
|
|
|
+ elif act == "relu":
|
|
|
+ return nn.ReLU()
|
|
|
+ elif act is None:
|
|
|
+ return None
|
|
|
+ else:
|
|
|
+ raise RuntimeError(
|
|
|
+ "The activation function is not supported: {}".format(act))
|
|
|
|
|
|
- The original article refers to Jingdong
|
|
|
- Andrew Howard, et, al. "Searching for MobileNetV3"
|
|
|
- (https://arxiv.org/pdf/1905.02244.pdf).
|
|
|
|
|
|
+class MobileNetV3(nn.Layer):
|
|
|
+ """
|
|
|
+ MobileNetV3
|
|
|
Args:
|
|
|
- pretrained (str, optional): The path of pretrained model.
|
|
|
- scale (float, optional): The scale of channels . Default: 1.0.
|
|
|
- model_name (str, optional): Model name. It determines the type of MobileNetV3. The value is 'small' or 'large'. Defualt: 'small'.
|
|
|
- output_stride (int, optional): The stride of output features compared to input images. The value should be one of (2, 4, 8, 16, 32). Default: None.
|
|
|
-
|
|
|
+ config: list. MobileNetV3 depthwise blocks config.
|
|
|
+ in_channels (int, optional): The channels of input image. Default: 3.
|
|
|
+ scale: float=1.0. The coefficient that controls the size of network parameters.
|
|
|
+ Returns:
|
|
|
+ model: nn.Layer. Specific MobileNetV3 model depends on args.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
- pretrained=None,
|
|
|
+ config,
|
|
|
+ stages_pattern,
|
|
|
+ out_index,
|
|
|
+ in_channels=3,
|
|
|
scale=1.0,
|
|
|
- model_name="small",
|
|
|
- output_stride=None):
|
|
|
- super(MobileNetV3, self).__init__()
|
|
|
+ pretrained=None):
|
|
|
+ super().__init__()
|
|
|
|
|
|
+ self.cfg = config
|
|
|
+ self.out_index = out_index
|
|
|
+ self.scale = scale
|
|
|
+ self.pretrained = pretrained
|
|
|
inplanes = 16
|
|
|
- if model_name == "large":
|
|
|
- self.cfg = [
|
|
|
- # k, exp, c, se, nl, s,
|
|
|
- [3, 16, 16, False, "relu", 1],
|
|
|
- [3, 64, 24, False, "relu", 2],
|
|
|
- [3, 72, 24, False, "relu", 1], # output 1 -> out_index=2
|
|
|
- [5, 72, 40, True, "relu", 2],
|
|
|
- [5, 120, 40, True, "relu", 1],
|
|
|
- [5, 120, 40, True, "relu", 1], # output 2 -> out_index=5
|
|
|
- [3, 240, 80, False, "hard_swish", 2],
|
|
|
- [3, 200, 80, False, "hard_swish", 1],
|
|
|
- [3, 184, 80, False, "hard_swish", 1],
|
|
|
- [3, 184, 80, False, "hard_swish", 1],
|
|
|
- [3, 480, 112, True, "hard_swish", 1],
|
|
|
- [3, 672, 112, True, "hard_swish",
|
|
|
- 1], # output 3 -> out_index=11
|
|
|
- [5, 672, 160, True, "hard_swish", 2],
|
|
|
- [5, 960, 160, True, "hard_swish", 1],
|
|
|
- [5, 960, 160, True, "hard_swish",
|
|
|
- 1], # output 3 -> out_index=14
|
|
|
- ]
|
|
|
- self.out_indices = [2, 5, 11, 14]
|
|
|
- self.feat_channels = [
|
|
|
- make_divisible(i * scale) for i in [24, 40, 112, 160]
|
|
|
- ]
|
|
|
-
|
|
|
- self.cls_ch_squeeze = 960
|
|
|
- self.cls_ch_expand = 1280
|
|
|
- elif model_name == "small":
|
|
|
- self.cfg = [
|
|
|
- # k, exp, c, se, nl, s,
|
|
|
- [3, 16, 16, True, "relu", 2], # output 1 -> out_index=0
|
|
|
- [3, 72, 24, False, "relu", 2],
|
|
|
- [3, 88, 24, False, "relu", 1], # output 2 -> out_index=3
|
|
|
- [5, 96, 40, True, "hard_swish", 2],
|
|
|
- [5, 240, 40, True, "hard_swish", 1],
|
|
|
- [5, 240, 40, True, "hard_swish", 1],
|
|
|
- [5, 120, 48, True, "hard_swish", 1],
|
|
|
- [5, 144, 48, True, "hard_swish", 1], # output 3 -> out_index=7
|
|
|
- [5, 288, 96, True, "hard_swish", 2],
|
|
|
- [5, 576, 96, True, "hard_swish", 1],
|
|
|
- [5, 576, 96, True, "hard_swish", 1], # output 4 -> out_index=10
|
|
|
- ]
|
|
|
- self.out_indices = [0, 3, 7, 10]
|
|
|
- self.feat_channels = [
|
|
|
- make_divisible(i * scale) for i in [16, 24, 48, 96]
|
|
|
- ]
|
|
|
-
|
|
|
- self.cls_ch_squeeze = 576
|
|
|
- self.cls_ch_expand = 1280
|
|
|
- else:
|
|
|
- raise NotImplementedError(
|
|
|
- "mode[{}_model] is not implemented!".format(model_name))
|
|
|
-
|
|
|
- ###################################################
|
|
|
- # modify stride and dilation based on output_stride
|
|
|
- self.dilation_cfg = [1] * len(self.cfg)
|
|
|
- self.modify_bottle_params(output_stride=output_stride)
|
|
|
- ###################################################
|
|
|
-
|
|
|
- self.conv1 = ConvBNLayer(
|
|
|
- in_c=3,
|
|
|
- out_c=make_divisible(inplanes * scale),
|
|
|
+
|
|
|
+ self.conv = ConvBNLayer(
|
|
|
+ in_c=in_channels,
|
|
|
+ out_c=_make_divisible(inplanes * self.scale),
|
|
|
filter_size=3,
|
|
|
stride=2,
|
|
|
padding=1,
|
|
|
num_groups=1,
|
|
|
if_act=True,
|
|
|
- act="hard_swish")
|
|
|
-
|
|
|
- self.block_list = []
|
|
|
-
|
|
|
- inplanes = make_divisible(inplanes * scale)
|
|
|
- for i, (k, exp, c, se, nl, s) in enumerate(self.cfg):
|
|
|
- ######################################
|
|
|
- # add dilation rate
|
|
|
- dilation_rate = self.dilation_cfg[i]
|
|
|
- ######################################
|
|
|
- self.block_list.append(
|
|
|
- ResidualUnit(
|
|
|
- in_c=inplanes,
|
|
|
- mid_c=make_divisible(scale * exp),
|
|
|
- out_c=make_divisible(scale * c),
|
|
|
- filter_size=k,
|
|
|
- stride=s,
|
|
|
- dilation=dilation_rate,
|
|
|
- use_se=se,
|
|
|
- act=nl,
|
|
|
- name="conv" + str(i + 2)))
|
|
|
- self.add_sublayer(
|
|
|
- sublayer=self.block_list[-1], name="conv" + str(i + 2))
|
|
|
- inplanes = make_divisible(scale * c)
|
|
|
-
|
|
|
- self.pretrained = pretrained
|
|
|
+ act="hardswish")
|
|
|
+ self.blocks = nn.Sequential(*[
|
|
|
+ ResidualUnit(
|
|
|
+ in_c=_make_divisible(inplanes * self.scale if i == 0 else
|
|
|
+ self.cfg[i - 1][2] * self.scale),
|
|
|
+ mid_c=_make_divisible(self.scale * exp),
|
|
|
+ out_c=_make_divisible(self.scale * c),
|
|
|
+ filter_size=k,
|
|
|
+ stride=s,
|
|
|
+ use_se=se,
|
|
|
+ act=act,
|
|
|
+ dilation=td[0] if td else 1)
|
|
|
+ for i, (k, exp, c, se, act, s, *td) in enumerate(self.cfg)
|
|
|
+ ])
|
|
|
+
|
|
|
+ out_channels = [config[idx][2] for idx in self.out_index]
|
|
|
+ self.feat_channels = [
|
|
|
+ _make_divisible(self.scale * c) for c in out_channels
|
|
|
+ ]
|
|
|
+
|
|
|
+ self.init_res(stages_pattern)
|
|
|
self.init_weight()
|
|
|
|
|
|
- def modify_bottle_params(self, output_stride=None):
|
|
|
-
|
|
|
- if output_stride is not None and output_stride % 2 != 0:
|
|
|
- raise ValueError("output stride must to be even number")
|
|
|
- if output_stride is not None:
|
|
|
- stride = 2
|
|
|
- rate = 1
|
|
|
- for i, _cfg in enumerate(self.cfg):
|
|
|
- stride = stride * _cfg[-1]
|
|
|
- if stride > output_stride:
|
|
|
- rate = rate * _cfg[-1]
|
|
|
- self.cfg[i][-1] = 1
|
|
|
+ def init_weight(self):
|
|
|
+ if self.pretrained is not None:
|
|
|
+ utils.load_entire_model(self, self.pretrained)
|
|
|
+
|
|
|
+ def init_res(self, stages_pattern, return_patterns=None,
|
|
|
+ return_stages=None):
|
|
|
+ if return_patterns and return_stages:
|
|
|
+ msg = f"The 'return_patterns' would be ignored when 'return_stages' is set."
|
|
|
+ logger.warning(msg)
|
|
|
+ return_stages = None
|
|
|
+
|
|
|
+ if return_stages is True:
|
|
|
+ return_patterns = stages_pattern
|
|
|
+ # return_stages is int or bool
|
|
|
+ if type(return_stages) is int:
|
|
|
+ return_stages = [return_stages]
|
|
|
+ if isinstance(return_stages, list):
|
|
|
+ if max(return_stages) > len(stages_pattern) or min(
|
|
|
+ return_stages) < 0:
|
|
|
+ msg = f"The 'return_stages' set error. Illegal value(s) have been ignored. The stages' pattern list is {stages_pattern}."
|
|
|
+ logger.warning(msg)
|
|
|
+ return_stages = [
|
|
|
+ val for val in return_stages
|
|
|
+ if val >= 0 and val < len(stages_pattern)
|
|
|
+ ]
|
|
|
+ return_patterns = [stages_pattern[i] for i in return_stages]
|
|
|
|
|
|
- self.dilation_cfg[i] = rate
|
|
|
+ def forward(self, x):
|
|
|
+ x = self.conv(x)
|
|
|
|
|
|
- def forward(self, inputs, label=None):
|
|
|
- x = self.conv1(inputs)
|
|
|
- # A feature list saves each downsampling feature.
|
|
|
feat_list = []
|
|
|
- for i, block in enumerate(self.block_list):
|
|
|
+ for idx, block in enumerate(self.blocks):
|
|
|
x = block(x)
|
|
|
- if i in self.out_indices:
|
|
|
+ if idx in self.out_index:
|
|
|
feat_list.append(x)
|
|
|
|
|
|
return feat_list
|
|
|
|
|
|
- def init_weight(self):
|
|
|
- if self.pretrained is not None:
|
|
|
- utils.load_pretrained_model(self, self.pretrained)
|
|
|
-
|
|
|
|
|
|
class ConvBNLayer(nn.Layer):
|
|
|
def __init__(self,
|
|
@@ -196,36 +239,34 @@ class ConvBNLayer(nn.Layer):
|
|
|
filter_size,
|
|
|
stride,
|
|
|
padding,
|
|
|
- dilation=1,
|
|
|
num_groups=1,
|
|
|
if_act=True,
|
|
|
- act=None):
|
|
|
- super(ConvBNLayer, self).__init__()
|
|
|
- self.if_act = if_act
|
|
|
- self.act = act
|
|
|
+ act=None,
|
|
|
+ dilation=1):
|
|
|
+ super().__init__()
|
|
|
|
|
|
- self.conv = nn.Conv2D(
|
|
|
+ self.conv = Conv2D(
|
|
|
in_channels=in_c,
|
|
|
out_channels=out_c,
|
|
|
kernel_size=filter_size,
|
|
|
stride=stride,
|
|
|
padding=padding,
|
|
|
- dilation=dilation,
|
|
|
groups=num_groups,
|
|
|
- bias_attr=False)
|
|
|
- self.bn = layers.SyncBatchNorm(
|
|
|
- num_features=out_c,
|
|
|
- weight_attr=paddle.ParamAttr(
|
|
|
- regularizer=paddle.regularizer.L2Decay(0.0)),
|
|
|
- bias_attr=paddle.ParamAttr(
|
|
|
- regularizer=paddle.regularizer.L2Decay(0.0)))
|
|
|
- self._act_op = layers.Activation(act='hardswish')
|
|
|
+ bias_attr=False,
|
|
|
+ dilation=dilation)
|
|
|
+ self.bn = BatchNorm(
|
|
|
+ num_channels=out_c,
|
|
|
+ act=None,
|
|
|
+ param_attr=ParamAttr(regularizer=L2Decay(0.0)),
|
|
|
+ bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
|
|
|
+ self.if_act = if_act
|
|
|
+ self.act = _create_act(act)
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = self.conv(x)
|
|
|
x = self.bn(x)
|
|
|
if self.if_act:
|
|
|
- x = self._act_op(x)
|
|
|
+ x = self.act(x)
|
|
|
return x
|
|
|
|
|
|
|
|
@@ -237,10 +278,9 @@ class ResidualUnit(nn.Layer):
|
|
|
filter_size,
|
|
|
stride,
|
|
|
use_se,
|
|
|
- dilation=1,
|
|
|
act=None,
|
|
|
- name=''):
|
|
|
- super(ResidualUnit, self).__init__()
|
|
|
+ dilation=1):
|
|
|
+ super().__init__()
|
|
|
self.if_shortcut = stride == 1 and in_c == out_c
|
|
|
self.if_se = use_se
|
|
|
|
|
@@ -252,19 +292,18 @@ class ResidualUnit(nn.Layer):
|
|
|
padding=0,
|
|
|
if_act=True,
|
|
|
act=act)
|
|
|
-
|
|
|
self.bottleneck_conv = ConvBNLayer(
|
|
|
in_c=mid_c,
|
|
|
out_c=mid_c,
|
|
|
filter_size=filter_size,
|
|
|
stride=stride,
|
|
|
- padding='same',
|
|
|
- dilation=dilation,
|
|
|
+ padding=int((filter_size - 1) // 2) * dilation,
|
|
|
num_groups=mid_c,
|
|
|
if_act=True,
|
|
|
- act=act)
|
|
|
+ act=act,
|
|
|
+ dilation=dilation)
|
|
|
if self.if_se:
|
|
|
- self.mid_se = SEModule(mid_c, name=name + "_se")
|
|
|
+ self.mid_se = SEModule(mid_c)
|
|
|
self.linear_conv = ConvBNLayer(
|
|
|
in_c=mid_c,
|
|
|
out_c=out_c,
|
|
@@ -273,92 +312,187 @@ class ResidualUnit(nn.Layer):
|
|
|
padding=0,
|
|
|
if_act=False,
|
|
|
act=None)
|
|
|
- self.dilation = dilation
|
|
|
|
|
|
- def forward(self, inputs):
|
|
|
- x = self.expand_conv(inputs)
|
|
|
+ def forward(self, x):
|
|
|
+ identity = x
|
|
|
+ x = self.expand_conv(x)
|
|
|
x = self.bottleneck_conv(x)
|
|
|
if self.if_se:
|
|
|
x = self.mid_se(x)
|
|
|
x = self.linear_conv(x)
|
|
|
if self.if_shortcut:
|
|
|
- x = inputs + x
|
|
|
+ x = paddle.add(identity, x)
|
|
|
return x
|
|
|
|
|
|
|
|
|
+# nn.Hardsigmoid can't transfer "slope" and "offset" in nn.functional.hardsigmoid
|
|
|
+class Hardsigmoid(nn.Layer):
|
|
|
+ def __init__(self, slope=0.2, offset=0.5):
|
|
|
+ super().__init__()
|
|
|
+ self.slope = slope
|
|
|
+ self.offset = offset
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ return nn.functional.hardsigmoid(
|
|
|
+ x, slope=self.slope, offset=self.offset)
|
|
|
+
|
|
|
+
|
|
|
class SEModule(nn.Layer):
|
|
|
- def __init__(self, channel, reduction=4, name=""):
|
|
|
- super(SEModule, self).__init__()
|
|
|
- self.avg_pool = nn.AdaptiveAvgPool2D(1)
|
|
|
- self.conv1 = nn.Conv2D(
|
|
|
+ def __init__(self, channel, reduction=4):
|
|
|
+ super().__init__()
|
|
|
+ self.avg_pool = AdaptiveAvgPool2D(1)
|
|
|
+ self.conv1 = Conv2D(
|
|
|
in_channels=channel,
|
|
|
out_channels=channel // reduction,
|
|
|
kernel_size=1,
|
|
|
stride=1,
|
|
|
padding=0)
|
|
|
- self.conv2 = nn.Conv2D(
|
|
|
+ self.relu = nn.ReLU()
|
|
|
+ self.conv2 = Conv2D(
|
|
|
in_channels=channel // reduction,
|
|
|
out_channels=channel,
|
|
|
kernel_size=1,
|
|
|
stride=1,
|
|
|
padding=0)
|
|
|
+ self.hardsigmoid = Hardsigmoid(slope=0.2, offset=0.5)
|
|
|
|
|
|
- def forward(self, inputs):
|
|
|
- outputs = self.avg_pool(inputs)
|
|
|
- outputs = self.conv1(outputs)
|
|
|
- outputs = F.relu(outputs)
|
|
|
- outputs = self.conv2(outputs)
|
|
|
- outputs = F.hardsigmoid(outputs)
|
|
|
- return paddle.multiply(x=inputs, y=outputs)
|
|
|
+ def forward(self, x):
|
|
|
+ identity = x
|
|
|
+ x = self.avg_pool(x)
|
|
|
+ x = self.conv1(x)
|
|
|
+ x = self.relu(x)
|
|
|
+ x = self.conv2(x)
|
|
|
+ x = self.hardsigmoid(x)
|
|
|
+ return paddle.multiply(x=identity, y=x)
|
|
|
|
|
|
|
|
|
+@manager.BACKBONES.add_component
|
|
|
def MobileNetV3_small_x0_35(**kwargs):
|
|
|
- model = MobileNetV3(model_name="small", scale=0.35, **kwargs)
|
|
|
+ model = MobileNetV3(
|
|
|
+ config=NET_CONFIG["small"],
|
|
|
+ scale=0.35,
|
|
|
+ stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"],
|
|
|
+ out_index=OUT_INDEX["small"],
|
|
|
+ **kwargs)
|
|
|
return model
|
|
|
|
|
|
|
|
|
+@manager.BACKBONES.add_component
|
|
|
def MobileNetV3_small_x0_5(**kwargs):
|
|
|
- model = MobileNetV3(model_name="small", scale=0.5, **kwargs)
|
|
|
+ model = MobileNetV3(
|
|
|
+ config=NET_CONFIG["small"],
|
|
|
+ scale=0.5,
|
|
|
+ stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"],
|
|
|
+ out_index=OUT_INDEX["small"],
|
|
|
+ **kwargs)
|
|
|
return model
|
|
|
|
|
|
|
|
|
+@manager.BACKBONES.add_component
|
|
|
def MobileNetV3_small_x0_75(**kwargs):
|
|
|
- model = MobileNetV3(model_name="small", scale=0.75, **kwargs)
|
|
|
+ model = MobileNetV3(
|
|
|
+ config=NET_CONFIG["small"],
|
|
|
+ scale=0.75,
|
|
|
+ stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"],
|
|
|
+ out_index=OUT_INDEX["small"],
|
|
|
+ **kwargs)
|
|
|
return model
|
|
|
|
|
|
|
|
|
@manager.BACKBONES.add_component
|
|
|
def MobileNetV3_small_x1_0(**kwargs):
|
|
|
- model = MobileNetV3(model_name="small", scale=1.0, **kwargs)
|
|
|
+ model = MobileNetV3(
|
|
|
+ config=NET_CONFIG["small"],
|
|
|
+ scale=1.0,
|
|
|
+ stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"],
|
|
|
+ out_index=OUT_INDEX["small"],
|
|
|
+ **kwargs)
|
|
|
return model
|
|
|
|
|
|
|
|
|
+@manager.BACKBONES.add_component
|
|
|
def MobileNetV3_small_x1_25(**kwargs):
|
|
|
- model = MobileNetV3(model_name="small", scale=1.25, **kwargs)
|
|
|
+ model = MobileNetV3(
|
|
|
+ config=NET_CONFIG["small"],
|
|
|
+ scale=1.25,
|
|
|
+ stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"],
|
|
|
+ out_index=OUT_INDEX["small"],
|
|
|
+ **kwargs)
|
|
|
return model
|
|
|
|
|
|
|
|
|
+@manager.BACKBONES.add_component
|
|
|
def MobileNetV3_large_x0_35(**kwargs):
|
|
|
- model = MobileNetV3(model_name="large", scale=0.35, **kwargs)
|
|
|
+ model = MobileNetV3(
|
|
|
+ config=NET_CONFIG["large"],
|
|
|
+ scale=0.35,
|
|
|
+ stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"],
|
|
|
+ out_index=OUT_INDEX["large"],
|
|
|
+ **kwargs)
|
|
|
return model
|
|
|
|
|
|
|
|
|
+@manager.BACKBONES.add_component
|
|
|
def MobileNetV3_large_x0_5(**kwargs):
|
|
|
- model = MobileNetV3(model_name="large", scale=0.5, **kwargs)
|
|
|
+ model = MobileNetV3(
|
|
|
+ config=NET_CONFIG["large"],
|
|
|
+ scale=0.5,
|
|
|
+ stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_large"],
|
|
|
+ out_index=OUT_INDEX["large"],
|
|
|
+ **kwargs)
|
|
|
return model
|
|
|
|
|
|
|
|
|
+@manager.BACKBONES.add_component
|
|
|
def MobileNetV3_large_x0_75(**kwargs):
|
|
|
- model = MobileNetV3(model_name="large", scale=0.75, **kwargs)
|
|
|
+ model = MobileNetV3(
|
|
|
+ config=NET_CONFIG["large"],
|
|
|
+ scale=0.75,
|
|
|
+ stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_large"],
|
|
|
+ out_index=OUT_INDEX["large"],
|
|
|
+ **kwargs)
|
|
|
return model
|
|
|
|
|
|
|
|
|
@manager.BACKBONES.add_component
|
|
|
def MobileNetV3_large_x1_0(**kwargs):
|
|
|
- model = MobileNetV3(model_name="large", scale=1.0, **kwargs)
|
|
|
+ model = MobileNetV3(
|
|
|
+ config=NET_CONFIG["large"],
|
|
|
+ scale=1.0,
|
|
|
+ stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_large"],
|
|
|
+ out_index=OUT_INDEX["large"],
|
|
|
+ **kwargs)
|
|
|
return model
|
|
|
|
|
|
|
|
|
+@manager.BACKBONES.add_component
|
|
|
def MobileNetV3_large_x1_25(**kwargs):
|
|
|
- model = MobileNetV3(model_name="large", scale=1.25, **kwargs)
|
|
|
+ model = MobileNetV3(
|
|
|
+ config=NET_CONFIG["large"],
|
|
|
+ scale=1.25,
|
|
|
+ stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_large"],
|
|
|
+ out_index=OUT_INDEX["large"],
|
|
|
+ **kwargs)
|
|
|
+ return model
|
|
|
+
|
|
|
+
|
|
|
+@manager.BACKBONES.add_component
|
|
|
+def MobileNetV3_large_x1_0_os8(**kwargs):
|
|
|
+ model = MobileNetV3(
|
|
|
+ config=NET_CONFIG["large_os8"],
|
|
|
+ scale=1.0,
|
|
|
+ stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_large"],
|
|
|
+ out_index=OUT_INDEX["large"],
|
|
|
+ **kwargs)
|
|
|
+ return model
|
|
|
+
|
|
|
+
|
|
|
+@manager.BACKBONES.add_component
|
|
|
+def MobileNetV3_small_x1_0_os8(**kwargs):
|
|
|
+ model = MobileNetV3(
|
|
|
+ config=NET_CONFIG["small_os8"],
|
|
|
+ scale=1.0,
|
|
|
+ stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"],
|
|
|
+ out_index=OUT_INDEX["small"],
|
|
|
+ **kwargs)
|
|
|
return model
|