visualize_feats.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. #!/usr/bin/env python
  2. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import sys
  17. import os
  18. import os.path as osp
  19. from collections import OrderedDict
  20. import numpy as np
  21. import cv2
  22. import paddle
  23. import paddlers
  24. from sklearn.decomposition import PCA
  25. _dir = osp.dirname(osp.abspath(__file__))
  26. sys.path.append(osp.abspath(osp.join(_dir, '../')))
  27. import bootstrap
  28. FILENAME_PATTERN = "{key}_{idx}_vis.png"
  29. class FeatureContainer:
  30. def __init__(self):
  31. self._dict = OrderedDict()
  32. def __setitem__(self, key, val):
  33. if key not in self._dict:
  34. self._dict[key] = list()
  35. self._dict[key].append(val)
  36. def __getitem__(self, key):
  37. return self._dict[key]
  38. def __repr__(self):
  39. return self._dict.__repr__()
  40. def items(self):
  41. return self._dict.items()
  42. def keys(self):
  43. return self._dict.keys()
  44. def values(self):
  45. return self._dict.values()
  46. class HookHelper:
  47. def __init__(self,
  48. model,
  49. fetch_dict,
  50. out_dict,
  51. hook_type='forward_out',
  52. auto_key=True):
  53. # XXX: A HookHelper object should only be used as a context manager and should not
  54. # persist in memory since it may keep references to some very large objects.
  55. self.model = model
  56. self.fetch_dict = fetch_dict
  57. self.out_dict = out_dict
  58. self._handles = []
  59. self.hook_type = hook_type
  60. self.auto_key = auto_key
  61. def __enter__(self):
  62. def _hook_proto(x, entry):
  63. # `x` should be a tensor or a tuple;
  64. # entry is expected to be a string or a non-nested tuple.
  65. if isinstance(entry, tuple):
  66. for key, f in zip(entry, x):
  67. self.out_dict[key] = f.detach().clone()
  68. else:
  69. if isinstance(x, tuple) and self.auto_key:
  70. for i, f in enumerate(x):
  71. key = self._gen_key(entry, i)
  72. self.out_dict[key] = f.detach().clone()
  73. else:
  74. self.out_dict[entry] = x.detach().clone()
  75. if self.hook_type == 'forward_in':
  76. # NOTE: Register forward hooks for LAYERs
  77. for name, layer in self.model.named_sublayers():
  78. if name in self.fetch_dict:
  79. entry = self.fetch_dict[name]
  80. self._handles.append(
  81. layer.register_forward_pre_hook(
  82. lambda l, x, entry=entry:
  83. # x is a tuple
  84. _hook_proto(x[0] if len(x)==1 else x, entry)
  85. )
  86. )
  87. elif self.hook_type == 'forward_out':
  88. # NOTE: Register forward hooks for LAYERs.
  89. for name, module in self.model.named_sublayers():
  90. if name in self.fetch_dict:
  91. entry = self.fetch_dict[name]
  92. self._handles.append(
  93. module.register_forward_post_hook(
  94. lambda l, x, y, entry=entry:
  95. # y is a tensor or a tuple
  96. _hook_proto(y, entry)
  97. )
  98. )
  99. elif self.hook_type == 'backward':
  100. # NOTE: Register backward hooks for TENSORs.
  101. for name, param in self.model.named_parameters():
  102. if name in self.fetch_dict:
  103. entry = self.fetch_dict[name]
  104. self._handles.append(
  105. param.register_hook(
  106. lambda grad, entry=entry: _hook_proto(grad, entry)))
  107. else:
  108. raise RuntimeError("Hook type is not implemented.")
  109. def __exit__(self, exc_type, exc_val, ext_tb):
  110. for handle in self._handles:
  111. handle.remove()
  112. def _gen_key(self, key, i):
  113. return key + f'_{i}'
  114. def parse_args():
  115. parser = argparse.ArgumentParser()
  116. parser.add_argument(
  117. "--model_dir", default=None, type=str, help="Path of saved model.")
  118. parser.add_argument(
  119. "--hook_type", default='forward_out', type=str, help="Type of hook.")
  120. parser.add_argument(
  121. "--layer_names",
  122. nargs='+',
  123. default=[],
  124. type=str,
  125. help="Layers that accepts or produces the features to visualize.")
  126. parser.add_argument(
  127. "--im_paths", nargs='+', type=str, help="Paths of input images.")
  128. parser.add_argument(
  129. "--save_dir",
  130. type=str,
  131. help="Path of directory to save prediction results.")
  132. parser.add_argument(
  133. "--to_pseudo_color",
  134. action='store_true',
  135. help="Whether to save pseudo-color images.")
  136. parser.add_argument(
  137. "--output_size",
  138. nargs='+',
  139. type=int,
  140. default=None,
  141. help="Resize the visualized image to `output_size`.")
  142. return parser.parse_args()
  143. def normalize_minmax(x):
  144. EPS = 1e-32
  145. return (x - x.min()) / (x.max() - x.min() + EPS)
  146. def quantize_8bit(x):
  147. # [0.0,1.0] float => [0,255] uint8
  148. # or [0,1] int => [0,255] uint8
  149. return (x * 255).astype('uint8')
  150. def to_pseudo_color(gray, color_map=cv2.COLORMAP_JET):
  151. return cv2.applyColorMap(gray, color_map)
  152. def process_fetched_feat(feat, to_pcolor=True):
  153. # Convert tensor to array
  154. feat = feat.squeeze(0).numpy()
  155. # Get principal component
  156. shape = feat.shape
  157. x = feat.reshape(shape[0], -1).transpose((1, 0))
  158. pca = PCA(n_components=1)
  159. y = pca.fit_transform(x)
  160. feat = y.reshape(shape[1:])
  161. feat = normalize_minmax(feat)
  162. feat = quantize_8bit(feat)
  163. if to_pcolor:
  164. feat = to_pseudo_color(feat)
  165. return feat
  166. if __name__ == '__main__':
  167. args = parse_args()
  168. # Load model
  169. model = paddlers.tasks.load_model(args.model_dir)
  170. fetch_dict = dict(zip(args.layer_names, args.layer_names))
  171. out_dict = FeatureContainer()
  172. with HookHelper(model.net, fetch_dict, out_dict, hook_type=args.hook_type):
  173. if len(args.im_paths) == 1:
  174. model.predict(args.im_paths[0])
  175. else:
  176. if len(args.im_paths) != 2:
  177. raise ValueError
  178. model.predict(tuple(args.im_paths))
  179. if not osp.exists(args.save_dir):
  180. os.makedirs(args.save_dir)
  181. for key, feats in out_dict.items():
  182. for idx, feat in enumerate(feats):
  183. im_vis = process_fetched_feat(feat, to_pcolor=args.to_pseudo_color)
  184. if args.output_size is not None:
  185. im_vis = cv2.resize(im_vis, tuple(args.output_size))
  186. out_path = osp.join(
  187. args.save_dir,
  188. FILENAME_PATTERN.format(
  189. key=key.replace('.', '_'), idx=idx))
  190. cv2.imwrite(out_path, im_vis)
  191. print(f"Write feature map to {out_path}")