meta_arch.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import numpy as np
  5. import paddle
  6. import paddle.nn as nn
  7. import typing
  8. from paddlers.models.ppdet.core.workspace import register
  9. from paddlers.models.ppdet.modeling.post_process import nms
  10. __all__ = ['BaseArch']
  11. @register
  12. class BaseArch(nn.Layer):
  13. def __init__(self, data_format='NCHW'):
  14. super(BaseArch, self).__init__()
  15. self.data_format = data_format
  16. self.inputs = {}
  17. self.fuse_norm = False
  18. def load_meanstd(self, cfg_transform):
  19. self.scale = 1.
  20. self.mean = paddle.to_tensor([0.485, 0.456, 0.406]).reshape(
  21. (1, 3, 1, 1))
  22. self.std = paddle.to_tensor([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))
  23. for item in cfg_transform:
  24. if 'NormalizeImage' in item:
  25. self.mean = paddle.to_tensor(item['NormalizeImage'][
  26. 'mean']).reshape((1, 3, 1, 1))
  27. self.std = paddle.to_tensor(item['NormalizeImage'][
  28. 'std']).reshape((1, 3, 1, 1))
  29. if item['NormalizeImage'].get('is_scale', True):
  30. self.scale = 1. / 255.
  31. break
  32. if self.data_format == 'NHWC':
  33. self.mean = self.mean.reshape(1, 1, 1, 3)
  34. self.std = self.std.reshape(1, 1, 1, 3)
  35. def forward(self, inputs):
  36. if self.data_format == 'NHWC':
  37. image = inputs['image']
  38. inputs['image'] = paddle.transpose(image, [0, 2, 3, 1])
  39. if self.fuse_norm:
  40. image = inputs['image']
  41. self.inputs['image'] = (image * self.scale - self.mean) / self.std
  42. self.inputs['im_shape'] = inputs['im_shape']
  43. self.inputs['scale_factor'] = inputs['scale_factor']
  44. else:
  45. self.inputs = inputs
  46. self.model_arch()
  47. if self.training:
  48. out = self.get_loss()
  49. else:
  50. inputs_list = []
  51. # multi-scale input
  52. if not isinstance(inputs, typing.Sequence):
  53. inputs_list.append(inputs)
  54. else:
  55. inputs_list.extend(inputs)
  56. outs = []
  57. for inp in inputs_list:
  58. self.inputs = inp
  59. outs.append(self.get_pred())
  60. # multi-scale test
  61. if len(outs) > 1:
  62. out = self.merge_multi_scale_predictions(outs)
  63. else:
  64. out = outs[0]
  65. return out
  66. def merge_multi_scale_predictions(self, outs):
  67. # default values for architectures not included in following list
  68. num_classes = 80
  69. nms_threshold = 0.5
  70. keep_top_k = 100
  71. if self.__class__.__name__ in ('CascadeRCNN', 'FasterRCNN', 'MaskRCNN'):
  72. num_classes = self.bbox_head.num_classes
  73. keep_top_k = self.bbox_post_process.nms.keep_top_k
  74. nms_threshold = self.bbox_post_process.nms.nms_threshold
  75. else:
  76. raise Exception(
  77. "Multi scale test only supports CascadeRCNN, FasterRCNN and MaskRCNN for now"
  78. )
  79. final_boxes = []
  80. all_scale_outs = paddle.concat([o['bbox'] for o in outs]).numpy()
  81. for c in range(num_classes):
  82. idxs = all_scale_outs[:, 0] == c
  83. if np.count_nonzero(idxs) == 0:
  84. continue
  85. r = nms(all_scale_outs[idxs, 1:], nms_threshold)
  86. final_boxes.append(
  87. np.concatenate([np.full((r.shape[0], 1), c), r], 1))
  88. out = np.concatenate(final_boxes)
  89. out = np.concatenate(sorted(
  90. out, key=lambda e: e[1])[-keep_top_k:]).reshape((-1, 6))
  91. out = {
  92. 'bbox': paddle.to_tensor(out),
  93. 'bbox_num': paddle.to_tensor(np.array([out.shape[0], ]))
  94. }
  95. return out
  96. def build_inputs(self, data, input_def):
  97. inputs = {}
  98. for i, k in enumerate(input_def):
  99. inputs[k] = data[i]
  100. return inputs
  101. def model_arch(self, ):
  102. pass
  103. def get_loss(self, ):
  104. raise NotImplementedError("Should implement get_loss method!")
  105. def get_pred(self, ):
  106. raise NotImplementedError("Should implement get_pred method!")
  107. @classmethod
  108. def convert_sync_batchnorm(cls, layer):
  109. layer_output = layer
  110. if getattr(layer, 'norm_type', None) == 'sync_bn':
  111. layer_output = nn.SyncBatchNorm.convert_sync_batchnorm(layer)
  112. else:
  113. for name, sublayer in layer.named_children():
  114. layer_output.add_sublayer(name,
  115. cls.convert_sync_batchnorm(sublayer))
  116. del layer
  117. return layer_output