visualize_feats.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. #!/usr/bin/env python
  2. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import sys
  17. import os
  18. import os.path as osp
  19. from collections import OrderedDict
  20. import numpy as np
  21. import cv2
  22. import paddle
  23. import paddlers
  24. from sklearn.decomposition import PCA
  25. _dir = osp.dirname(osp.abspath(__file__))
  26. sys.path.append(osp.abspath(osp.join(_dir, '../')))
  27. import custom_model
  28. import custom_trainer
  29. FILENAME_PATTERN = "{key}_{idx}_vis.png"
  30. class FeatureContainer:
  31. def __init__(self):
  32. self._dict = OrderedDict()
  33. def __setitem__(self, key, val):
  34. if key not in self._dict:
  35. self._dict[key] = list()
  36. self._dict[key].append(val)
  37. def __getitem__(self, key):
  38. return self._dict[key]
  39. def __repr__(self):
  40. return self._dict.__repr__()
  41. def items(self):
  42. return self._dict.items()
  43. def keys(self):
  44. return self._dict.keys()
  45. def values(self):
  46. return self._dict.values()
  47. class HookHelper:
  48. def __init__(self,
  49. model,
  50. fetch_dict,
  51. out_dict,
  52. hook_type='forward_out',
  53. auto_key=True):
  54. # XXX: A HookHelper object should only be used as a context manager and should not
  55. # persist in memory since it may keep references to some very large objects.
  56. self.model = model
  57. self.fetch_dict = fetch_dict
  58. self.out_dict = out_dict
  59. self._handles = []
  60. self.hook_type = hook_type
  61. self.auto_key = auto_key
  62. def __enter__(self):
  63. def _hook_proto(x, entry):
  64. # `x` should be a tensor or a tuple;
  65. # entry is expected to be a string or a non-nested tuple.
  66. if isinstance(entry, tuple):
  67. for key, f in zip(entry, x):
  68. self.out_dict[key] = f.detach().clone()
  69. else:
  70. if isinstance(x, tuple) and self.auto_key:
  71. for i, f in enumerate(x):
  72. key = self._gen_key(entry, i)
  73. self.out_dict[key] = f.detach().clone()
  74. else:
  75. self.out_dict[entry] = x.detach().clone()
  76. if self.hook_type == 'forward_in':
  77. # NOTE: Register forward hooks for LAYERs
  78. for name, layer in self.model.named_sublayers():
  79. if name in self.fetch_dict:
  80. entry = self.fetch_dict[name]
  81. self._handles.append(
  82. layer.register_forward_pre_hook(
  83. lambda l, x, entry=entry:
  84. # x is a tuple
  85. _hook_proto(x[0] if len(x)==1 else x, entry)
  86. )
  87. )
  88. elif self.hook_type == 'forward_out':
  89. # NOTE: Register forward hooks for LAYERs.
  90. for name, module in self.model.named_sublayers():
  91. if name in self.fetch_dict:
  92. entry = self.fetch_dict[name]
  93. self._handles.append(
  94. module.register_forward_post_hook(
  95. lambda l, x, y, entry=entry:
  96. # y is a tensor or a tuple
  97. _hook_proto(y, entry)
  98. )
  99. )
  100. elif self.hook_type == 'backward':
  101. # NOTE: Register backward hooks for TENSORs.
  102. for name, param in self.model.named_parameters():
  103. if name in self.fetch_dict:
  104. entry = self.fetch_dict[name]
  105. self._handles.append(
  106. param.register_hook(
  107. lambda grad, entry=entry: _hook_proto(grad, entry)))
  108. else:
  109. raise RuntimeError("Hook type is not implemented.")
  110. def __exit__(self, exc_type, exc_val, ext_tb):
  111. for handle in self._handles:
  112. handle.remove()
  113. def _gen_key(self, key, i):
  114. return key + f'_{i}'
  115. def parse_args():
  116. parser = argparse.ArgumentParser()
  117. parser.add_argument(
  118. "--model_dir", default=None, type=str, help="Path of saved model.")
  119. parser.add_argument(
  120. "--hook_type", default='forward_out', type=str, help="Type of hook.")
  121. parser.add_argument(
  122. "--layer_names",
  123. nargs='+',
  124. default=[],
  125. type=str,
  126. help="Layers that accepts or produces the features to visualize.")
  127. parser.add_argument(
  128. "--im_paths", nargs='+', type=str, help="Paths of input images.")
  129. parser.add_argument(
  130. "--save_dir",
  131. type=str,
  132. help="Path of directory to save prediction results.")
  133. parser.add_argument(
  134. "--to_pseudo_color",
  135. action='store_true',
  136. help="Whether to save pseudo-color images.")
  137. parser.add_argument(
  138. "--output_size",
  139. nargs='+',
  140. type=int,
  141. default=None,
  142. help="Resize the visualized image to `output_size`.")
  143. return parser.parse_args()
  144. def normalize_minmax(x):
  145. EPS = 1e-32
  146. return (x - x.min()) / (x.max() - x.min() + EPS)
  147. def quantize_8bit(x):
  148. # [0.0,1.0] float => [0,255] uint8
  149. # or [0,1] int => [0,255] uint8
  150. return (x * 255).astype('uint8')
  151. def to_pseudo_color(gray, color_map=cv2.COLORMAP_JET):
  152. return cv2.applyColorMap(gray, color_map)
  153. def process_fetched_feat(feat, to_pcolor=True):
  154. # Convert tensor to array
  155. feat = feat.squeeze(0).numpy()
  156. # Get principal component
  157. shape = feat.shape
  158. x = feat.reshape(shape[0], -1).transpose((1, 0))
  159. pca = PCA(n_components=1)
  160. y = pca.fit_transform(x)
  161. feat = y.reshape(shape[1:])
  162. feat = normalize_minmax(feat)
  163. feat = quantize_8bit(feat)
  164. if to_pcolor:
  165. feat = to_pseudo_color(feat)
  166. return feat
  167. if __name__ == '__main__':
  168. args = parse_args()
  169. # Load model
  170. model = paddlers.tasks.load_model(args.model_dir)
  171. fetch_dict = dict(zip(args.layer_names, args.layer_names))
  172. out_dict = FeatureContainer()
  173. with HookHelper(model.net, fetch_dict, out_dict, hook_type=args.hook_type):
  174. if len(args.im_paths) == 1:
  175. model.predict(args.im_paths[0])
  176. else:
  177. if len(args.im_paths) != 2:
  178. raise ValueError
  179. model.predict(tuple(args.im_paths))
  180. if not osp.exists(args.save_dir):
  181. os.makedirs(args.save_dir)
  182. for key, feats in out_dict.items():
  183. for idx, feat in enumerate(feats):
  184. im_vis = process_fetched_feat(feat, to_pcolor=args.to_pseudo_color)
  185. if args.output_size is not None:
  186. im_vis = cv2.resize(im_vis, tuple(args.output_size))
  187. out_path = osp.join(
  188. args.save_dir,
  189. FILENAME_PATTERN.format(
  190. key=key.replace('.', '_'), idx=idx))
  191. cv2.imwrite(out_path, im_vis)
  192. print(f"Write feature map to {out_path}")