resnet_vd.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddlers.models.ppseg.cvlibs import manager
  18. from paddlers.models.ppseg.models import layers
  19. from paddlers.models.ppseg.utils import utils
  20. __all__ = [
  21. "ResNet18_vd", "ResNet34_vd", "ResNet50_vd", "ResNet101_vd", "ResNet152_vd"
  22. ]
  23. class ConvBNLayer(nn.Layer):
  24. def __init__(self,
  25. in_channels,
  26. out_channels,
  27. kernel_size,
  28. stride=1,
  29. dilation=1,
  30. groups=1,
  31. is_vd_mode=False,
  32. act=None,
  33. data_format='NCHW'):
  34. super(ConvBNLayer, self).__init__()
  35. if dilation != 1 and kernel_size != 3:
  36. raise RuntimeError("When the dilation isn't 1," \
  37. "the kernel_size should be 3.")
  38. self.is_vd_mode = is_vd_mode
  39. self._pool2d_avg = nn.AvgPool2D(
  40. kernel_size=2,
  41. stride=2,
  42. padding=0,
  43. ceil_mode=True,
  44. data_format=data_format)
  45. self._conv = nn.Conv2D(
  46. in_channels=in_channels,
  47. out_channels=out_channels,
  48. kernel_size=kernel_size,
  49. stride=stride,
  50. padding=(kernel_size - 1) // 2 \
  51. if dilation == 1 else dilation,
  52. dilation=dilation,
  53. groups=groups,
  54. bias_attr=False,
  55. data_format=data_format)
  56. self._batch_norm = layers.SyncBatchNorm(
  57. out_channels, data_format=data_format)
  58. self._act_op = layers.Activation(act=act)
  59. def forward(self, inputs):
  60. if self.is_vd_mode:
  61. inputs = self._pool2d_avg(inputs)
  62. y = self._conv(inputs)
  63. y = self._batch_norm(y)
  64. y = self._act_op(y)
  65. return y
  66. class BottleneckBlock(nn.Layer):
  67. def __init__(self,
  68. in_channels,
  69. out_channels,
  70. stride,
  71. shortcut=True,
  72. if_first=False,
  73. dilation=1,
  74. data_format='NCHW'):
  75. super(BottleneckBlock, self).__init__()
  76. self.data_format = data_format
  77. self.conv0 = ConvBNLayer(
  78. in_channels=in_channels,
  79. out_channels=out_channels,
  80. kernel_size=1,
  81. act='relu',
  82. data_format=data_format)
  83. self.dilation = dilation
  84. self.conv1 = ConvBNLayer(
  85. in_channels=out_channels,
  86. out_channels=out_channels,
  87. kernel_size=3,
  88. stride=stride,
  89. act='relu',
  90. dilation=dilation,
  91. data_format=data_format)
  92. self.conv2 = ConvBNLayer(
  93. in_channels=out_channels,
  94. out_channels=out_channels * 4,
  95. kernel_size=1,
  96. act=None,
  97. data_format=data_format)
  98. if not shortcut:
  99. self.short = ConvBNLayer(
  100. in_channels=in_channels,
  101. out_channels=out_channels * 4,
  102. kernel_size=1,
  103. stride=1,
  104. is_vd_mode=False if if_first or stride == 1 else True,
  105. data_format=data_format)
  106. self.shortcut = shortcut
  107. # NOTE: Use the wrap layer for quantization training
  108. self.add = layers.Add()
  109. self.relu = layers.Activation(act="relu")
  110. def forward(self, inputs):
  111. y = self.conv0(inputs)
  112. conv1 = self.conv1(y)
  113. conv2 = self.conv2(conv1)
  114. if self.shortcut:
  115. short = inputs
  116. else:
  117. short = self.short(inputs)
  118. y = self.add(short, conv2)
  119. y = self.relu(y)
  120. return y
  121. class BasicBlock(nn.Layer):
  122. def __init__(self,
  123. in_channels,
  124. out_channels,
  125. stride,
  126. dilation=1,
  127. shortcut=True,
  128. if_first=False,
  129. data_format='NCHW'):
  130. super(BasicBlock, self).__init__()
  131. self.conv0 = ConvBNLayer(
  132. in_channels=in_channels,
  133. out_channels=out_channels,
  134. kernel_size=3,
  135. stride=stride,
  136. dilation=dilation,
  137. act='relu',
  138. data_format=data_format)
  139. self.conv1 = ConvBNLayer(
  140. in_channels=out_channels,
  141. out_channels=out_channels,
  142. kernel_size=3,
  143. dilation=dilation,
  144. act=None,
  145. data_format=data_format)
  146. if not shortcut:
  147. self.short = ConvBNLayer(
  148. in_channels=in_channels,
  149. out_channels=out_channels,
  150. kernel_size=1,
  151. stride=1,
  152. is_vd_mode=False if if_first or stride == 1 else True,
  153. data_format=data_format)
  154. self.shortcut = shortcut
  155. self.dilation = dilation
  156. self.data_format = data_format
  157. self.add = layers.Add()
  158. self.relu = layers.Activation(act="relu")
  159. def forward(self, inputs):
  160. y = self.conv0(inputs)
  161. conv1 = self.conv1(y)
  162. if self.shortcut:
  163. short = inputs
  164. else:
  165. short = self.short(inputs)
  166. y = self.add(short, conv1)
  167. y = self.relu(y)
  168. return y
  169. class ResNet_vd(nn.Layer):
  170. """
  171. The ResNet_vd implementation based on PaddlePaddle.
  172. The original article refers to Jingdong
  173. Tong He, et, al. "Bag of Tricks for Image Classification with Convolutional Neural Networks"
  174. (https://arxiv.org/pdf/1812.01187.pdf).
  175. Args:
  176. layers (int, optional): The layers of ResNet_vd. The supported layers are (18, 34, 50, 101, 152, 200). Default: 50.
  177. output_stride (int, optional): The stride of output features compared to input images. It is 8 or 16. Default: 8.
  178. multi_grid (tuple|list, optional): The grid of stage4. Defult: (1, 1, 1).
  179. in_channels (int, optional): The channels of input image. Default: 3.
  180. pretrained (str, optional): The path of pretrained model.
  181. """
  182. def __init__(self,
  183. layers=50,
  184. output_stride=8,
  185. multi_grid=(1, 1, 1),
  186. in_channels=3,
  187. pretrained=None,
  188. data_format='NCHW'):
  189. super(ResNet_vd, self).__init__()
  190. self.data_format = data_format
  191. self.conv1_logit = None # for gscnn shape stream
  192. self.layers = layers
  193. supported_layers = [18, 34, 50, 101, 152, 200]
  194. assert layers in supported_layers, \
  195. "supported layers are {} but input layer is {}".format(
  196. supported_layers, layers)
  197. if layers == 18:
  198. depth = [2, 2, 2, 2]
  199. elif layers == 34 or layers == 50:
  200. depth = [3, 4, 6, 3]
  201. elif layers == 101:
  202. depth = [3, 4, 23, 3]
  203. elif layers == 152:
  204. depth = [3, 8, 36, 3]
  205. elif layers == 200:
  206. depth = [3, 12, 48, 3]
  207. num_channels = [64, 256, 512,
  208. 1024] if layers >= 50 else [64, 64, 128, 256]
  209. num_filters = [64, 128, 256, 512]
  210. # for channels of four returned stages
  211. self.feat_channels = [c * 4 for c in num_filters
  212. ] if layers >= 50 else num_filters
  213. dilation_dict = None
  214. if output_stride == 8:
  215. dilation_dict = {2: 2, 3: 4}
  216. elif output_stride == 16:
  217. dilation_dict = {3: 2}
  218. self.conv1_1 = ConvBNLayer(
  219. in_channels=in_channels,
  220. out_channels=32,
  221. kernel_size=3,
  222. stride=2,
  223. act='relu',
  224. data_format=data_format)
  225. self.conv1_2 = ConvBNLayer(
  226. in_channels=32,
  227. out_channels=32,
  228. kernel_size=3,
  229. stride=1,
  230. act='relu',
  231. data_format=data_format)
  232. self.conv1_3 = ConvBNLayer(
  233. in_channels=32,
  234. out_channels=64,
  235. kernel_size=3,
  236. stride=1,
  237. act='relu',
  238. data_format=data_format)
  239. self.pool2d_max = nn.MaxPool2D(
  240. kernel_size=3, stride=2, padding=1, data_format=data_format)
  241. # self.block_list = []
  242. self.stage_list = []
  243. if layers >= 50:
  244. for block in range(len(depth)):
  245. shortcut = False
  246. block_list = []
  247. for i in range(depth[block]):
  248. if layers in [101, 152] and block == 2:
  249. if i == 0:
  250. conv_name = "res" + str(block + 2) + "a"
  251. else:
  252. conv_name = "res" + str(block + 2) + "b" + str(i)
  253. else:
  254. conv_name = "res" + str(block + 2) + chr(97 + i)
  255. ###############################################################################
  256. # Add dilation rate for some segmentation tasks, if dilation_dict is not None.
  257. dilation_rate = dilation_dict[
  258. block] if dilation_dict and block in dilation_dict else 1
  259. # Actually block here is 'stage', and i is 'block' in 'stage'
  260. # At the stage 4, expand the the dilation_rate if given multi_grid
  261. if block == 3:
  262. dilation_rate = dilation_rate * multi_grid[i]
  263. ###############################################################################
  264. bottleneck_block = self.add_sublayer(
  265. 'bb_%d_%d' % (block, i),
  266. BottleneckBlock(
  267. in_channels=num_channels[block]
  268. if i == 0 else num_filters[block] * 4,
  269. out_channels=num_filters[block],
  270. stride=2 if i == 0 and block != 0 and
  271. dilation_rate == 1 else 1,
  272. shortcut=shortcut,
  273. if_first=block == i == 0,
  274. dilation=dilation_rate,
  275. data_format=data_format))
  276. block_list.append(bottleneck_block)
  277. shortcut = True
  278. self.stage_list.append(block_list)
  279. else:
  280. for block in range(len(depth)):
  281. shortcut = False
  282. block_list = []
  283. for i in range(depth[block]):
  284. dilation_rate = dilation_dict[block] \
  285. if dilation_dict and block in dilation_dict else 1
  286. if block == 3:
  287. dilation_rate = dilation_rate * multi_grid[i]
  288. basic_block = self.add_sublayer(
  289. 'bb_%d_%d' % (block, i),
  290. BasicBlock(
  291. in_channels=num_channels[block]
  292. if i == 0 else num_filters[block],
  293. out_channels=num_filters[block],
  294. stride=2 if i == 0 and block != 0 \
  295. and dilation_rate == 1 else 1,
  296. dilation=dilation_rate,
  297. shortcut=shortcut,
  298. if_first=block == i == 0,
  299. data_format=data_format))
  300. block_list.append(basic_block)
  301. shortcut = True
  302. self.stage_list.append(block_list)
  303. self.pretrained = pretrained
  304. self.init_weight()
  305. def forward(self, inputs):
  306. y = self.conv1_1(inputs)
  307. y = self.conv1_2(y)
  308. y = self.conv1_3(y)
  309. self.conv1_logit = y.clone()
  310. y = self.pool2d_max(y)
  311. # A feature list saves the output feature map of each stage.
  312. feat_list = []
  313. for stage in self.stage_list:
  314. for block in stage:
  315. y = block(y)
  316. feat_list.append(y)
  317. return feat_list
  318. def init_weight(self):
  319. utils.load_pretrained_model(self, self.pretrained)
  320. @manager.BACKBONES.add_component
  321. def ResNet18_vd(**args):
  322. model = ResNet_vd(layers=18, **args)
  323. return model
  324. def ResNet34_vd(**args):
  325. model = ResNet_vd(layers=34, **args)
  326. return model
  327. @manager.BACKBONES.add_component
  328. def ResNet50_vd(**args):
  329. model = ResNet_vd(layers=50, **args)
  330. return model
  331. @manager.BACKBONES.add_component
  332. def ResNet101_vd(**args):
  333. model = ResNet_vd(layers=101, **args)
  334. return model
  335. def ResNet152_vd(**args):
  336. model = ResNet_vd(layers=152, **args)
  337. return model
  338. def ResNet200_vd(**args):
  339. model = ResNet_vd(layers=200, **args)
  340. return model