|
@@ -25,6 +25,7 @@ from paddle.static import InputSpec
|
|
|
import paddlers
|
|
|
import paddlers.models.ppgan as ppgan
|
|
|
import paddlers.rs_models.res as cmres
|
|
|
+import paddlers.models.ppgan.metrics as metrics
|
|
|
import paddlers.utils.logging as logging
|
|
|
from paddlers.models import res_losses
|
|
|
from paddlers.transforms import Resize, decode_image
|
|
@@ -32,12 +33,14 @@ from paddlers.transforms.functions import calc_hr_shape
|
|
|
from paddlers.utils import get_single_card_bs
|
|
|
from .base import BaseModel
|
|
|
from .utils.res_adapters import GANAdapter, OptimizerAdapter
|
|
|
+from .utils.infer_nets import InferResNet
|
|
|
|
|
|
__all__ = []
|
|
|
|
|
|
|
|
|
class BaseRestorer(BaseModel):
|
|
|
- MIN_MAX = (0., 255.)
|
|
|
+ MIN_MAX = (0., 1.)
|
|
|
+ TEST_OUT_KEY = None
|
|
|
|
|
|
def __init__(self, model_name, losses=None, sr_factor=None, **params):
|
|
|
self.init_params = locals()
|
|
@@ -63,9 +66,10 @@ class BaseRestorer(BaseModel):
|
|
|
def _build_inference_net(self):
|
|
|
# For GAN models, only the generator will be used for inference.
|
|
|
if isinstance(self.net, GANAdapter):
|
|
|
- infer_net = self.net.generator
|
|
|
+ infer_net = InferResNet(
|
|
|
+ self.net.generator, out_key=self.TEST_OUT_KEY)
|
|
|
else:
|
|
|
- infer_net = self.net
|
|
|
+ infer_net = InferResNet(self.net, out_key=self.TEST_OUT_KEY)
|
|
|
infer_net.eval()
|
|
|
return infer_net
|
|
|
|
|
@@ -108,15 +112,18 @@ class BaseRestorer(BaseModel):
|
|
|
outputs = OrderedDict()
|
|
|
|
|
|
if mode == 'test':
|
|
|
- if isinstance(net, GANAdapter):
|
|
|
- net_out = net.generator(inputs[0])
|
|
|
- else:
|
|
|
- net_out = net(inputs[0])
|
|
|
tar_shape = inputs[1]
|
|
|
if self.status == 'Infer':
|
|
|
+ net_out = net(inputs[0])
|
|
|
res_map_list = self._postprocess(
|
|
|
net_out, tar_shape, transforms=inputs[2])
|
|
|
else:
|
|
|
+ if isinstance(net, GANAdapter):
|
|
|
+ net_out = net.generator(inputs[0])
|
|
|
+ else:
|
|
|
+ net_out = net(inputs[0])
|
|
|
+ if self.TEST_OUT_KEY is not None:
|
|
|
+ net_out = net_out[self.TEST_OUT_KEY]
|
|
|
pred = self._postprocess(
|
|
|
net_out, tar_shape, transforms=inputs[2])
|
|
|
res_map_list = []
|
|
@@ -130,13 +137,15 @@ class BaseRestorer(BaseModel):
|
|
|
net_out = net.generator(inputs[0])
|
|
|
else:
|
|
|
net_out = net(inputs[0])
|
|
|
+ if self.TEST_OUT_KEY is not None:
|
|
|
+ net_out = net_out[self.TEST_OUT_KEY]
|
|
|
tar = inputs[1]
|
|
|
tar_shape = [tar.shape[-2:]]
|
|
|
pred = self._postprocess(
|
|
|
net_out, tar_shape, transforms=inputs[2])[0] # NCHW
|
|
|
pred = self._tensor_to_images(pred)
|
|
|
outputs['pred'] = pred
|
|
|
- tar = self.tensor_to_images(tar)
|
|
|
+ tar = self._tensor_to_images(tar)
|
|
|
outputs['tar'] = tar
|
|
|
|
|
|
if mode == 'train':
|
|
@@ -386,10 +395,11 @@ class BaseRestorer(BaseModel):
|
|
|
self.eval_data_loader = self.build_data_loader(
|
|
|
eval_dataset, batch_size=batch_size, mode='eval')
|
|
|
# XXX: Hard-code crop_border and test_y_channel
|
|
|
- psnr = ppgan.metrics.PSNR(crop_border=4, test_y_channel=True)
|
|
|
- ssim = ppgan.metrics.SSIM(crop_border=4, test_y_channel=True)
|
|
|
+ psnr = metrics.PSNR(crop_border=4, test_y_channel=True)
|
|
|
+ ssim = metrics.SSIM(crop_border=4, test_y_channel=True)
|
|
|
with paddle.no_grad():
|
|
|
for step, data in enumerate(self.eval_data_loader):
|
|
|
+ data.append(eval_dataset.transforms.transforms)
|
|
|
outputs = self.run(self.net, data, 'eval')
|
|
|
psnr.update(outputs['pred'], outputs['tar'])
|
|
|
ssim.update(outputs['pred'], outputs['tar'])
|
|
@@ -520,10 +530,9 @@ class BaseRestorer(BaseModel):
|
|
|
def _postprocess(self, batch_pred, batch_tar_shape, transforms):
|
|
|
batch_restore_list = BaseRestorer.get_transforms_shape_info(
|
|
|
batch_tar_shape, transforms)
|
|
|
- if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
|
|
|
+ if self.status == 'Infer':
|
|
|
return self._infer_postprocess(
|
|
|
- batch_res_map=batch_pred[0],
|
|
|
- batch_restore_list=batch_restore_list)
|
|
|
+ batch_res_map=batch_pred, batch_restore_list=batch_restore_list)
|
|
|
results = []
|
|
|
if batch_pred.dtype == paddle.float32:
|
|
|
mode = 'bilinear'
|
|
@@ -546,7 +555,7 @@ class BaseRestorer(BaseModel):
|
|
|
|
|
|
def _infer_postprocess(self, batch_res_map, batch_restore_list):
|
|
|
res_maps = []
|
|
|
- for score_map, restore_list in zip(batch_res_map, batch_restore_list):
|
|
|
+ for res_map, restore_list in zip(batch_res_map, batch_restore_list):
|
|
|
if not isinstance(res_map, np.ndarray):
|
|
|
res_map = paddle.unsqueeze(res_map, axis=0)
|
|
|
for item in restore_list[::-1]:
|
|
@@ -557,15 +566,15 @@ class BaseRestorer(BaseModel):
|
|
|
res_map, (w, h), interpolation=cv2.INTER_LINEAR)
|
|
|
else:
|
|
|
res_map = F.interpolate(
|
|
|
- score_map, (h, w),
|
|
|
+ res_map, (h, w),
|
|
|
mode='bilinear',
|
|
|
data_format='NHWC')
|
|
|
elif item[0] == 'padding':
|
|
|
x, y = item[2]
|
|
|
if isinstance(res_map, np.ndarray):
|
|
|
- res_map = res_map[..., y:y + h, x:x + w]
|
|
|
+ res_map = res_map[y:y + h, x:x + w]
|
|
|
else:
|
|
|
- res_map = res_map[:, :, y:y + h, x:x + w]
|
|
|
+ res_map = res_map[:, y:y + h, x:x + w, :]
|
|
|
else:
|
|
|
pass
|
|
|
res_map = res_map.squeeze()
|
|
@@ -585,18 +594,25 @@ class BaseRestorer(BaseModel):
|
|
|
def set_losses(self, losses):
|
|
|
self.losses = losses
|
|
|
|
|
|
- def _tensor_to_images(self, tensor, squeeze=True, quantize=True):
|
|
|
- tensor = paddle.transpose(tensor, perm=[0, 2, 3, 1]) # NHWC
|
|
|
+ def _tensor_to_images(self,
|
|
|
+ tensor,
|
|
|
+ transpose=True,
|
|
|
+ squeeze=True,
|
|
|
+ quantize=True):
|
|
|
+ if transpose:
|
|
|
+ tensor = paddle.transpose(tensor, perm=[0, 2, 3, 1]) # NHWC
|
|
|
if squeeze:
|
|
|
tensor = tensor.squeeze()
|
|
|
images = tensor.numpy().astype('float32')
|
|
|
- images = np.clip(images, self.MIN_MAX[0], self.MIN_MAX[1])
|
|
|
- images = self._normalize(images, copy=True, quantize=quantize)
|
|
|
+ images = self._normalize(
|
|
|
+ images, copy=True, clip=True, quantize=quantize)
|
|
|
return images
|
|
|
|
|
|
- def _normalize(self, im, copy=False, quantize=True):
|
|
|
+ def _normalize(self, im, copy=False, clip=True, quantize=True):
|
|
|
if copy:
|
|
|
im = im.copy()
|
|
|
+ if clip:
|
|
|
+ im = np.clip(im, self.MIN_MAX[0], self.MIN_MAX[1])
|
|
|
im -= im.min()
|
|
|
im /= im.max() + 1e-32
|
|
|
if quantize:
|
|
@@ -605,32 +621,9 @@ class BaseRestorer(BaseModel):
|
|
|
return im
|
|
|
|
|
|
|
|
|
-class RCAN(BaseRestorer):
|
|
|
- def __init__(self,
|
|
|
- losses=None,
|
|
|
- sr_factor=4,
|
|
|
- n_resgroups=10,
|
|
|
- n_resblocks=20,
|
|
|
- n_feats=64,
|
|
|
- n_colors=3,
|
|
|
- rgb_range=255,
|
|
|
- kernel_size=3,
|
|
|
- reduction=16,
|
|
|
- **params):
|
|
|
- params.update({
|
|
|
- 'n_resgroups': n_resgroups,
|
|
|
- 'n_resblocks': n_resblocks,
|
|
|
- 'n_feats': n_feats,
|
|
|
- 'n_colors': n_colors,
|
|
|
- 'rgb_range': rgb_range,
|
|
|
- 'kernel_size': kernel_size,
|
|
|
- 'reduction': reduction
|
|
|
- })
|
|
|
- super(RCAN, self).__init__(
|
|
|
- model_name='RCAN', losses=losses, sr_factor=sr_factor, **params)
|
|
|
-
|
|
|
-
|
|
|
class DRN(BaseRestorer):
|
|
|
+ TEST_OUT_KEY = -1
|
|
|
+
|
|
|
def __init__(self,
|
|
|
losses=None,
|
|
|
sr_factor=4,
|
|
@@ -638,8 +631,10 @@ class DRN(BaseRestorer):
|
|
|
n_blocks=30,
|
|
|
n_feats=16,
|
|
|
n_colors=3,
|
|
|
- rgb_range=255,
|
|
|
+ rgb_range=1.0,
|
|
|
negval=0.2,
|
|
|
+ lq_loss_weight=0.1,
|
|
|
+ dual_loss_weight=0.1,
|
|
|
**params):
|
|
|
if sr_factor != max(scale):
|
|
|
raise ValueError(f"`sr_factor` must be equal to `max(scale)`.")
|
|
@@ -651,12 +646,80 @@ class DRN(BaseRestorer):
|
|
|
'rgb_range': rgb_range,
|
|
|
'negval': negval
|
|
|
})
|
|
|
+ self.lq_loss_weight = lq_loss_weight
|
|
|
+ self.dual_loss_weight = dual_loss_weight
|
|
|
super(DRN, self).__init__(
|
|
|
model_name='DRN', losses=losses, sr_factor=sr_factor, **params)
|
|
|
|
|
|
def build_net(self, **params):
|
|
|
- net = ppgan.models.generators.DRNGenerator(**params)
|
|
|
- return net
|
|
|
+ from ppgan.modules.init import init_weights
|
|
|
+ generators = [ppgan.models.generators.DRNGenerator(**params)]
|
|
|
+ init_weights(generators[-1])
|
|
|
+ for scale in params['scale']:
|
|
|
+ dual_model = ppgan.models.generators.drn.DownBlock(
|
|
|
+ params['negval'], params['n_feats'], params['n_colors'], 2)
|
|
|
+ generators.append(dual_model)
|
|
|
+ init_weights(generators[-1])
|
|
|
+ return GANAdapter(generators, [])
|
|
|
+
|
|
|
+ def default_optimizer(self, parameters, *args, **kwargs):
|
|
|
+ optims_g = [
|
|
|
+ super(DRN, self).default_optimizer(params_g, *args, **kwargs)
|
|
|
+ for params_g in parameters['params_g']
|
|
|
+ ]
|
|
|
+ return OptimizerAdapter(*optims_g)
|
|
|
+
|
|
|
+ def run_gan(self, net, inputs, mode, gan_mode='forward_primary'):
|
|
|
+ if mode != 'train':
|
|
|
+ raise ValueError("`mode` is not 'train'.")
|
|
|
+ outputs = OrderedDict()
|
|
|
+ if gan_mode == 'forward_primary':
|
|
|
+ sr = net.generator(inputs[0])
|
|
|
+ lr = [inputs[0]]
|
|
|
+ lr.extend([
|
|
|
+ F.interpolate(
|
|
|
+ inputs[0], scale_factor=s, mode='bicubic')
|
|
|
+ for s in net.generator.scale[:-1]
|
|
|
+ ])
|
|
|
+ loss = self.losses(sr[-1], inputs[1])
|
|
|
+ for i in range(1, len(sr)):
|
|
|
+ if self.lq_loss_weight > 0:
|
|
|
+ loss += self.losses(sr[i - 1 - len(sr)],
|
|
|
+ lr[i - len(sr)]) * self.lq_loss_weight
|
|
|
+ outputs['loss_prim'] = loss
|
|
|
+ outputs['sr'] = sr
|
|
|
+ outputs['lr'] = lr
|
|
|
+ elif gan_mode == 'forward_dual':
|
|
|
+ sr, lr = inputs[0], inputs[1]
|
|
|
+ sr2lr = []
|
|
|
+ n_scales = len(net.generator.scale)
|
|
|
+ for i in range(n_scales):
|
|
|
+ sr2lr_i = net.generators[1 + i](sr[i - n_scales])
|
|
|
+ sr2lr.append(sr2lr_i)
|
|
|
+ loss = self.losses(sr2lr[0], lr[0])
|
|
|
+ for i in range(1, n_scales):
|
|
|
+ if self.dual_loss_weight > 0.0:
|
|
|
+ loss += self.losses(sr2lr[i], lr[i]) * self.dual_loss_weight
|
|
|
+ outputs['loss_dual'] = loss
|
|
|
+ else:
|
|
|
+ raise ValueError("Invalid `gan_mode`!")
|
|
|
+ return outputs
|
|
|
+
|
|
|
+ def train_step(self, step, data, net):
|
|
|
+ outputs = self.run_gan(
|
|
|
+ net, data, mode='train', gan_mode='forward_primary')
|
|
|
+ outputs.update(
|
|
|
+ self.run_gan(
|
|
|
+ net, (outputs['sr'], outputs['lr']),
|
|
|
+ mode='train',
|
|
|
+ gan_mode='forward_dual'))
|
|
|
+ self.optimizer.clear_grad()
|
|
|
+ (outputs['loss_prim'] + outputs['loss_dual']).backward()
|
|
|
+ self.optimizer.step()
|
|
|
+ return {
|
|
|
+ 'loss_prim': outputs['loss_prim'],
|
|
|
+ 'loss_dual': outputs['loss_dual']
|
|
|
+ }
|
|
|
|
|
|
|
|
|
class LESRCNN(BaseRestorer):
|
|
@@ -680,8 +743,6 @@ class LESRCNN(BaseRestorer):
|
|
|
|
|
|
|
|
|
class ESRGAN(BaseRestorer):
|
|
|
- MIN_MAX = (0., 1.)
|
|
|
-
|
|
|
def __init__(self,
|
|
|
losses=None,
|
|
|
sr_factor=4,
|
|
@@ -704,7 +765,9 @@ class ESRGAN(BaseRestorer):
|
|
|
model_name='ESRGAN', losses=losses, sr_factor=sr_factor, **params)
|
|
|
|
|
|
def build_net(self, **params):
|
|
|
+ from ppgan.modules.init import init_weights
|
|
|
generator = ppgan.models.generators.RRDBNet(**params)
|
|
|
+ init_weights(generator)
|
|
|
if self.use_gan:
|
|
|
discriminator = ppgan.models.discriminators.VGGDiscriminator128(
|
|
|
in_channels=params['out_nc'], num_feat=64)
|
|
@@ -716,10 +779,13 @@ class ESRGAN(BaseRestorer):
|
|
|
|
|
|
def default_loss(self):
|
|
|
if self.use_gan:
|
|
|
- self.losses = {
|
|
|
+ return {
|
|
|
'pixel': res_losses.L1Loss(loss_weight=0.01),
|
|
|
- 'perceptual':
|
|
|
- res_losses.PerceptualLoss(layer_weights={'34': 1.0}),
|
|
|
+ 'perceptual': res_losses.PerceptualLoss(
|
|
|
+ layer_weights={'34': 1.0},
|
|
|
+ perceptual_weight=1.0,
|
|
|
+ style_weight=0.0,
|
|
|
+ norm_img=False),
|
|
|
'gan': res_losses.GANLoss(
|
|
|
gan_mode='vanilla', loss_weight=0.005)
|
|
|
}
|
|
@@ -734,7 +800,7 @@ class ESRGAN(BaseRestorer):
|
|
|
parameters['params_d'][0], *args, **kwargs)
|
|
|
return OptimizerAdapter(optim_g, optim_d)
|
|
|
else:
|
|
|
- return super(ESRGAN, self).default_optimizer(params, *args,
|
|
|
+ return super(ESRGAN, self).default_optimizer(parameters, *args,
|
|
|
**kwargs)
|
|
|
|
|
|
def run_gan(self, net, inputs, mode, gan_mode='forward_g'):
|
|
@@ -744,8 +810,8 @@ class ESRGAN(BaseRestorer):
|
|
|
if gan_mode == 'forward_g':
|
|
|
loss_g = 0
|
|
|
g_pred = net.generator(inputs[0])
|
|
|
- loss_pix = self.losses['pixel'](g_pred, tar)
|
|
|
- loss_perc, loss_sty = self.losses['perceptual'](g_pred, tar)
|
|
|
+ loss_pix = self.losses['pixel'](g_pred, inputs[1])
|
|
|
+ loss_perc, loss_sty = self.losses['perceptual'](g_pred, inputs[1])
|
|
|
loss_g += loss_pix
|
|
|
if loss_perc is not None:
|
|
|
loss_g += loss_perc
|
|
@@ -767,14 +833,14 @@ class ESRGAN(BaseRestorer):
|
|
|
elif gan_mode == 'forward_d':
|
|
|
self._set_requires_grad(net.discriminator, True)
|
|
|
# Real
|
|
|
- fake_d_pred = net.discriminator(data[0]).detach()
|
|
|
- real_d_pred = net.discriminator(data[1])
|
|
|
+ fake_d_pred = net.discriminator(inputs[0]).detach()
|
|
|
+ real_d_pred = net.discriminator(inputs[1])
|
|
|
loss_d_real = self.losses['gan'](
|
|
|
real_d_pred - paddle.mean(fake_d_pred), True,
|
|
|
is_disc=True) * 0.5
|
|
|
# Fake
|
|
|
- fake_d_pred = self.nets['discriminator'](self.output.detach())
|
|
|
- loss_d_fake = self.gan_criterion(
|
|
|
+ fake_d_pred = net.discriminator(inputs[0].detach())
|
|
|
+ loss_d_fake = self.losses['gan'](
|
|
|
fake_d_pred - paddle.mean(real_d_pred.detach()),
|
|
|
False,
|
|
|
is_disc=True) * 0.5
|
|
@@ -802,30 +868,43 @@ class ESRGAN(BaseRestorer):
|
|
|
outputs['loss_d'].backward()
|
|
|
optim_d.step()
|
|
|
|
|
|
- outputs['loss'] = outupts['loss_g_pps'] + outputs[
|
|
|
+ outputs['loss'] = outputs['loss_g_pps'] + outputs[
|
|
|
'loss_g_gan'] + outputs['loss_d']
|
|
|
|
|
|
- if isinstance(optim_g._learning_rate,
|
|
|
- paddle.optimizer.lr.LRScheduler):
|
|
|
- # If ReduceOnPlateau is used as the scheduler, use the loss value as the metric.
|
|
|
- if isinstance(optim_g._learning_rate,
|
|
|
- paddle.optimizer.lr.ReduceOnPlateau):
|
|
|
- optim_g._learning_rate.step(loss.item())
|
|
|
- else:
|
|
|
- optim_g._learning_rate.step()
|
|
|
-
|
|
|
- if isinstance(optim_d._learning_rate,
|
|
|
- paddle.optimizer.lr.LRScheduler):
|
|
|
- if isinstance(optim_d._learning_rate,
|
|
|
- paddle.optimizer.lr.ReduceOnPlateau):
|
|
|
- optim_d._learning_rate.step(loss.item())
|
|
|
- else:
|
|
|
- optim_d._learning_rate.step()
|
|
|
-
|
|
|
- return outputs
|
|
|
+ return {
|
|
|
+ 'loss': outputs['loss'],
|
|
|
+ 'loss_g_pps': outputs['loss_g_pps'],
|
|
|
+ 'loss_g_gan': outputs['loss_g_gan'],
|
|
|
+ 'loss_d': outputs['loss_d']
|
|
|
+ }
|
|
|
else:
|
|
|
- super(ESRGAN, self).train_step(step, data, net)
|
|
|
+ return super(ESRGAN, self).train_step(step, data, net)
|
|
|
|
|
|
def _set_requires_grad(self, net, requires_grad):
|
|
|
for p in net.parameters():
|
|
|
p.trainable = requires_grad
|
|
|
+
|
|
|
+
|
|
|
+class RCAN(BaseRestorer):
|
|
|
+ def __init__(self,
|
|
|
+ losses=None,
|
|
|
+ sr_factor=4,
|
|
|
+ n_resgroups=10,
|
|
|
+ n_resblocks=20,
|
|
|
+ n_feats=64,
|
|
|
+ n_colors=3,
|
|
|
+ rgb_range=1.0,
|
|
|
+ kernel_size=3,
|
|
|
+ reduction=16,
|
|
|
+ **params):
|
|
|
+ params.update({
|
|
|
+ 'n_resgroups': n_resgroups,
|
|
|
+ 'n_resblocks': n_resblocks,
|
|
|
+ 'n_feats': n_feats,
|
|
|
+ 'n_colors': n_colors,
|
|
|
+ 'rgb_range': rgb_range,
|
|
|
+ 'kernel_size': kernel_size,
|
|
|
+ 'reduction': reduction
|
|
|
+ })
|
|
|
+ super(RCAN, self).__init__(
|
|
|
+ model_name='RCAN', losses=losses, sr_factor=sr_factor, **params)
|