restorer.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. from collections import OrderedDict
  17. import numpy as np
  18. import cv2
  19. import paddle
  20. import paddle.nn.functional as F
  21. from paddle.static import InputSpec
  22. import paddlers
  23. import paddlers.models.ppgan as ppgan
  24. import paddlers.rs_models.res as cmres
  25. import paddlers.models.ppgan.metrics as metrics
  26. import paddlers.utils.logging as logging
  27. from paddlers.models import res_losses
  28. from paddlers.transforms import Resize, decode_image
  29. from paddlers.transforms.functions import calc_hr_shape
  30. from paddlers.utils import get_single_card_bs
  31. from paddlers.utils.checkpoint import res_pretrain_weights_dict
  32. from .base import BaseModel
  33. from .utils.res_adapters import GANAdapter, OptimizerAdapter
  34. from .utils.infer_nets import InferResNet
  35. __all__ = ["DRN", "LESRCNN", "ESRGAN"]
  36. class BaseRestorer(BaseModel):
  37. MIN_MAX = (0., 1.)
  38. TEST_OUT_KEY = None
  39. def __init__(self, model_name, losses=None, sr_factor=None, **params):
  40. self.init_params = locals()
  41. if 'with_net' in self.init_params:
  42. del self.init_params['with_net']
  43. super(BaseRestorer, self).__init__('restorer')
  44. self.model_name = model_name
  45. self.losses = losses
  46. self.sr_factor = sr_factor
  47. if params.get('with_net', True):
  48. params.pop('with_net', None)
  49. self.net = self.build_net(**params)
  50. self.find_unused_parameters = True
  51. def build_net(self, **params):
  52. # Currently, only use models from cmres.
  53. if not hasattr(cmres, self.model_name):
  54. raise ValueError("ERROR: There is no model named {}.".format(
  55. model_name))
  56. net = dict(**cmres.__dict__)[self.model_name](**params)
  57. return net
  58. def _build_inference_net(self):
  59. # For GAN models, only the generator will be used for inference.
  60. if isinstance(self.net, GANAdapter):
  61. infer_net = InferResNet(
  62. self.net.generator, out_key=self.TEST_OUT_KEY)
  63. else:
  64. infer_net = InferResNet(self.net, out_key=self.TEST_OUT_KEY)
  65. infer_net.eval()
  66. return infer_net
  67. def _fix_transforms_shape(self, image_shape):
  68. if hasattr(self, 'test_transforms'):
  69. if self.test_transforms is not None:
  70. has_resize_op = False
  71. resize_op_idx = -1
  72. normalize_op_idx = len(self.test_transforms.transforms)
  73. for idx, op in enumerate(self.test_transforms.transforms):
  74. name = op.__class__.__name__
  75. if name == 'Normalize':
  76. normalize_op_idx = idx
  77. if 'Resize' in name:
  78. has_resize_op = True
  79. resize_op_idx = idx
  80. if not has_resize_op:
  81. self.test_transforms.transforms.insert(
  82. normalize_op_idx, Resize(target_size=image_shape))
  83. else:
  84. self.test_transforms.transforms[resize_op_idx] = Resize(
  85. target_size=image_shape)
  86. def _get_test_inputs(self, image_shape):
  87. if image_shape is not None:
  88. if len(image_shape) == 2:
  89. image_shape = [1, 3] + image_shape
  90. self._fix_transforms_shape(image_shape[-2:])
  91. else:
  92. image_shape = [None, 3, -1, -1]
  93. self.fixed_input_shape = image_shape
  94. input_spec = [
  95. InputSpec(
  96. shape=image_shape, name='image', dtype='float32')
  97. ]
  98. return input_spec
  99. def run(self, net, inputs, mode):
  100. outputs = OrderedDict()
  101. if mode == 'test':
  102. tar_shape = inputs[1]
  103. if self.status == 'Infer':
  104. net_out = net(inputs[0])
  105. res_map_list = self.postprocess(
  106. net_out, tar_shape, transforms=inputs[2])
  107. else:
  108. if isinstance(net, GANAdapter):
  109. net_out = net.generator(inputs[0])
  110. else:
  111. net_out = net(inputs[0])
  112. if self.TEST_OUT_KEY is not None:
  113. net_out = net_out[self.TEST_OUT_KEY]
  114. pred = self.postprocess(
  115. net_out, tar_shape, transforms=inputs[2])
  116. res_map_list = []
  117. for res_map in pred:
  118. res_map = self._tensor_to_images(res_map)
  119. res_map_list.append(res_map)
  120. outputs['res_map'] = res_map_list
  121. if mode == 'eval':
  122. if isinstance(net, GANAdapter):
  123. net_out = net.generator(inputs[0])
  124. else:
  125. net_out = net(inputs[0])
  126. if self.TEST_OUT_KEY is not None:
  127. net_out = net_out[self.TEST_OUT_KEY]
  128. tar = inputs[1]
  129. tar_shape = [tar.shape[-2:]]
  130. pred = self.postprocess(
  131. net_out, tar_shape, transforms=inputs[2])[0] # NCHW
  132. pred = self._tensor_to_images(pred)
  133. outputs['pred'] = pred
  134. tar = self._tensor_to_images(tar)
  135. outputs['tar'] = tar
  136. if mode == 'train':
  137. # This is used by non-GAN models.
  138. # For GAN models, self.run_gan() should be used.
  139. net_out = net(inputs[0])
  140. loss = self.losses(net_out, inputs[1])
  141. outputs['loss'] = loss
  142. return outputs
  143. def run_gan(self, net, inputs, mode, gan_mode):
  144. raise NotImplementedError
  145. def default_loss(self):
  146. return res_losses.L1Loss()
  147. def default_optimizer(self,
  148. parameters,
  149. learning_rate,
  150. num_epochs,
  151. num_steps_each_epoch,
  152. lr_decay_power=0.9):
  153. decay_step = num_epochs * num_steps_each_epoch
  154. lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
  155. learning_rate, decay_step, end_lr=0, power=lr_decay_power)
  156. optimizer = paddle.optimizer.Momentum(
  157. learning_rate=lr_scheduler,
  158. parameters=parameters,
  159. momentum=0.9,
  160. weight_decay=4e-5)
  161. return optimizer
  162. def train(self,
  163. num_epochs,
  164. train_dataset,
  165. train_batch_size=2,
  166. eval_dataset=None,
  167. optimizer=None,
  168. save_interval_epochs=1,
  169. log_interval_steps=2,
  170. save_dir='output',
  171. pretrain_weights=None,
  172. learning_rate=0.01,
  173. lr_decay_power=0.9,
  174. early_stop=False,
  175. early_stop_patience=5,
  176. use_vdl=True,
  177. resume_checkpoint=None):
  178. """
  179. Train the model.
  180. Args:
  181. num_epochs (int): Number of epochs.
  182. train_dataset (paddlers.datasets.ResDataset): Training dataset.
  183. train_batch_size (int, optional): Total batch size among all cards used in
  184. training. Defaults to 2.
  185. eval_dataset (paddlers.datasets.ResDataset|None, optional): Evaluation dataset.
  186. If None, the model will not be evaluated during training process.
  187. Defaults to None.
  188. optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in
  189. training. If None, a default optimizer will be used. Defaults to None.
  190. save_interval_epochs (int, optional): Epoch interval for saving the model.
  191. Defaults to 1.
  192. log_interval_steps (int, optional): Step interval for printing training
  193. information. Defaults to 2.
  194. save_dir (str, optional): Directory to save the model. Defaults to 'output'.
  195. pretrain_weights (str|None, optional): None or name/path of pretrained
  196. weights. If None, no pretrained weights will be loaded.
  197. Defaults to None.
  198. learning_rate (float, optional): Learning rate for training. Defaults to .01.
  199. lr_decay_power (float, optional): Learning decay power. Defaults to .9.
  200. early_stop (bool, optional): Whether to adopt early stop strategy. Defaults
  201. to False.
  202. early_stop_patience (int, optional): Early stop patience. Defaults to 5.
  203. use_vdl (bool, optional): Whether to use VisualDL to monitor the training
  204. process. Defaults to True.
  205. resume_checkpoint (str|None, optional): Path of the checkpoint to resume
  206. training from. If None, no training checkpoint will be resumed. At most
  207. Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
  208. Defaults to None.
  209. """
  210. if self.status == 'Infer':
  211. logging.error(
  212. "Exported inference model does not support training.",
  213. exit=True)
  214. if pretrain_weights is not None and resume_checkpoint is not None:
  215. logging.error(
  216. "`pretrain_weights` and `resume_checkpoint` cannot be set simultaneously.",
  217. exit=True)
  218. if self.losses is None:
  219. self.losses = self.default_loss()
  220. if optimizer is None:
  221. num_steps_each_epoch = train_dataset.num_samples // train_batch_size
  222. if isinstance(self.net, GANAdapter):
  223. parameters = {'params_g': [], 'params_d': []}
  224. for net_g in self.net.generators:
  225. parameters['params_g'].append(net_g.parameters())
  226. for net_d in self.net.discriminators:
  227. parameters['params_d'].append(net_d.parameters())
  228. else:
  229. parameters = self.net.parameters()
  230. self.optimizer = self.default_optimizer(
  231. parameters, learning_rate, num_epochs, num_steps_each_epoch,
  232. lr_decay_power)
  233. else:
  234. self.optimizer = optimizer
  235. if pretrain_weights is not None:
  236. if not osp.exists(pretrain_weights):
  237. if self.model_name not in res_pretrain_weights_dict:
  238. logging.warning(
  239. "Path of pretrained weights ('{}') does not exist!".
  240. format(pretrain_weights))
  241. pretrain_weights = None
  242. elif pretrain_weights not in res_pretrain_weights_dict[
  243. self.model_name]:
  244. logging.warning(
  245. "Path of pretrained weights ('{}') does not exist!".
  246. format(pretrain_weights))
  247. pretrain_weights = res_pretrain_weights_dict[
  248. self.model_name][0]
  249. logging.warning(
  250. "`pretrain_weights` is forcibly set to '{}'. "
  251. "If you don't want to use pretrained weights, "
  252. "please set `pretrain_weights` to None.".format(
  253. pretrain_weights))
  254. else:
  255. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  256. logging.error(
  257. "Invalid pretrained weights. Please specify a .pdparams file.",
  258. exit=True)
  259. pretrained_dir = osp.join(save_dir, 'pretrain')
  260. is_backbone_weights = pretrain_weights == 'IMAGENET'
  261. self.initialize_net(
  262. pretrain_weights=pretrain_weights,
  263. save_dir=pretrained_dir,
  264. resume_checkpoint=resume_checkpoint,
  265. is_backbone_weights=is_backbone_weights)
  266. self.train_loop(
  267. num_epochs=num_epochs,
  268. train_dataset=train_dataset,
  269. train_batch_size=train_batch_size,
  270. eval_dataset=eval_dataset,
  271. save_interval_epochs=save_interval_epochs,
  272. log_interval_steps=log_interval_steps,
  273. save_dir=save_dir,
  274. early_stop=early_stop,
  275. early_stop_patience=early_stop_patience,
  276. use_vdl=use_vdl)
  277. def quant_aware_train(self,
  278. num_epochs,
  279. train_dataset,
  280. train_batch_size=2,
  281. eval_dataset=None,
  282. optimizer=None,
  283. save_interval_epochs=1,
  284. log_interval_steps=2,
  285. save_dir='output',
  286. learning_rate=0.0001,
  287. lr_decay_power=0.9,
  288. early_stop=False,
  289. early_stop_patience=5,
  290. use_vdl=True,
  291. resume_checkpoint=None,
  292. quant_config=None):
  293. """
  294. Quantization-aware training.
  295. Args:
  296. num_epochs (int): Number of epochs.
  297. train_dataset (paddlers.datasets.ResDataset): Training dataset.
  298. train_batch_size (int, optional): Total batch size among all cards used in
  299. training. Defaults to 2.
  300. eval_dataset (paddlers.datasets.ResDataset|None, optional): Evaluation dataset.
  301. If None, the model will not be evaluated during training process.
  302. Defaults to None.
  303. optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in
  304. training. If None, a default optimizer will be used. Defaults to None.
  305. save_interval_epochs (int, optional): Epoch interval for saving the model.
  306. Defaults to 1.
  307. log_interval_steps (int, optional): Step interval for printing training
  308. information. Defaults to 2.
  309. save_dir (str, optional): Directory to save the model. Defaults to 'output'.
  310. learning_rate (float, optional): Learning rate for training.
  311. Defaults to .0001.
  312. lr_decay_power (float, optional): Learning decay power. Defaults to .9.
  313. early_stop (bool, optional): Whether to adopt early stop strategy.
  314. Defaults to False.
  315. early_stop_patience (int, optional): Early stop patience. Defaults to 5.
  316. use_vdl (bool, optional): Whether to use VisualDL to monitor the training
  317. process. Defaults to True.
  318. quant_config (dict|None, optional): Quantization configuration. If None,
  319. a default rule of thumb configuration will be used. Defaults to None.
  320. resume_checkpoint (str|None, optional): Path of the checkpoint to resume
  321. quantization-aware training from. If None, no training checkpoint will
  322. be resumed. Defaults to None.
  323. """
  324. self._prepare_qat(quant_config)
  325. self.train(
  326. num_epochs=num_epochs,
  327. train_dataset=train_dataset,
  328. train_batch_size=train_batch_size,
  329. eval_dataset=eval_dataset,
  330. optimizer=optimizer,
  331. save_interval_epochs=save_interval_epochs,
  332. log_interval_steps=log_interval_steps,
  333. save_dir=save_dir,
  334. pretrain_weights=None,
  335. learning_rate=learning_rate,
  336. lr_decay_power=lr_decay_power,
  337. early_stop=early_stop,
  338. early_stop_patience=early_stop_patience,
  339. use_vdl=use_vdl,
  340. resume_checkpoint=resume_checkpoint)
  341. def evaluate(self, eval_dataset, batch_size=1, return_details=False):
  342. """
  343. Evaluate the model.
  344. Args:
  345. eval_dataset (paddlers.datasets.ResDataset): Evaluation dataset.
  346. batch_size (int, optional): Total batch size among all cards used for
  347. evaluation. Defaults to 1.
  348. return_details (bool, optional): Whether to return evaluation details.
  349. Defaults to False.
  350. Returns:
  351. If `return_details` is False, return collections.OrderedDict with
  352. key-value pairs:
  353. {"psnr": peak signal-to-noise ratio,
  354. "ssim": structural similarity}.
  355. """
  356. self._check_transforms(eval_dataset.transforms, 'eval')
  357. self.net.eval()
  358. nranks = paddle.distributed.get_world_size()
  359. local_rank = paddle.distributed.get_rank()
  360. if nranks > 1:
  361. # Initialize parallel environment if not done.
  362. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  363. ):
  364. paddle.distributed.init_parallel_env()
  365. # TODO: Distributed evaluation
  366. if batch_size > 1:
  367. logging.warning(
  368. "Restorer only supports single card evaluation with batch_size=1 "
  369. "during evaluation, so batch_size is forcibly set to 1.")
  370. batch_size = 1
  371. if nranks < 2 or local_rank == 0:
  372. self.eval_data_loader = self.build_data_loader(
  373. eval_dataset, batch_size=batch_size, mode='eval')
  374. # XXX: Hard-code crop_border and test_y_channel
  375. psnr = metrics.PSNR(crop_border=4, test_y_channel=True)
  376. ssim = metrics.SSIM(crop_border=4, test_y_channel=True)
  377. logging.info(
  378. "Start to evaluate(total_samples={}, total_steps={})...".format(
  379. eval_dataset.num_samples, eval_dataset.num_samples))
  380. with paddle.no_grad():
  381. for step, data in enumerate(self.eval_data_loader):
  382. data.append(eval_dataset.transforms.transforms)
  383. outputs = self.run(self.net, data, 'eval')
  384. psnr.update(outputs['pred'], outputs['tar'])
  385. ssim.update(outputs['pred'], outputs['tar'])
  386. # DO NOT use psnr.accumulate() here, otherwise the program hangs in multi-card training.
  387. assert len(psnr.results) > 0
  388. assert len(ssim.results) > 0
  389. eval_metrics = OrderedDict(
  390. zip(['psnr', 'ssim'],
  391. [np.mean(psnr.results), np.mean(ssim.results)]))
  392. if return_details:
  393. # TODO: Add details
  394. return eval_metrics, None
  395. return eval_metrics
  396. def predict(self, img_file, transforms=None):
  397. """
  398. Do inference.
  399. Args:
  400. img_file (list[np.ndarray|str] | str | np.ndarray): Image path or decoded
  401. image data, which also could constitute a list, meaning all images to be
  402. predicted as a mini-batch.
  403. transforms (paddlers.transforms.Compose|None, optional): Transforms for
  404. inputs. If None, the transforms for evaluation process will be used.
  405. Defaults to None.
  406. Returns:
  407. If `img_file` is a tuple of string or np.array, the result is a dict with
  408. the following key-value pairs:
  409. res_map (np.ndarray): Restored image (HWC).
  410. If `img_file` is a list, the result is a list composed of dicts with the
  411. above keys.
  412. """
  413. if transforms is None and not hasattr(self, 'test_transforms'):
  414. raise ValueError("transforms need to be defined, now is None.")
  415. if transforms is None:
  416. transforms = self.test_transforms
  417. if isinstance(img_file, (str, np.ndarray)):
  418. images = [img_file]
  419. else:
  420. images = img_file
  421. batch_im, batch_tar_shape = self.preprocess(images, transforms,
  422. self.model_type)
  423. self.net.eval()
  424. data = (batch_im, batch_tar_shape, transforms.transforms)
  425. outputs = self.run(self.net, data, 'test')
  426. res_map_list = outputs['res_map']
  427. if isinstance(img_file, list):
  428. prediction = [{'res_map': m} for m in res_map_list]
  429. else:
  430. prediction = {'res_map': res_map_list[0]}
  431. return prediction
  432. def preprocess(self, images, transforms, to_tensor=True):
  433. self._check_transforms(transforms, 'test')
  434. batch_im = list()
  435. batch_tar_shape = list()
  436. for im in images:
  437. if isinstance(im, str):
  438. im = decode_image(im, read_raw=True)
  439. ori_shape = im.shape[:2]
  440. sample = {'image': im}
  441. im = transforms(sample)[0]
  442. batch_im.append(im)
  443. batch_tar_shape.append(self._get_target_shape(ori_shape))
  444. if to_tensor:
  445. batch_im = paddle.to_tensor(batch_im)
  446. else:
  447. batch_im = np.asarray(batch_im)
  448. return batch_im, batch_tar_shape
  449. def _get_target_shape(self, ori_shape):
  450. if self.sr_factor is None:
  451. return ori_shape
  452. else:
  453. return calc_hr_shape(ori_shape, self.sr_factor)
  454. @staticmethod
  455. def get_transforms_shape_info(batch_tar_shape, transforms):
  456. batch_restore_list = list()
  457. for tar_shape in batch_tar_shape:
  458. restore_list = list()
  459. h, w = tar_shape[0], tar_shape[1]
  460. for op in transforms:
  461. if op.__class__.__name__ == 'Resize':
  462. restore_list.append(('resize', (h, w)))
  463. h, w = op.target_size
  464. elif op.__class__.__name__ == 'ResizeByShort':
  465. restore_list.append(('resize', (h, w)))
  466. im_short_size = min(h, w)
  467. im_long_size = max(h, w)
  468. scale = float(op.short_size) / float(im_short_size)
  469. if 0 < op.max_size < np.round(scale * im_long_size):
  470. scale = float(op.max_size) / float(im_long_size)
  471. h = int(round(h * scale))
  472. w = int(round(w * scale))
  473. elif op.__class__.__name__ == 'ResizeByLong':
  474. restore_list.append(('resize', (h, w)))
  475. im_long_size = max(h, w)
  476. scale = float(op.long_size) / float(im_long_size)
  477. h = int(round(h * scale))
  478. w = int(round(w * scale))
  479. elif op.__class__.__name__ == 'Pad':
  480. if op.target_size:
  481. target_h, target_w = op.target_size
  482. else:
  483. target_h = int(
  484. (np.ceil(h / op.size_divisor) * op.size_divisor))
  485. target_w = int(
  486. (np.ceil(w / op.size_divisor) * op.size_divisor))
  487. if op.pad_mode == -1:
  488. offsets = op.offsets
  489. elif op.pad_mode == 0:
  490. offsets = [0, 0]
  491. elif op.pad_mode == 1:
  492. offsets = [(target_h - h) // 2, (target_w - w) // 2]
  493. else:
  494. offsets = [target_h - h, target_w - w]
  495. restore_list.append(('padding', (h, w), offsets))
  496. h, w = target_h, target_w
  497. batch_restore_list.append(restore_list)
  498. return batch_restore_list
  499. def postprocess(self, batch_pred, batch_tar_shape, transforms):
  500. batch_restore_list = BaseRestorer.get_transforms_shape_info(
  501. batch_tar_shape, transforms)
  502. if self.status == 'Infer':
  503. return self._infer_postprocess(
  504. batch_res_map=batch_pred, batch_restore_list=batch_restore_list)
  505. results = []
  506. if batch_pred.dtype == paddle.float32:
  507. mode = 'bilinear'
  508. else:
  509. mode = 'nearest'
  510. for pred, restore_list in zip(batch_pred, batch_restore_list):
  511. pred = paddle.unsqueeze(pred, axis=0)
  512. for item in restore_list[::-1]:
  513. h, w = item[1][0], item[1][1]
  514. if item[0] == 'resize':
  515. pred = F.interpolate(
  516. pred, (h, w), mode=mode, data_format='NCHW')
  517. elif item[0] == 'padding':
  518. x, y = item[2]
  519. pred = pred[:, :, y:y + h, x:x + w]
  520. else:
  521. pass
  522. results.append(pred)
  523. return results
  524. def _infer_postprocess(self, batch_res_map, batch_restore_list):
  525. res_maps = []
  526. for res_map, restore_list in zip(batch_res_map, batch_restore_list):
  527. if not isinstance(res_map, np.ndarray):
  528. res_map = paddle.unsqueeze(res_map, axis=0)
  529. for item in restore_list[::-1]:
  530. h, w = item[1][0], item[1][1]
  531. if item[0] == 'resize':
  532. if isinstance(res_map, np.ndarray):
  533. res_map = cv2.resize(
  534. res_map, (w, h), interpolation=cv2.INTER_LINEAR)
  535. else:
  536. res_map = F.interpolate(
  537. res_map, (h, w),
  538. mode='bilinear',
  539. data_format='NHWC')
  540. elif item[0] == 'padding':
  541. x, y = item[2]
  542. if isinstance(res_map, np.ndarray):
  543. res_map = res_map[y:y + h, x:x + w]
  544. else:
  545. res_map = res_map[:, y:y + h, x:x + w, :]
  546. else:
  547. pass
  548. res_map = res_map.squeeze()
  549. if not isinstance(res_map, np.ndarray):
  550. res_map = res_map.numpy()
  551. res_map = self._normalize(res_map)
  552. res_maps.append(res_map.squeeze())
  553. return res_maps
  554. def _check_transforms(self, transforms, mode):
  555. super()._check_transforms(transforms, mode)
  556. if not isinstance(transforms.arrange,
  557. paddlers.transforms.ArrangeRestorer):
  558. raise TypeError(
  559. "`transforms.arrange` must be an ArrangeRestorer object.")
  560. def build_data_loader(self, dataset, batch_size, mode='train'):
  561. if dataset.num_samples < batch_size:
  562. raise ValueError(
  563. 'The volume of dataset({}) must be larger than batch size({}).'
  564. .format(dataset.num_samples, batch_size))
  565. if mode != 'train':
  566. return paddle.io.DataLoader(
  567. dataset,
  568. batch_size=batch_size,
  569. shuffle=dataset.shuffle,
  570. drop_last=False,
  571. collate_fn=dataset.batch_transforms,
  572. num_workers=dataset.num_workers,
  573. return_list=True,
  574. use_shared_memory=False)
  575. else:
  576. return super(BaseRestorer, self).build_data_loader(dataset,
  577. batch_size, mode)
  578. def set_losses(self, losses):
  579. self.losses = losses
  580. def _tensor_to_images(self,
  581. tensor,
  582. transpose=True,
  583. squeeze=True,
  584. quantize=True):
  585. if transpose:
  586. tensor = paddle.transpose(tensor, perm=[0, 2, 3, 1]) # NHWC
  587. if squeeze:
  588. tensor = tensor.squeeze()
  589. images = tensor.numpy().astype('float32')
  590. images = self._normalize(
  591. images, copy=True, clip=True, quantize=quantize)
  592. return images
  593. def _normalize(self, im, copy=False, clip=True, quantize=True):
  594. if copy:
  595. im = im.copy()
  596. if clip:
  597. im = np.clip(im, self.MIN_MAX[0], self.MIN_MAX[1])
  598. im -= im.min()
  599. im /= im.max() + 1e-32
  600. if quantize:
  601. im *= 255
  602. im = im.astype('uint8')
  603. return im
  604. class DRN(BaseRestorer):
  605. TEST_OUT_KEY = -1
  606. def __init__(self,
  607. losses=None,
  608. sr_factor=4,
  609. scales=(2, 4),
  610. n_blocks=30,
  611. n_feats=16,
  612. n_colors=3,
  613. rgb_range=1.0,
  614. negval=0.2,
  615. lq_loss_weight=0.1,
  616. dual_loss_weight=0.1,
  617. **params):
  618. if sr_factor != max(scales):
  619. raise ValueError(f"`sr_factor` must be equal to `max(scales)`.")
  620. params.update({
  621. 'scale': scales,
  622. 'n_blocks': n_blocks,
  623. 'n_feats': n_feats,
  624. 'n_colors': n_colors,
  625. 'rgb_range': rgb_range,
  626. 'negval': negval
  627. })
  628. self.lq_loss_weight = lq_loss_weight
  629. self.dual_loss_weight = dual_loss_weight
  630. self.scales = scales
  631. super(DRN, self).__init__(
  632. model_name='DRN', losses=losses, sr_factor=sr_factor, **params)
  633. def build_net(self, **params):
  634. from ppgan.modules.init import init_weights
  635. generators = [ppgan.models.generators.DRNGenerator(**params)]
  636. init_weights(generators[-1])
  637. for scale in params['scale']:
  638. dual_model = ppgan.models.generators.drn.DownBlock(
  639. params['negval'], params['n_feats'], params['n_colors'], 2)
  640. generators.append(dual_model)
  641. init_weights(generators[-1])
  642. return GANAdapter(generators, [])
  643. def default_optimizer(self, parameters, *args, **kwargs):
  644. optims_g = [
  645. super(DRN, self).default_optimizer(params_g, *args, **kwargs)
  646. for params_g in parameters['params_g']
  647. ]
  648. return OptimizerAdapter(*optims_g)
  649. def run_gan(self, net, inputs, mode, gan_mode='forward_primary'):
  650. if mode != 'train':
  651. raise ValueError("`mode` is not 'train'.")
  652. outputs = OrderedDict()
  653. if gan_mode == 'forward_primary':
  654. sr = net.generator(inputs[0])
  655. lr = [inputs[0]]
  656. lr.extend([
  657. F.interpolate(
  658. inputs[0], scale_factor=s, mode='bicubic')
  659. for s in self.scales[:-1]
  660. ])
  661. loss = self.losses(sr[-1], inputs[1])
  662. for i in range(1, len(sr)):
  663. if self.lq_loss_weight > 0:
  664. loss += self.losses(sr[i - 1 - len(sr)],
  665. lr[i - len(sr)]) * self.lq_loss_weight
  666. outputs['loss_prim'] = loss
  667. outputs['sr'] = sr
  668. outputs['lr'] = lr
  669. elif gan_mode == 'forward_dual':
  670. sr, lr = inputs[0], inputs[1]
  671. sr2lr = []
  672. n_scales = len(self.scales)
  673. for i in range(n_scales):
  674. sr2lr_i = net.generators[1 + i](sr[i - n_scales])
  675. sr2lr.append(sr2lr_i)
  676. loss = self.losses(sr2lr[0], lr[0])
  677. for i in range(1, n_scales):
  678. if self.dual_loss_weight > 0.0:
  679. loss += self.losses(sr2lr[i], lr[i]) * self.dual_loss_weight
  680. outputs['loss_dual'] = loss
  681. else:
  682. raise ValueError("Invalid `gan_mode`!")
  683. return outputs
  684. def train_step(self, step, data, net):
  685. outputs = self.run_gan(
  686. net, data, mode='train', gan_mode='forward_primary')
  687. outputs.update(
  688. self.run_gan(
  689. net, (outputs['sr'], outputs['lr']),
  690. mode='train',
  691. gan_mode='forward_dual'))
  692. self.optimizer.clear_grad()
  693. (outputs['loss_prim'] + outputs['loss_dual']).backward()
  694. self.optimizer.step()
  695. return {
  696. 'loss': outputs['loss_prim'] + outputs['loss_dual'],
  697. 'loss_prim': outputs['loss_prim'],
  698. 'loss_dual': outputs['loss_dual']
  699. }
  700. class LESRCNN(BaseRestorer):
  701. def __init__(self,
  702. losses=None,
  703. sr_factor=4,
  704. multi_scale=False,
  705. group=1,
  706. **params):
  707. params.update({
  708. 'scale': sr_factor if sr_factor is not None else 1,
  709. 'multi_scale': multi_scale,
  710. 'group': group
  711. })
  712. super(LESRCNN, self).__init__(
  713. model_name='LESRCNN', losses=losses, sr_factor=sr_factor, **params)
  714. def build_net(self, **params):
  715. net = ppgan.models.generators.LESRCNNGenerator(**params)
  716. return net
  717. class ESRGAN(BaseRestorer):
  718. def __init__(self,
  719. losses=None,
  720. sr_factor=4,
  721. use_gan=True,
  722. in_channels=3,
  723. out_channels=3,
  724. nf=64,
  725. nb=23,
  726. **params):
  727. if sr_factor != 4:
  728. raise ValueError("`sr_factor` must be 4.")
  729. params.update({
  730. 'in_nc': in_channels,
  731. 'out_nc': out_channels,
  732. 'nf': nf,
  733. 'nb': nb
  734. })
  735. self.use_gan = use_gan
  736. super(ESRGAN, self).__init__(
  737. model_name='ESRGAN', losses=losses, sr_factor=sr_factor, **params)
  738. def build_net(self, **params):
  739. from ppgan.modules.init import init_weights
  740. generator = ppgan.models.generators.RRDBNet(**params)
  741. init_weights(generator)
  742. if self.use_gan:
  743. discriminator = ppgan.models.discriminators.VGGDiscriminator128(
  744. in_channels=params['out_nc'], num_feat=64)
  745. net = GANAdapter(
  746. generators=[generator], discriminators=[discriminator])
  747. else:
  748. net = generator
  749. return net
  750. def default_loss(self):
  751. if self.use_gan:
  752. return {
  753. 'pixel': res_losses.L1Loss(loss_weight=0.01),
  754. 'perceptual': res_losses.PerceptualLoss(
  755. layer_weights={'34': 1.0},
  756. perceptual_weight=1.0,
  757. style_weight=0.0,
  758. norm_img=False),
  759. 'gan': res_losses.GANLoss(
  760. gan_mode='vanilla', loss_weight=0.005)
  761. }
  762. else:
  763. return res_losses.L1Loss()
  764. def default_optimizer(self, parameters, *args, **kwargs):
  765. if self.use_gan:
  766. optim_g = super(ESRGAN, self).default_optimizer(
  767. parameters['params_g'][0], *args, **kwargs)
  768. optim_d = super(ESRGAN, self).default_optimizer(
  769. parameters['params_d'][0], *args, **kwargs)
  770. return OptimizerAdapter(optim_g, optim_d)
  771. else:
  772. return super(ESRGAN, self).default_optimizer(parameters, *args,
  773. **kwargs)
  774. def run_gan(self, net, inputs, mode, gan_mode='forward_g'):
  775. if mode != 'train':
  776. raise ValueError("`mode` is not 'train'.")
  777. outputs = OrderedDict()
  778. if gan_mode == 'forward_g':
  779. loss_g = 0
  780. g_pred = net.generator(inputs[0])
  781. loss_pix = self.losses['pixel'](g_pred, inputs[1])
  782. loss_perc, loss_sty = self.losses['perceptual'](g_pred, inputs[1])
  783. loss_g += loss_pix
  784. if loss_perc is not None:
  785. loss_g += loss_perc
  786. if loss_sty is not None:
  787. loss_g += loss_sty
  788. self._set_requires_grad(net.discriminator, False)
  789. real_d_pred = net.discriminator(inputs[1]).detach()
  790. fake_g_pred = net.discriminator(g_pred)
  791. loss_g_real = self.losses['gan'](
  792. real_d_pred - paddle.mean(fake_g_pred), False,
  793. is_disc=False) * 0.5
  794. loss_g_fake = self.losses['gan'](
  795. fake_g_pred - paddle.mean(real_d_pred), True,
  796. is_disc=False) * 0.5
  797. loss_g_gan = loss_g_real + loss_g_fake
  798. outputs['g_pred'] = g_pred.detach()
  799. outputs['loss_g_pps'] = loss_g
  800. outputs['loss_g_gan'] = loss_g_gan
  801. elif gan_mode == 'forward_d':
  802. self._set_requires_grad(net.discriminator, True)
  803. # Real
  804. fake_d_pred = net.discriminator(inputs[0]).detach()
  805. real_d_pred = net.discriminator(inputs[1])
  806. loss_d_real = self.losses['gan'](
  807. real_d_pred - paddle.mean(fake_d_pred), True,
  808. is_disc=True) * 0.5
  809. # Fake
  810. fake_d_pred = net.discriminator(inputs[0].detach())
  811. loss_d_fake = self.losses['gan'](
  812. fake_d_pred - paddle.mean(real_d_pred.detach()),
  813. False,
  814. is_disc=True) * 0.5
  815. outputs['loss_d'] = loss_d_real + loss_d_fake
  816. else:
  817. raise ValueError("Invalid `gan_mode`!")
  818. return outputs
  819. def train_step(self, step, data, net):
  820. if self.use_gan:
  821. optim_g, optim_d = self.optimizer
  822. outputs = self.run_gan(
  823. net, data, mode='train', gan_mode='forward_g')
  824. optim_g.clear_grad()
  825. (outputs['loss_g_pps'] + outputs['loss_g_gan']).backward()
  826. optim_g.step()
  827. outputs.update(
  828. self.run_gan(
  829. net, (outputs['g_pred'], data[1]),
  830. mode='train',
  831. gan_mode='forward_d'))
  832. optim_d.clear_grad()
  833. outputs['loss_d'].backward()
  834. optim_d.step()
  835. outputs['loss'] = outputs['loss_g_pps'] + outputs[
  836. 'loss_g_gan'] + outputs['loss_d']
  837. return {
  838. 'loss': outputs['loss'],
  839. 'loss_g_pps': outputs['loss_g_pps'],
  840. 'loss_g_gan': outputs['loss_g_gan'],
  841. 'loss_d': outputs['loss_d']
  842. }
  843. else:
  844. return super(ESRGAN, self).train_step(step, data, net)
  845. def _set_requires_grad(self, net, requires_grad):
  846. for p in net.parameters():
  847. p.trainable = requires_grad
  848. class RCAN(BaseRestorer):
  849. def __init__(self,
  850. losses=None,
  851. sr_factor=4,
  852. n_resgroups=10,
  853. n_resblocks=20,
  854. n_feats=64,
  855. n_colors=3,
  856. rgb_range=1.0,
  857. kernel_size=3,
  858. reduction=16,
  859. **params):
  860. params.update({
  861. 'n_resgroups': n_resgroups,
  862. 'n_resblocks': n_resblocks,
  863. 'n_feats': n_feats,
  864. 'n_colors': n_colors,
  865. 'rgb_range': rgb_range,
  866. 'kernel_size': kernel_size,
  867. 'reduction': reduction
  868. })
  869. super(RCAN, self).__init__(
  870. model_name='RCAN', losses=losses, sr_factor=sr_factor, **params)