condensenet_v2.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/AgentMaker/Paddle-Image-Models
  16. Ths copyright of AgentMaker/Paddle-Image-Models is as follows:
  17. Apache License [see LICENSE for details]
  18. """
  19. import paddle
  20. import paddle.nn as nn
  21. __all__ = ["CondenseNetV2_a", "CondenseNetV2_b", "CondenseNetV2_c"]
  22. class SELayer(nn.Layer):
  23. def __init__(self, inplanes, reduction=16):
  24. super(SELayer, self).__init__()
  25. self.avg_pool = nn.AdaptiveAvgPool2D(1)
  26. self.fc = nn.Sequential(
  27. nn.Linear(
  28. inplanes, inplanes // reduction, bias_attr=False),
  29. nn.ReLU(),
  30. nn.Linear(
  31. inplanes // reduction, inplanes, bias_attr=False),
  32. nn.Sigmoid(), )
  33. def forward(self, x):
  34. b, c, _, _ = x.shape
  35. y = self.avg_pool(x).reshape((b, c))
  36. y = self.fc(y).reshape((b, c, 1, 1))
  37. return x * y.expand_as(x)
  38. class HS(nn.Layer):
  39. def __init__(self):
  40. super(HS, self).__init__()
  41. self.relu6 = nn.ReLU6()
  42. def forward(self, inputs):
  43. return inputs * self.relu6(inputs + 3) / 6
  44. class Conv(nn.Sequential):
  45. def __init__(
  46. self,
  47. in_channels,
  48. out_channels,
  49. kernel_size,
  50. stride=1,
  51. padding=0,
  52. groups=1,
  53. activation="ReLU",
  54. bn_momentum=0.9, ):
  55. super(Conv, self).__init__()
  56. self.add_sublayer(
  57. "norm", nn.BatchNorm2D(
  58. in_channels, momentum=bn_momentum))
  59. if activation == "ReLU":
  60. self.add_sublayer("activation", nn.ReLU())
  61. elif activation == "HS":
  62. self.add_sublayer("activation", HS())
  63. else:
  64. raise NotImplementedError
  65. self.add_sublayer(
  66. "conv",
  67. nn.Conv2D(
  68. in_channels,
  69. out_channels,
  70. kernel_size=kernel_size,
  71. stride=stride,
  72. padding=padding,
  73. bias_attr=False,
  74. groups=groups, ), )
  75. def ShuffleLayer(x, groups):
  76. batchsize, num_channels, height, width = x.shape
  77. channels_per_group = num_channels // groups
  78. # reshape
  79. x = x.reshape((batchsize, groups, channels_per_group, height, width))
  80. # transpose
  81. x = x.transpose((0, 2, 1, 3, 4))
  82. # reshape
  83. x = x.reshape((batchsize, -1, height, width))
  84. return x
  85. def ShuffleLayerTrans(x, groups):
  86. batchsize, num_channels, height, width = x.shape
  87. channels_per_group = num_channels // groups
  88. # reshape
  89. x = x.reshape((batchsize, channels_per_group, groups, height, width))
  90. # transpose
  91. x = x.transpose((0, 2, 1, 3, 4))
  92. # reshape
  93. x = x.reshape((batchsize, -1, height, width))
  94. return x
  95. class CondenseLGC(nn.Layer):
  96. def __init__(
  97. self,
  98. in_channels,
  99. out_channels,
  100. kernel_size,
  101. stride=1,
  102. padding=0,
  103. groups=1,
  104. activation="ReLU", ):
  105. super(CondenseLGC, self).__init__()
  106. self.in_channels = in_channels
  107. self.out_channels = out_channels
  108. self.groups = groups
  109. self.norm = nn.BatchNorm2D(self.in_channels)
  110. if activation == "ReLU":
  111. self.activation = nn.ReLU()
  112. elif activation == "HS":
  113. self.activation = HS()
  114. else:
  115. raise NotImplementedError
  116. self.conv = nn.Conv2D(
  117. self.in_channels,
  118. self.out_channels,
  119. kernel_size=kernel_size,
  120. stride=stride,
  121. padding=padding,
  122. groups=self.groups,
  123. bias_attr=False, )
  124. self.register_buffer(
  125. "index", paddle.zeros(
  126. (self.in_channels, ), dtype="int64"))
  127. def forward(self, x):
  128. x = paddle.index_select(x, self.index, axis=1)
  129. x = self.norm(x)
  130. x = self.activation(x)
  131. x = self.conv(x)
  132. x = ShuffleLayer(x, self.groups)
  133. return x
  134. class CondenseSFR(nn.Layer):
  135. def __init__(
  136. self,
  137. in_channels,
  138. out_channels,
  139. kernel_size,
  140. stride=1,
  141. padding=0,
  142. groups=1,
  143. activation="ReLU", ):
  144. super(CondenseSFR, self).__init__()
  145. self.in_channels = in_channels
  146. self.out_channels = out_channels
  147. self.groups = groups
  148. self.norm = nn.BatchNorm2D(self.in_channels)
  149. if activation == "ReLU":
  150. self.activation = nn.ReLU()
  151. elif activation == "HS":
  152. self.activation = HS()
  153. else:
  154. raise NotImplementedError
  155. self.conv = nn.Conv2D(
  156. self.in_channels,
  157. self.out_channels,
  158. kernel_size=kernel_size,
  159. padding=padding,
  160. groups=self.groups,
  161. bias_attr=False,
  162. stride=stride, )
  163. self.register_buffer("index",
  164. paddle.zeros(
  165. (self.out_channels, self.out_channels)))
  166. def forward(self, x):
  167. x = self.norm(x)
  168. x = self.activation(x)
  169. x = ShuffleLayerTrans(x, self.groups)
  170. x = self.conv(x) # SIZE: N, C, H, W
  171. N, C, H, W = x.shape
  172. x = x.reshape((N, C, H * W))
  173. x = x.transpose((0, 2, 1)) # SIZE: N, HW, C
  174. # x SIZE: N, HW, C; self.index SIZE: C, C; OUTPUT SIZE: N, HW, C
  175. x = paddle.matmul(x, self.index)
  176. x = x.transpose((0, 2, 1)) # SIZE: N, C, HW
  177. x = x.reshape((N, C, H, W)) # SIZE: N, C, HW
  178. return x
  179. class _SFR_DenseLayer(nn.Layer):
  180. def __init__(
  181. self,
  182. in_channels,
  183. growth_rate,
  184. group_1x1,
  185. group_3x3,
  186. group_trans,
  187. bottleneck,
  188. activation,
  189. use_se=False, ):
  190. super(_SFR_DenseLayer, self).__init__()
  191. self.group_1x1 = group_1x1
  192. self.group_3x3 = group_3x3
  193. self.group_trans = group_trans
  194. self.use_se = use_se
  195. # 1x1 conv i --> b*k
  196. self.conv_1 = CondenseLGC(
  197. in_channels,
  198. bottleneck * growth_rate,
  199. kernel_size=1,
  200. groups=self.group_1x1,
  201. activation=activation, )
  202. # 3x3 conv b*k --> k
  203. self.conv_2 = Conv(
  204. bottleneck * growth_rate,
  205. growth_rate,
  206. kernel_size=3,
  207. padding=1,
  208. groups=self.group_3x3,
  209. activation=activation, )
  210. # 1x1 res conv k(8-16-32)--> i (k*l)
  211. self.sfr = CondenseSFR(
  212. growth_rate,
  213. in_channels,
  214. kernel_size=1,
  215. groups=self.group_trans,
  216. activation=activation, )
  217. if self.use_se:
  218. self.se = SELayer(inplanes=growth_rate, reduction=1)
  219. def forward(self, x):
  220. x_ = x
  221. x = self.conv_1(x)
  222. x = self.conv_2(x)
  223. if self.use_se:
  224. x = self.se(x)
  225. sfr_feature = self.sfr(x)
  226. y = x_ + sfr_feature
  227. return paddle.concat([y, x], 1)
  228. class _SFR_DenseBlock(nn.Sequential):
  229. def __init__(
  230. self,
  231. num_layers,
  232. in_channels,
  233. growth_rate,
  234. group_1x1,
  235. group_3x3,
  236. group_trans,
  237. bottleneck,
  238. activation,
  239. use_se, ):
  240. super(_SFR_DenseBlock, self).__init__()
  241. for i in range(num_layers):
  242. layer = _SFR_DenseLayer(
  243. in_channels + i * growth_rate,
  244. growth_rate,
  245. group_1x1,
  246. group_3x3,
  247. group_trans,
  248. bottleneck,
  249. activation,
  250. use_se, )
  251. self.add_sublayer("denselayer_%d" % (i + 1), layer)
  252. class _Transition(nn.Layer):
  253. def __init__(self):
  254. super(_Transition, self).__init__()
  255. self.pool = nn.AvgPool2D(kernel_size=2, stride=2)
  256. def forward(self, x):
  257. x = self.pool(x)
  258. return x
  259. class CondenseNetV2(nn.Layer):
  260. def __init__(
  261. self,
  262. stages,
  263. growth,
  264. HS_start_block,
  265. SE_start_block,
  266. fc_channel,
  267. group_1x1,
  268. group_3x3,
  269. group_trans,
  270. bottleneck,
  271. last_se_reduction,
  272. in_channels=3,
  273. class_num=1000, ):
  274. super(CondenseNetV2, self).__init__()
  275. self.stages = stages
  276. self.growth = growth
  277. self.in_channels = in_channels
  278. self.class_num = class_num
  279. self.last_se_reduction = last_se_reduction
  280. assert len(self.stages) == len(self.growth)
  281. self.progress = 0.0
  282. self.init_stride = 2
  283. self.pool_size = 7
  284. self.features = nn.Sequential()
  285. # Initial nChannels should be 3
  286. self.num_features = 2 * self.growth[0]
  287. # Dense-block 1 (224x224)
  288. self.features.add_sublayer(
  289. "init_conv",
  290. nn.Conv2D(
  291. in_channels,
  292. self.num_features,
  293. kernel_size=3,
  294. stride=self.init_stride,
  295. padding=1,
  296. bias_attr=False, ), )
  297. for i in range(len(self.stages)):
  298. activation = "HS" if i >= HS_start_block else "ReLU"
  299. use_se = True if i >= SE_start_block else False
  300. # Dense-block i
  301. self.add_block(i, group_1x1, group_3x3, group_trans, bottleneck,
  302. activation, use_se)
  303. self.fc = nn.Linear(self.num_features, fc_channel)
  304. self.fc_act = HS()
  305. # Classifier layer
  306. if class_num > 0:
  307. self.classifier = nn.Linear(fc_channel, class_num)
  308. self._initialize()
  309. def add_block(self, i, group_1x1, group_3x3, group_trans, bottleneck,
  310. activation, use_se):
  311. # Check if ith is the last one
  312. last = i == len(self.stages) - 1
  313. block = _SFR_DenseBlock(
  314. num_layers=self.stages[i],
  315. in_channels=self.num_features,
  316. growth_rate=self.growth[i],
  317. group_1x1=group_1x1,
  318. group_3x3=group_3x3,
  319. group_trans=group_trans,
  320. bottleneck=bottleneck,
  321. activation=activation,
  322. use_se=use_se, )
  323. self.features.add_sublayer("denseblock_%d" % (i + 1), block)
  324. self.num_features += self.stages[i] * self.growth[i]
  325. if not last:
  326. trans = _Transition()
  327. self.features.add_sublayer("transition_%d" % (i + 1), trans)
  328. else:
  329. self.features.add_sublayer("norm_last",
  330. nn.BatchNorm2D(self.num_features))
  331. self.features.add_sublayer("relu_last", nn.ReLU())
  332. self.features.add_sublayer("pool_last",
  333. nn.AvgPool2D(self.pool_size))
  334. # if useSE:
  335. self.features.add_sublayer(
  336. "se_last",
  337. SELayer(
  338. self.num_features, reduction=self.last_se_reduction))
  339. def forward(self, x):
  340. features = self.features(x)
  341. out = features.reshape((features.shape[0], -1))
  342. out = self.fc(out)
  343. out = self.fc_act(out)
  344. if self.class_num > 0:
  345. out = self.classifier(out)
  346. return out
  347. def _initialize(self):
  348. # initialize
  349. for m in self.sublayers():
  350. if isinstance(m, nn.Conv2D):
  351. nn.initializer.KaimingNormal()(m.weight)
  352. elif isinstance(m, nn.BatchNorm2D):
  353. nn.initializer.Constant(value=1.0)(m.weight)
  354. nn.initializer.Constant(value=0.0)(m.bias)
  355. def CondenseNetV2_a(**kwargs):
  356. model = CondenseNetV2(
  357. stages=[1, 1, 4, 6, 8],
  358. growth=[8, 8, 16, 32, 64],
  359. HS_start_block=2,
  360. SE_start_block=3,
  361. fc_channel=828,
  362. group_1x1=8,
  363. group_3x3=8,
  364. group_trans=8,
  365. bottleneck=4,
  366. last_se_reduction=16,
  367. **kwargs)
  368. return model
  369. def CondenseNetV2_b(**kwargs):
  370. model = CondenseNetV2(
  371. stages=[2, 4, 6, 8, 6],
  372. growth=[6, 12, 24, 48, 96],
  373. HS_start_block=2,
  374. SE_start_block=3,
  375. fc_channel=1024,
  376. group_1x1=6,
  377. group_3x3=6,
  378. group_trans=6,
  379. bottleneck=4,
  380. last_se_reduction=16,
  381. **kwargs)
  382. return model
  383. def CondenseNetV2_c(**kwargs):
  384. model = CondenseNetV2(
  385. stages=[4, 6, 8, 10, 8],
  386. growth=[8, 16, 32, 64, 128],
  387. HS_start_block=2,
  388. SE_start_block=3,
  389. fc_channel=1024,
  390. group_1x1=8,
  391. group_3x3=8,
  392. group_trans=8,
  393. bottleneck=4,
  394. last_se_reduction=16,
  395. **kwargs)
  396. return model