visualize_feats.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. #!/usr/bin/env python
  2. import argparse
  3. import sys
  4. import os
  5. import os.path as osp
  6. from collections import OrderedDict
  7. import numpy as np
  8. import cv2
  9. import paddle
  10. import paddlers
  11. _dir = osp.dirname(osp.abspath(__file__))
  12. sys.path.append(osp.abspath(osp.join(_dir, '../')))
  13. import custom_model
  14. import custom_trainer
  15. FILENAME_PATTERN = "{key}_{idx}_vis.png"
  16. class FeatureContainer:
  17. def __init__(self):
  18. self._dict = OrderedDict()
  19. def __setitem__(self, key, val):
  20. if key not in self._dict:
  21. self._dict[key] = list()
  22. self._dict[key].append(val)
  23. def __getitem__(self, key):
  24. return self._dict[key]
  25. def __repr__(self):
  26. return self._dict.__repr__()
  27. def items(self):
  28. return self._dict.items()
  29. def keys(self):
  30. return self._dict.keys()
  31. def values(self):
  32. return self._dict.values()
  33. class HookHelper:
  34. def __init__(self, model, fetch_dict, out_dict, hook_type='forward_out'):
  35. # XXX: A HookHelper object should only be used as a context manager and should not
  36. # persist in memory since it may keep references to some very large objects.
  37. self.model = model
  38. self.fetch_dict = fetch_dict
  39. self.out_dict = out_dict
  40. self._handles = []
  41. self.hook_type = hook_type
  42. def __enter__(self):
  43. def _hook_proto(x, entry):
  44. # `x` should be a tensor or a tuple;
  45. # entry is expected to be a string or a non-nested tuple.
  46. if isinstance(entry, tuple):
  47. for key, f in zip(entry, x):
  48. self.out_dict[key] = f.detach().clone()
  49. else:
  50. self.out_dict[entry] = x.detach().clone()
  51. if self.hook_type == 'forward_in':
  52. # NOTE: Register forward hooks for LAYERs
  53. for name, layer in self.model.named_sublayers():
  54. if name in self.fetch_dict:
  55. entry = self.fetch_dict[name]
  56. self._handles.append(
  57. layer.register_forward_pre_hook(
  58. lambda l, x, entry=entry:
  59. # x is a tuple
  60. _hook_proto(x[0] if len(x)==1 else x, entry)
  61. )
  62. )
  63. elif self.hook_type == 'forward_out':
  64. # NOTE: Register forward hooks for LAYERs.
  65. for name, module in self.model.named_sublayers():
  66. if name in self.fetch_dict:
  67. entry = self.fetch_dict[name]
  68. self._handles.append(
  69. module.register_forward_post_hook(
  70. lambda l, x, y, entry=entry:
  71. # y is a tensor or a tuple
  72. _hook_proto(y, entry)
  73. )
  74. )
  75. elif self.hook_type == 'backward':
  76. # NOTE: Register backward hooks for TENSORs.
  77. for name, param in self.model.named_parameters():
  78. if name in self.fetch_dict:
  79. entry = self.fetch_dict[name]
  80. self._handles.append(
  81. param.register_hook(
  82. lambda grad, entry=entry: _hook_proto(grad, entry)))
  83. else:
  84. raise RuntimeError("Hook type is not implemented.")
  85. def __exit__(self, exc_type, exc_val, ext_tb):
  86. for handle in self._handles:
  87. handle.remove()
  88. def parse_args():
  89. parser = argparse.ArgumentParser()
  90. parser.add_argument(
  91. "--model_dir", default=None, type=str, help="Path of saved model.")
  92. parser.add_argument(
  93. "--hook_type", default='forward_out', type=str, help="Type of hook.")
  94. parser.add_argument(
  95. "--layer_names",
  96. nargs='+',
  97. default=[],
  98. type=str,
  99. help="Layers that accepts or produces the features to visualize.")
  100. parser.add_argument(
  101. "--im_paths", nargs='+', type=str, help="Paths of input images.")
  102. parser.add_argument(
  103. "--save_dir",
  104. type=str,
  105. help="Path of directory to save prediction results.")
  106. parser.add_argument(
  107. "--to_pseudo_color",
  108. action='store_true',
  109. help="Whether to save pseudo-color images.")
  110. parser.add_argument(
  111. "--output_size",
  112. nargs='+',
  113. type=int,
  114. default=None,
  115. help="Resize the visualized image to `output_size`.")
  116. return parser.parse_args()
  117. def normalize_minmax(x):
  118. EPS = 1e-32
  119. return (x - x.min()) / (x.max() - x.min() + EPS)
  120. def quantize_8bit(x):
  121. # [0.0,1.0] float => [0,255] uint8
  122. # or [0,1] int => [0,255] uint8
  123. return (x * 255).astype('uint8')
  124. def to_pseudo_color(gray, color_map=cv2.COLORMAP_JET):
  125. return cv2.applyColorMap(gray, color_map)
  126. def process_fetched_feat(feat, to_pcolor=True):
  127. # Convert tensor to array
  128. feat = feat.squeeze(0).numpy()
  129. # Average along channel dimension
  130. feat = normalize_minmax(feat.mean(0))
  131. feat = quantize_8bit(feat)
  132. if to_pcolor:
  133. feat = to_pseudo_color(feat)
  134. return feat
  135. if __name__ == '__main__':
  136. args = parse_args()
  137. # Load model
  138. model = paddlers.tasks.load_model(args.model_dir)
  139. fetch_dict = dict(zip(args.layer_names, args.layer_names))
  140. out_dict = FeatureContainer()
  141. with HookHelper(model.net, fetch_dict, out_dict, hook_type=args.hook_type):
  142. if len(args.im_paths) == 1:
  143. model.predict(args.im_paths[0])
  144. else:
  145. if len(args.im_paths) != 2:
  146. raise ValueError
  147. model.predict(tuple(args.im_paths))
  148. if not osp.exists(args.save_dir):
  149. os.makedirs(args.save_dir)
  150. for key, feats in out_dict.items():
  151. for idx, feat in enumerate(feats):
  152. im_vis = process_fetched_feat(feat, to_pcolor=args.to_pseudo_color)
  153. if args.output_size is not None:
  154. im_vis = cv2.resize(im_vis, tuple(args.output_size))
  155. out_path = osp.join(
  156. args.save_dir,
  157. FILENAME_PATTERN.format(
  158. key=key.replace('.', '_'), idx=idx))
  159. cv2.imwrite(out_path, im_vis)