visualize_feats.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. #!/usr/bin/env python
  2. import argparse
  3. import sys
  4. import os
  5. import os.path as osp
  6. from collections import OrderedDict
  7. import numpy as np
  8. import cv2
  9. import paddle
  10. import paddlers
  11. from sklearn.decomposition import PCA
  12. _dir = osp.dirname(osp.abspath(__file__))
  13. sys.path.append(osp.abspath(osp.join(_dir, '../')))
  14. import custom_model
  15. import custom_trainer
  16. FILENAME_PATTERN = "{key}_{idx}_vis.png"
  17. class FeatureContainer:
  18. def __init__(self):
  19. self._dict = OrderedDict()
  20. def __setitem__(self, key, val):
  21. if key not in self._dict:
  22. self._dict[key] = list()
  23. self._dict[key].append(val)
  24. def __getitem__(self, key):
  25. return self._dict[key]
  26. def __repr__(self):
  27. return self._dict.__repr__()
  28. def items(self):
  29. return self._dict.items()
  30. def keys(self):
  31. return self._dict.keys()
  32. def values(self):
  33. return self._dict.values()
  34. class HookHelper:
  35. def __init__(self,
  36. model,
  37. fetch_dict,
  38. out_dict,
  39. hook_type='forward_out',
  40. auto_key=True):
  41. # XXX: A HookHelper object should only be used as a context manager and should not
  42. # persist in memory since it may keep references to some very large objects.
  43. self.model = model
  44. self.fetch_dict = fetch_dict
  45. self.out_dict = out_dict
  46. self._handles = []
  47. self.hook_type = hook_type
  48. self.auto_key = auto_key
  49. def __enter__(self):
  50. def _hook_proto(x, entry):
  51. # `x` should be a tensor or a tuple;
  52. # entry is expected to be a string or a non-nested tuple.
  53. if isinstance(entry, tuple):
  54. for key, f in zip(entry, x):
  55. self.out_dict[key] = f.detach().clone()
  56. else:
  57. if isinstance(x, tuple) and self.auto_key:
  58. for i, f in enumerate(x):
  59. key = self._gen_key(entry, i)
  60. self.out_dict[key] = f.detach().clone()
  61. else:
  62. self.out_dict[entry] = x.detach().clone()
  63. if self.hook_type == 'forward_in':
  64. # NOTE: Register forward hooks for LAYERs
  65. for name, layer in self.model.named_sublayers():
  66. if name in self.fetch_dict:
  67. entry = self.fetch_dict[name]
  68. self._handles.append(
  69. layer.register_forward_pre_hook(
  70. lambda l, x, entry=entry:
  71. # x is a tuple
  72. _hook_proto(x[0] if len(x)==1 else x, entry)
  73. )
  74. )
  75. elif self.hook_type == 'forward_out':
  76. # NOTE: Register forward hooks for LAYERs.
  77. for name, module in self.model.named_sublayers():
  78. if name in self.fetch_dict:
  79. entry = self.fetch_dict[name]
  80. self._handles.append(
  81. module.register_forward_post_hook(
  82. lambda l, x, y, entry=entry:
  83. # y is a tensor or a tuple
  84. _hook_proto(y, entry)
  85. )
  86. )
  87. elif self.hook_type == 'backward':
  88. # NOTE: Register backward hooks for TENSORs.
  89. for name, param in self.model.named_parameters():
  90. if name in self.fetch_dict:
  91. entry = self.fetch_dict[name]
  92. self._handles.append(
  93. param.register_hook(
  94. lambda grad, entry=entry: _hook_proto(grad, entry)))
  95. else:
  96. raise RuntimeError("Hook type is not implemented.")
  97. def __exit__(self, exc_type, exc_val, ext_tb):
  98. for handle in self._handles:
  99. handle.remove()
  100. def _gen_key(self, key, i):
  101. return key + f'_{i}'
  102. def parse_args():
  103. parser = argparse.ArgumentParser()
  104. parser.add_argument(
  105. "--model_dir", default=None, type=str, help="Path of saved model.")
  106. parser.add_argument(
  107. "--hook_type", default='forward_out', type=str, help="Type of hook.")
  108. parser.add_argument(
  109. "--layer_names",
  110. nargs='+',
  111. default=[],
  112. type=str,
  113. help="Layers that accepts or produces the features to visualize.")
  114. parser.add_argument(
  115. "--im_paths", nargs='+', type=str, help="Paths of input images.")
  116. parser.add_argument(
  117. "--save_dir",
  118. type=str,
  119. help="Path of directory to save prediction results.")
  120. parser.add_argument(
  121. "--to_pseudo_color",
  122. action='store_true',
  123. help="Whether to save pseudo-color images.")
  124. parser.add_argument(
  125. "--output_size",
  126. nargs='+',
  127. type=int,
  128. default=None,
  129. help="Resize the visualized image to `output_size`.")
  130. return parser.parse_args()
  131. def normalize_minmax(x):
  132. EPS = 1e-32
  133. return (x - x.min()) / (x.max() - x.min() + EPS)
  134. def quantize_8bit(x):
  135. # [0.0,1.0] float => [0,255] uint8
  136. # or [0,1] int => [0,255] uint8
  137. return (x * 255).astype('uint8')
  138. def to_pseudo_color(gray, color_map=cv2.COLORMAP_JET):
  139. return cv2.applyColorMap(gray, color_map)
  140. def process_fetched_feat(feat, to_pcolor=True):
  141. # Convert tensor to array
  142. feat = feat.squeeze(0).numpy()
  143. # Get principal component
  144. shape = feat.shape
  145. x = feat.reshape(shape[0], -1).transpose((1, 0))
  146. pca = PCA(n_components=1)
  147. y = pca.fit_transform(x)
  148. feat = y.reshape(shape[1:])
  149. feat = normalize_minmax(feat)
  150. feat = quantize_8bit(feat)
  151. if to_pcolor:
  152. feat = to_pseudo_color(feat)
  153. return feat
  154. if __name__ == '__main__':
  155. args = parse_args()
  156. # Load model
  157. model = paddlers.tasks.load_model(args.model_dir)
  158. fetch_dict = dict(zip(args.layer_names, args.layer_names))
  159. out_dict = FeatureContainer()
  160. with HookHelper(model.net, fetch_dict, out_dict, hook_type=args.hook_type):
  161. if len(args.im_paths) == 1:
  162. model.predict(args.im_paths[0])
  163. else:
  164. if len(args.im_paths) != 2:
  165. raise ValueError
  166. model.predict(tuple(args.im_paths))
  167. if not osp.exists(args.save_dir):
  168. os.makedirs(args.save_dir)
  169. for key, feats in out_dict.items():
  170. for idx, feat in enumerate(feats):
  171. im_vis = process_fetched_feat(feat, to_pcolor=args.to_pseudo_color)
  172. if args.output_size is not None:
  173. im_vis = cv2.resize(im_vis, tuple(args.output_size))
  174. out_path = osp.join(
  175. args.save_dir,
  176. FILENAME_PATTERN.format(
  177. key=key.replace('.', '_'), idx=idx))
  178. cv2.imwrite(out_path, im_vis)
  179. print(f"Write feature map to {out_path}")