|
@@ -54,7 +54,7 @@ class BaseRestorer(BaseModel):
|
|
|
|
|
|
def build_net(self, **params):
|
|
|
# Currently, only use models from cmres.
|
|
|
- if not hasattr(cmres, model_name):
|
|
|
+ if not hasattr(cmres, self.model_name):
|
|
|
raise ValueError("ERROR: There is no model named {}.".format(
|
|
|
model_name))
|
|
|
net = dict(**cmres.__dict__)[self.model_name](**params)
|
|
@@ -618,7 +618,6 @@ class RCAN(BaseRestorer):
|
|
|
reduction=16,
|
|
|
**params):
|
|
|
params.update({
|
|
|
- 'factor': sr_factor,
|
|
|
'n_resgroups': n_resgroups,
|
|
|
'n_resblocks': n_resblocks,
|
|
|
'n_feats': n_feats,
|
|
@@ -661,8 +660,17 @@ class DRN(BaseRestorer):
|
|
|
|
|
|
|
|
|
class LESRCNN(BaseRestorer):
|
|
|
- def __init__(self, losses=None, sr_factor=4, multi_scale=False, group=1):
|
|
|
- params.update({'scale': sr_factor, 'multi_scale': False, 'group': 1})
|
|
|
+ def __init__(self,
|
|
|
+ losses=None,
|
|
|
+ sr_factor=4,
|
|
|
+ multi_scale=False,
|
|
|
+ group=1,
|
|
|
+ **params):
|
|
|
+ params.update({
|
|
|
+ 'scale': sr_factor,
|
|
|
+ 'multi_scale': multi_scale,
|
|
|
+ 'group': group
|
|
|
+ })
|
|
|
super(LESRCNN, self).__init__(
|
|
|
model_name='LESRCNN', losses=losses, sr_factor=sr_factor, **params)
|
|
|
|
|
@@ -681,9 +689,11 @@ class ESRGAN(BaseRestorer):
|
|
|
in_channels=3,
|
|
|
out_channels=3,
|
|
|
nf=64,
|
|
|
- nb=23):
|
|
|
+ nb=23,
|
|
|
+ **params):
|
|
|
+ if sr_factor != 4:
|
|
|
+ raise ValueError("`sr_factor` must be 4.")
|
|
|
params.update({
|
|
|
- 'scale': sr_factor,
|
|
|
'in_nc': in_channels,
|
|
|
'out_nc': out_channels,
|
|
|
'nf': nf,
|
|
@@ -696,7 +706,7 @@ class ESRGAN(BaseRestorer):
|
|
|
def build_net(self, **params):
|
|
|
generator = ppgan.models.generators.RRDBNet(**params)
|
|
|
if self.use_gan:
|
|
|
- discriminator = ppgan.models.discriminators.VGGDiscrinimator128(
|
|
|
+ discriminator = ppgan.models.discriminators.VGGDiscriminator128(
|
|
|
in_channels=params['out_nc'], num_feat=64)
|
|
|
net = GANAdapter(
|
|
|
generators=[generator], discriminators=[discriminator])
|
|
@@ -719,9 +729,9 @@ class ESRGAN(BaseRestorer):
|
|
|
def default_optimizer(self, parameters, *args, **kwargs):
|
|
|
if self.use_gan:
|
|
|
optim_g = super(ESRGAN, self).default_optimizer(
|
|
|
- parameters['optims_g'][0], *args, **kwargs)
|
|
|
+ parameters['params_g'][0], *args, **kwargs)
|
|
|
optim_d = super(ESRGAN, self).default_optimizer(
|
|
|
- parameters['optims_d'][0], *args, **kwargs)
|
|
|
+ parameters['params_d'][0], *args, **kwargs)
|
|
|
return OptimizerAdapter(optim_g, optim_d)
|
|
|
else:
|
|
|
return super(ESRGAN, self).default_optimizer(params, *args,
|
|
@@ -777,14 +787,17 @@ class ESRGAN(BaseRestorer):
|
|
|
if self.use_gan:
|
|
|
optim_g, optim_d = self.optimizer
|
|
|
|
|
|
- outputs = self.run_gan(net, data, gan_mode='forward_g')
|
|
|
+ outputs = self.run_gan(
|
|
|
+ net, data, mode='train', gan_mode='forward_g')
|
|
|
optim_g.clear_grad()
|
|
|
(outputs['loss_g_pps'] + outputs['loss_g_gan']).backward()
|
|
|
optim_g.step()
|
|
|
|
|
|
outputs.update(
|
|
|
self.run_gan(
|
|
|
- net, (outputs['g_pred'], data[1]), gan_mode='forward_d'))
|
|
|
+ net, (outputs['g_pred'], data[1]),
|
|
|
+ mode='train',
|
|
|
+ gan_mode='forward_d'))
|
|
|
optim_d.clear_grad()
|
|
|
outputs['loss_d'].backward()
|
|
|
optim_d.step()
|