image_restorer.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import time
  16. import datetime
  17. import paddle
  18. from paddle.distributed import ParallelEnv
  19. from ..models.ppgan.datasets.builder import build_dataloader
  20. from ..models.ppgan.models.builder import build_model
  21. from ..models.ppgan.utils.visual import tensor2img, save_image
  22. from ..models.ppgan.utils.filesystem import makedirs, save, load
  23. from ..models.ppgan.utils.timer import TimeAverager
  24. from ..models.ppgan.utils.profiler import add_profiler_step
  25. from ..models.ppgan.utils.logger import setup_logger
  26. # 定义AttrDict类实现动态属性
  27. class AttrDict(dict):
  28. def __getattr__(self, key):
  29. try:
  30. return self[key]
  31. except KeyError:
  32. raise AttributeError(key)
  33. def __setattr__(self, key, value):
  34. if key in self.__dict__:
  35. self.__dict__[key] = value
  36. else:
  37. self[key] = value
  38. # 创建AttrDict类
  39. def create_attr_dict(config_dict):
  40. from ast import literal_eval
  41. for key, value in config_dict.items():
  42. if type(value) is dict:
  43. config_dict[key] = value = AttrDict(value)
  44. if isinstance(value, str):
  45. try:
  46. value = literal_eval(value)
  47. except BaseException:
  48. pass
  49. if isinstance(value, AttrDict):
  50. create_attr_dict(config_dict[key])
  51. else:
  52. config_dict[key] = value
  53. # 数据加载类
  54. class IterLoader:
  55. def __init__(self, dataloader):
  56. self._dataloader = dataloader
  57. self.iter_loader = iter(self._dataloader)
  58. self._epoch = 1
  59. @property
  60. def epoch(self):
  61. return self._epoch
  62. def __next__(self):
  63. try:
  64. data = next(self.iter_loader)
  65. except StopIteration:
  66. self._epoch += 1
  67. self.iter_loader = iter(self._dataloader)
  68. data = next(self.iter_loader)
  69. return data
  70. def __len__(self):
  71. return len(self._dataloader)
  72. # 基础训练类
  73. class Restorer:
  74. """
  75. # trainer calling logic:
  76. #
  77. # build_model || model(BaseModel)
  78. # | ||
  79. # build_dataloader || dataloader
  80. # | ||
  81. # model.setup_lr_schedulers || lr_scheduler
  82. # | ||
  83. # model.setup_optimizers || optimizers
  84. # | ||
  85. # train loop (model.setup_input + model.train_iter) || train loop
  86. # | ||
  87. # print log (model.get_current_losses) ||
  88. # | ||
  89. # save checkpoint (model.nets) \/
  90. """
  91. def __init__(self, cfg, logger):
  92. # base config
  93. # self.logger = logging.getLogger(__name__)
  94. self.logger = logger
  95. self.cfg = cfg
  96. self.output_dir = cfg.output_dir
  97. self.max_eval_steps = cfg.model.get('max_eval_steps', None)
  98. self.local_rank = ParallelEnv().local_rank
  99. self.world_size = ParallelEnv().nranks
  100. self.log_interval = cfg.log_config.interval
  101. self.visual_interval = cfg.log_config.visiual_interval
  102. self.weight_interval = cfg.snapshot_config.interval
  103. self.start_epoch = 1
  104. self.current_epoch = 1
  105. self.current_iter = 1
  106. self.inner_iter = 1
  107. self.batch_id = 0
  108. self.global_steps = 0
  109. # build model
  110. self.model = build_model(cfg.model)
  111. # multiple gpus prepare
  112. if ParallelEnv().nranks > 1:
  113. self.distributed_data_parallel()
  114. # build metrics
  115. self.metrics = None
  116. self.is_save_img = True
  117. validate_cfg = cfg.get('validate', None)
  118. if validate_cfg and 'metrics' in validate_cfg:
  119. self.metrics = self.model.setup_metrics(validate_cfg['metrics'])
  120. if validate_cfg and 'save_img' in validate_cfg:
  121. self.is_save_img = validate_cfg['save_img']
  122. self.enable_visualdl = cfg.get('enable_visualdl', False)
  123. if self.enable_visualdl:
  124. import visualdl
  125. self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir)
  126. # evaluate only
  127. if not cfg.is_train:
  128. return
  129. # build train dataloader
  130. self.train_dataloader = build_dataloader(cfg.dataset.train)
  131. self.iters_per_epoch = len(self.train_dataloader)
  132. # build lr scheduler
  133. # TODO: has a better way?
  134. if 'lr_scheduler' in cfg and 'iters_per_epoch' in cfg.lr_scheduler:
  135. cfg.lr_scheduler.iters_per_epoch = self.iters_per_epoch
  136. self.lr_schedulers = self.model.setup_lr_schedulers(cfg.lr_scheduler)
  137. # build optimizers
  138. self.optimizers = self.model.setup_optimizers(self.lr_schedulers,
  139. cfg.optimizer)
  140. self.epochs = cfg.get('epochs', None)
  141. if self.epochs:
  142. self.total_iters = self.epochs * self.iters_per_epoch
  143. self.by_epoch = True
  144. else:
  145. self.by_epoch = False
  146. self.total_iters = cfg.total_iters
  147. if self.by_epoch:
  148. self.weight_interval *= self.iters_per_epoch
  149. self.validate_interval = -1
  150. if cfg.get('validate', None) is not None:
  151. self.validate_interval = cfg.validate.get('interval', -1)
  152. self.time_count = {}
  153. self.best_metric = {}
  154. self.model.set_total_iter(self.total_iters)
  155. self.profiler_options = cfg.profiler_options
  156. def distributed_data_parallel(self):
  157. paddle.distributed.init_parallel_env()
  158. find_unused_parameters = self.cfg.get('find_unused_parameters', False)
  159. for net_name, net in self.model.nets.items():
  160. self.model.nets[net_name] = paddle.DataParallel(
  161. net, find_unused_parameters=find_unused_parameters)
  162. def learning_rate_scheduler_step(self):
  163. if isinstance(self.model.lr_scheduler, dict):
  164. for lr_scheduler in self.model.lr_scheduler.values():
  165. lr_scheduler.step()
  166. elif isinstance(self.model.lr_scheduler,
  167. paddle.optimizer.lr.LRScheduler):
  168. self.model.lr_scheduler.step()
  169. else:
  170. raise ValueError(
  171. 'lr schedulter must be a dict or an instance of LRScheduler')
  172. def train(self):
  173. reader_cost_averager = TimeAverager()
  174. batch_cost_averager = TimeAverager()
  175. iter_loader = IterLoader(self.train_dataloader)
  176. # set model.is_train = True
  177. self.model.setup_train_mode(is_train=True)
  178. while self.current_iter < (self.total_iters + 1):
  179. self.current_epoch = iter_loader.epoch
  180. self.inner_iter = self.current_iter % self.iters_per_epoch
  181. add_profiler_step(self.profiler_options)
  182. start_time = step_start_time = time.time()
  183. data = next(iter_loader)
  184. reader_cost_averager.record(time.time() - step_start_time)
  185. # unpack data from dataset and apply preprocessing
  186. # data input should be dict
  187. self.model.setup_input(data)
  188. self.model.train_iter(self.optimizers)
  189. batch_cost_averager.record(
  190. time.time() - step_start_time,
  191. num_samples=self.cfg['dataset']['train'].get('batch_size', 1))
  192. step_start_time = time.time()
  193. if self.current_iter % self.log_interval == 0:
  194. self.data_time = reader_cost_averager.get_average()
  195. self.step_time = batch_cost_averager.get_average()
  196. self.ips = batch_cost_averager.get_ips_average()
  197. self.print_log()
  198. reader_cost_averager.reset()
  199. batch_cost_averager.reset()
  200. if self.current_iter % self.visual_interval == 0 and self.local_rank == 0:
  201. self.visual('visual_train')
  202. self.learning_rate_scheduler_step()
  203. if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0:
  204. self.test()
  205. if self.current_iter % self.weight_interval == 0:
  206. self.save(self.current_iter, 'weight', keep=-1)
  207. self.save(self.current_iter)
  208. self.current_iter += 1
  209. def test(self):
  210. if not hasattr(self, 'test_dataloader'):
  211. self.test_dataloader = build_dataloader(
  212. self.cfg.dataset.test, is_train=False)
  213. iter_loader = IterLoader(self.test_dataloader)
  214. if self.max_eval_steps is None:
  215. self.max_eval_steps = len(self.test_dataloader)
  216. if self.metrics:
  217. for metric in self.metrics.values():
  218. metric.reset()
  219. # set model.is_train = False
  220. self.model.setup_train_mode(is_train=False)
  221. for i in range(self.max_eval_steps):
  222. if self.max_eval_steps < self.log_interval or i % self.log_interval == 0:
  223. self.logger.info('Test iter: [%d/%d]' % (
  224. i * self.world_size, self.max_eval_steps * self.world_size))
  225. data = next(iter_loader)
  226. self.model.setup_input(data)
  227. self.model.test_iter(metrics=self.metrics)
  228. if self.is_save_img:
  229. visual_results = {}
  230. current_paths = self.model.get_image_paths()
  231. current_visuals = self.model.get_current_visuals()
  232. if len(current_visuals) > 0 and list(current_visuals.values())[
  233. 0].shape == 4:
  234. num_samples = list(current_visuals.values())[0].shape[0]
  235. else:
  236. num_samples = 1
  237. for j in range(num_samples):
  238. if j < len(current_paths):
  239. short_path = os.path.basename(current_paths[j])
  240. basename = os.path.splitext(short_path)[0]
  241. else:
  242. basename = '{:04d}_{:04d}'.format(i, j)
  243. for k, img_tensor in current_visuals.items():
  244. name = '%s_%s' % (basename, k)
  245. if len(img_tensor.shape) == 4:
  246. visual_results.update({name: img_tensor[j]})
  247. else:
  248. visual_results.update({name: img_tensor})
  249. self.visual(
  250. 'visual_test',
  251. visual_results=visual_results,
  252. step=self.batch_id,
  253. is_save_image=True)
  254. if self.metrics:
  255. for metric_name, metric in self.metrics.items():
  256. self.logger.info("Metric {}: {:.4f}".format(
  257. metric_name, metric.accumulate()))
  258. def print_log(self):
  259. losses = self.model.get_current_losses()
  260. message = ''
  261. if self.by_epoch:
  262. message += 'Epoch: %d/%d, iter: %d/%d ' % (
  263. self.current_epoch, self.epochs, self.inner_iter,
  264. self.iters_per_epoch)
  265. else:
  266. message += 'Iter: %d/%d ' % (self.current_iter, self.total_iters)
  267. message += f'lr: {self.current_learning_rate:.3e} '
  268. for k, v in losses.items():
  269. message += '%s: %.3f ' % (k, v)
  270. if self.enable_visualdl:
  271. self.vdl_logger.add_scalar(k, v, step=self.global_steps)
  272. if hasattr(self, 'step_time'):
  273. message += 'batch_cost: %.5f sec ' % self.step_time
  274. if hasattr(self, 'data_time'):
  275. message += 'reader_cost: %.5f sec ' % self.data_time
  276. if hasattr(self, 'ips'):
  277. message += 'ips: %.5f images/s ' % self.ips
  278. if hasattr(self, 'step_time'):
  279. eta = self.step_time * (self.total_iters - self.current_iter)
  280. eta = eta if eta > 0 else 0
  281. eta_str = str(datetime.timedelta(seconds=int(eta)))
  282. message += f'eta: {eta_str}'
  283. # print the message
  284. self.logger.info(message)
  285. @property
  286. def current_learning_rate(self):
  287. for optimizer in self.model.optimizers.values():
  288. return optimizer.get_lr()
  289. def visual(self,
  290. results_dir,
  291. visual_results=None,
  292. step=None,
  293. is_save_image=False):
  294. """
  295. visual the images, use visualdl or directly write to the directory
  296. Parameters:
  297. results_dir (str) -- directory name which contains saved images
  298. visual_results (dict) -- the results images dict
  299. step (int) -- global steps, used in visualdl
  300. is_save_image (bool) -- weather write to the directory or visualdl
  301. """
  302. self.model.compute_visuals()
  303. if visual_results is None:
  304. visual_results = self.model.get_current_visuals()
  305. min_max = self.cfg.get('min_max', None)
  306. if min_max is None:
  307. min_max = (-1., 1.)
  308. image_num = self.cfg.get('image_num', None)
  309. if (image_num is None) or (not self.enable_visualdl):
  310. image_num = 1
  311. for label, image in visual_results.items():
  312. image_numpy = tensor2img(image, min_max, image_num)
  313. if (not is_save_image) and self.enable_visualdl:
  314. self.vdl_logger.add_image(
  315. results_dir + '/' + label,
  316. image_numpy,
  317. step=step if step else self.global_steps,
  318. dataformats="HWC" if image_num == 1 else "NCHW")
  319. else:
  320. if self.cfg.is_train:
  321. if self.by_epoch:
  322. msg = 'epoch%.3d_' % self.current_epoch
  323. else:
  324. msg = 'iter%.3d_' % self.current_iter
  325. else:
  326. msg = ''
  327. makedirs(os.path.join(self.output_dir, results_dir))
  328. img_path = os.path.join(self.output_dir, results_dir,
  329. msg + '%s.png' % (label))
  330. save_image(image_numpy, img_path)
  331. def save(self, epoch, name='checkpoint', keep=1):
  332. if self.local_rank != 0:
  333. return
  334. assert name in ['checkpoint', 'weight']
  335. state_dicts = {}
  336. if self.by_epoch:
  337. save_filename = 'epoch_%s_%s.pdparams' % (
  338. epoch // self.iters_per_epoch, name)
  339. else:
  340. save_filename = 'iter_%s_%s.pdparams' % (epoch, name)
  341. os.makedirs(self.output_dir, exist_ok=True)
  342. save_path = os.path.join(self.output_dir, save_filename)
  343. for net_name, net in self.model.nets.items():
  344. state_dicts[net_name] = net.state_dict()
  345. if name == 'weight':
  346. save(state_dicts, save_path)
  347. return
  348. state_dicts['epoch'] = epoch
  349. for opt_name, opt in self.model.optimizers.items():
  350. state_dicts[opt_name] = opt.state_dict()
  351. save(state_dicts, save_path)
  352. if keep > 0:
  353. try:
  354. if self.by_epoch:
  355. checkpoint_name_to_be_removed = os.path.join(
  356. self.output_dir, 'epoch_%s_%s.pdparams' % (
  357. (epoch - keep * self.weight_interval) //
  358. self.iters_per_epoch, name))
  359. else:
  360. checkpoint_name_to_be_removed = os.path.join(
  361. self.output_dir, 'iter_%s_%s.pdparams' %
  362. (epoch - keep * self.weight_interval, name))
  363. if os.path.exists(checkpoint_name_to_be_removed):
  364. os.remove(checkpoint_name_to_be_removed)
  365. except Exception as e:
  366. self.logger.info('remove old checkpoints error: {}'.format(e))
  367. def resume(self, checkpoint_path):
  368. state_dicts = load(checkpoint_path)
  369. if state_dicts.get('epoch', None) is not None:
  370. self.start_epoch = state_dicts['epoch'] + 1
  371. self.global_steps = self.iters_per_epoch * state_dicts['epoch']
  372. self.current_iter = state_dicts['epoch'] + 1
  373. for net_name, net in self.model.nets.items():
  374. net.set_state_dict(state_dicts[net_name])
  375. for opt_name, opt in self.model.optimizers.items():
  376. opt.set_state_dict(state_dicts[opt_name])
  377. def load(self, weight_path):
  378. state_dicts = load(weight_path)
  379. for net_name, net in self.model.nets.items():
  380. if net_name in state_dicts:
  381. net.set_state_dict(state_dicts[net_name])
  382. self.logger.info('Loaded pretrained weight for net {}'.format(
  383. net_name))
  384. else:
  385. self.logger.warning(
  386. 'Can not find state dict of net {}. Skip load pretrained weight for net {}'
  387. .format(net_name, net_name))
  388. def close(self):
  389. """
  390. when finish the training need close file handler or other.
  391. """
  392. if self.enable_visualdl:
  393. self.vdl_logger.close()
  394. # 基础超分模型训练类
  395. class BasicSRNet:
  396. def __init__(self):
  397. self.model = {}
  398. self.optimizer = {}
  399. self.lr_scheduler = {}
  400. self.min_max = ''
  401. def train(
  402. self,
  403. total_iters,
  404. train_dataset,
  405. test_dataset,
  406. output_dir,
  407. validate,
  408. snapshot,
  409. log,
  410. lr_rate,
  411. evaluate_weights='',
  412. resume='',
  413. pretrain_weights='',
  414. periods=[100000],
  415. restart_weights=[1], ):
  416. self.lr_scheduler['learning_rate'] = lr_rate
  417. if self.lr_scheduler['name'] == 'CosineAnnealingRestartLR':
  418. self.lr_scheduler['periods'] = periods
  419. self.lr_scheduler['restart_weights'] = restart_weights
  420. validate = {
  421. 'interval': validate,
  422. 'save_img': False,
  423. 'metrics': {
  424. 'psnr': {
  425. 'name': 'PSNR',
  426. 'crop_border': 4,
  427. 'test_y_channel': True
  428. },
  429. 'ssim': {
  430. 'name': 'SSIM',
  431. 'crop_border': 4,
  432. 'test_y_channel': True
  433. }
  434. }
  435. }
  436. log_config = {'interval': log, 'visiual_interval': 500}
  437. snapshot_config = {'interval': snapshot}
  438. cfg = {
  439. 'total_iters': total_iters,
  440. 'output_dir': output_dir,
  441. 'min_max': self.min_max,
  442. 'model': self.model,
  443. 'dataset': {
  444. 'train': train_dataset,
  445. 'test': test_dataset
  446. },
  447. 'lr_scheduler': self.lr_scheduler,
  448. 'optimizer': self.optimizer,
  449. 'validate': validate,
  450. 'log_config': log_config,
  451. 'snapshot_config': snapshot_config
  452. }
  453. cfg = AttrDict(cfg)
  454. create_attr_dict(cfg)
  455. cfg.is_train = True
  456. cfg.profiler_options = None
  457. cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime())
  458. if cfg.model.name == 'BaseSRModel':
  459. floderModelName = cfg.model.generator.name
  460. else:
  461. floderModelName = cfg.model.name
  462. cfg.output_dir = os.path.join(cfg.output_dir,
  463. floderModelName + cfg.timestamp)
  464. logger_cfg = setup_logger(cfg.output_dir)
  465. logger_cfg.info('Configs: {}'.format(cfg))
  466. if paddle.is_compiled_with_cuda():
  467. paddle.set_device('gpu')
  468. else:
  469. paddle.set_device('cpu')
  470. # build trainer
  471. trainer = Restorer(cfg, logger_cfg)
  472. # continue train or evaluate, checkpoint need contain epoch and optimizer info
  473. if len(resume) > 0:
  474. trainer.resume(resume)
  475. # evaluate or finute, only load generator weights
  476. elif len(pretrain_weights) > 0:
  477. trainer.load(pretrain_weights)
  478. if len(evaluate_weights) > 0:
  479. trainer.load(evaluate_weights)
  480. trainer.test()
  481. return
  482. # training, when keyboard interrupt save weights
  483. try:
  484. trainer.train()
  485. except KeyboardInterrupt as e:
  486. trainer.save(trainer.current_epoch)
  487. trainer.close()
  488. # DRN模型训练
  489. class DRNet(BasicSRNet):
  490. def __init__(self,
  491. n_blocks=30,
  492. n_feats=16,
  493. n_colors=3,
  494. rgb_range=255,
  495. negval=0.2):
  496. super(DRNet, self).__init__()
  497. self.min_max = '(0., 255.)'
  498. self.generator = {
  499. 'name': 'DRNGenerator',
  500. 'scale': (2, 4),
  501. 'n_blocks': n_blocks,
  502. 'n_feats': n_feats,
  503. 'n_colors': n_colors,
  504. 'rgb_range': rgb_range,
  505. 'negval': negval
  506. }
  507. self.pixel_criterion = {'name': 'L1Loss'}
  508. self.model = {
  509. 'name': 'DRN',
  510. 'generator': self.generator,
  511. 'pixel_criterion': self.pixel_criterion
  512. }
  513. self.optimizer = {
  514. 'optimG': {
  515. 'name': 'Adam',
  516. 'net_names': ['generator'],
  517. 'weight_decay': 0.0,
  518. 'beta1': 0.9,
  519. 'beta2': 0.999
  520. },
  521. 'optimD': {
  522. 'name': 'Adam',
  523. 'net_names': ['dual_model_0', 'dual_model_1'],
  524. 'weight_decay': 0.0,
  525. 'beta1': 0.9,
  526. 'beta2': 0.999
  527. }
  528. }
  529. self.lr_scheduler = {
  530. 'name': 'CosineAnnealingRestartLR',
  531. 'eta_min': 1e-07
  532. }
  533. # 轻量化超分模型LESRCNN训练
  534. class LESRCNNet(BasicSRNet):
  535. def __init__(self, scale=4, multi_scale=False, group=1):
  536. super(LESRCNNet, self).__init__()
  537. self.min_max = '(0., 1.)'
  538. self.generator = {
  539. 'name': 'LESRCNNGenerator',
  540. 'scale': scale,
  541. 'multi_scale': False,
  542. 'group': 1
  543. }
  544. self.pixel_criterion = {'name': 'L1Loss'}
  545. self.model = {
  546. 'name': 'BaseSRModel',
  547. 'generator': self.generator,
  548. 'pixel_criterion': self.pixel_criterion
  549. }
  550. self.optimizer = {
  551. 'name': 'Adam',
  552. 'net_names': ['generator'],
  553. 'beta1': 0.9,
  554. 'beta2': 0.99
  555. }
  556. self.lr_scheduler = {
  557. 'name': 'CosineAnnealingRestartLR',
  558. 'eta_min': 1e-07
  559. }
  560. # ESRGAN模型训练
  561. # 若loss_type='gan' 使用感知损失、对抗损失和像素损失
  562. # 若loss_type = 'pixel' 只使用像素损失
  563. class ESRGANet(BasicSRNet):
  564. def __init__(self, loss_type='gan', in_nc=3, out_nc=3, nf=64, nb=23):
  565. super(ESRGANet, self).__init__()
  566. self.min_max = '(0., 1.)'
  567. self.generator = {
  568. 'name': 'RRDBNet',
  569. 'in_nc': in_nc,
  570. 'out_nc': out_nc,
  571. 'nf': nf,
  572. 'nb': nb
  573. }
  574. if loss_type == 'gan':
  575. # 定义损失函数
  576. self.pixel_criterion = {'name': 'L1Loss', 'loss_weight': 0.01}
  577. self.discriminator = {
  578. 'name': 'VGGDiscriminator128',
  579. 'in_channels': 3,
  580. 'num_feat': 64
  581. }
  582. self.perceptual_criterion = {
  583. 'name': 'PerceptualLoss',
  584. 'layer_weights': {
  585. '34': 1.0
  586. },
  587. 'perceptual_weight': 1.0,
  588. 'style_weight': 0.0,
  589. 'norm_img': False
  590. }
  591. self.gan_criterion = {
  592. 'name': 'GANLoss',
  593. 'gan_mode': 'vanilla',
  594. 'loss_weight': 0.005
  595. }
  596. # 定义模型
  597. self.model = {
  598. 'name': 'ESRGAN',
  599. 'generator': self.generator,
  600. 'discriminator': self.discriminator,
  601. 'pixel_criterion': self.pixel_criterion,
  602. 'perceptual_criterion': self.perceptual_criterion,
  603. 'gan_criterion': self.gan_criterion
  604. }
  605. self.optimizer = {
  606. 'optimG': {
  607. 'name': 'Adam',
  608. 'net_names': ['generator'],
  609. 'weight_decay': 0.0,
  610. 'beta1': 0.9,
  611. 'beta2': 0.99
  612. },
  613. 'optimD': {
  614. 'name': 'Adam',
  615. 'net_names': ['discriminator'],
  616. 'weight_decay': 0.0,
  617. 'beta1': 0.9,
  618. 'beta2': 0.99
  619. }
  620. }
  621. self.lr_scheduler = {
  622. 'name': 'MultiStepDecay',
  623. 'milestones': [50000, 100000, 200000, 300000],
  624. 'gamma': 0.5
  625. }
  626. else:
  627. self.pixel_criterion = {'name': 'L1Loss'}
  628. self.model = {
  629. 'name': 'BaseSRModel',
  630. 'generator': self.generator,
  631. 'pixel_criterion': self.pixel_criterion
  632. }
  633. self.optimizer = {
  634. 'name': 'Adam',
  635. 'net_names': ['generator'],
  636. 'beta1': 0.9,
  637. 'beta2': 0.99
  638. }
  639. self.lr_scheduler = {
  640. 'name': 'CosineAnnealingRestartLR',
  641. 'eta_min': 1e-07
  642. }
  643. # RCAN模型训练
  644. class RCANet(BasicSRNet):
  645. def __init__(
  646. self,
  647. scale=2,
  648. n_resgroups=10,
  649. n_resblocks=20, ):
  650. super(RCANet, self).__init__()
  651. self.min_max = '(0., 255.)'
  652. self.generator = {
  653. 'name': 'RCAN',
  654. 'scale': scale,
  655. 'n_resgroups': n_resgroups,
  656. 'n_resblocks': n_resblocks
  657. }
  658. self.pixel_criterion = {'name': 'L1Loss'}
  659. self.model = {
  660. 'name': 'RCANModel',
  661. 'generator': self.generator,
  662. 'pixel_criterion': self.pixel_criterion
  663. }
  664. self.optimizer = {
  665. 'name': 'Adam',
  666. 'net_names': ['generator'],
  667. 'beta1': 0.9,
  668. 'beta2': 0.99
  669. }
  670. self.lr_scheduler = {
  671. 'name': 'MultiStepDecay',
  672. 'milestones': [250000, 500000, 750000, 1000000],
  673. 'gamma': 0.5
  674. }