|
@@ -28,7 +28,7 @@ import paddlers.models.ppgan.metrics as metrics
|
|
|
import paddlers.utils.logging as logging
|
|
|
from paddlers.models import res_losses
|
|
|
from paddlers.models.ppgan.modules.init import init_weights
|
|
|
-from paddlers.transforms import Resize, decode_image
|
|
|
+from paddlers.transforms import Resize, decode_image, construct_sample
|
|
|
from paddlers.transforms.functions import calc_hr_shape
|
|
|
from paddlers.utils.checkpoint import res_pretrain_weights_dict
|
|
|
from .base import BaseModel
|
|
@@ -58,7 +58,6 @@ class BaseRestorer(BaseModel):
|
|
|
if params.get('with_net', True):
|
|
|
params.pop('with_net', None)
|
|
|
self.net = self.build_net(**params)
|
|
|
- self.find_unused_parameters = True
|
|
|
if min_max is None:
|
|
|
self.min_max = self.MIN_MAX
|
|
|
|
|
@@ -116,14 +115,13 @@ class BaseRestorer(BaseModel):
|
|
|
return input_spec
|
|
|
|
|
|
def run(self, net, inputs, mode):
|
|
|
+ inputs, batch_restore_list = inputs
|
|
|
outputs = OrderedDict()
|
|
|
|
|
|
if mode == 'test':
|
|
|
- tar_shape = inputs[1]
|
|
|
if self.status == 'Infer':
|
|
|
net_out = net(inputs[0])
|
|
|
- res_map_list = self.postprocess(
|
|
|
- net_out, tar_shape, transforms=inputs[2])
|
|
|
+ res_map_list = self.postprocess(net_out, batch_restore_list)
|
|
|
else:
|
|
|
if isinstance(net, GANAdapter):
|
|
|
net_out = net.generator(inputs[0])
|
|
@@ -131,8 +129,7 @@ class BaseRestorer(BaseModel):
|
|
|
net_out = net(inputs[0])
|
|
|
if self.TEST_OUT_KEY is not None:
|
|
|
net_out = net_out[self.TEST_OUT_KEY]
|
|
|
- pred = self.postprocess(
|
|
|
- net_out, tar_shape, transforms=inputs[2])
|
|
|
+ pred = self.postprocess(net_out, batch_restore_list)
|
|
|
res_map_list = []
|
|
|
for res_map in pred:
|
|
|
res_map = self._tensor_to_images(res_map)
|
|
@@ -147,9 +144,7 @@ class BaseRestorer(BaseModel):
|
|
|
if self.TEST_OUT_KEY is not None:
|
|
|
net_out = net_out[self.TEST_OUT_KEY]
|
|
|
tar = inputs[1]
|
|
|
- tar_shape = [tar.shape[-2:]]
|
|
|
- pred = self.postprocess(
|
|
|
- net_out, tar_shape, transforms=inputs[2])[0] # NCHW
|
|
|
+ pred = self.postprocess(net_out, batch_restore_list)[0] # NCHW
|
|
|
pred = self._tensor_to_images(pred)
|
|
|
outputs['pred'] = pred
|
|
|
tar = self._tensor_to_images(tar)
|
|
@@ -424,7 +419,6 @@ class BaseRestorer(BaseModel):
|
|
|
eval_dataset.num_samples, eval_dataset.num_samples))
|
|
|
with paddle.no_grad():
|
|
|
for step, data in enumerate(self.eval_data_loader):
|
|
|
- data.append(eval_dataset.transforms.transforms)
|
|
|
outputs = self.run(self.net, data, 'eval')
|
|
|
psnr.update(outputs['pred'], outputs['tar'])
|
|
|
ssim.update(outputs['pred'], outputs['tar'])
|
|
@@ -472,10 +466,8 @@ class BaseRestorer(BaseModel):
|
|
|
images = [img_file]
|
|
|
else:
|
|
|
images = img_file
|
|
|
- batch_im, batch_tar_shape = self.preprocess(images, transforms,
|
|
|
- self.model_type)
|
|
|
+ data = self.preprocess(images, transforms, self.model_type)
|
|
|
self.net.eval()
|
|
|
- data = (batch_im, batch_tar_shape, transforms.transforms)
|
|
|
outputs = self.run(self.net, data, 'test')
|
|
|
res_map_list = outputs['res_map']
|
|
|
if isinstance(img_file, list):
|
|
@@ -487,79 +479,24 @@ class BaseRestorer(BaseModel):
|
|
|
def preprocess(self, images, transforms, to_tensor=True):
|
|
|
self._check_transforms(transforms, 'test')
|
|
|
batch_im = list()
|
|
|
- batch_tar_shape = list()
|
|
|
+ batch_trans_info = list()
|
|
|
for im in images:
|
|
|
if isinstance(im, str):
|
|
|
im = decode_image(im, read_raw=True)
|
|
|
- ori_shape = im.shape[:2]
|
|
|
- sample = {'image': im}
|
|
|
- im = transforms(sample)[0]
|
|
|
+ sample = construct_sample(image=im)
|
|
|
+ data = transforms(sample)
|
|
|
+ im = data[0][0]
|
|
|
+ trans_info = data[1]
|
|
|
batch_im.append(im)
|
|
|
- batch_tar_shape.append(self._get_target_shape(ori_shape))
|
|
|
+ batch_trans_info.append(trans_info)
|
|
|
if to_tensor:
|
|
|
batch_im = paddle.to_tensor(batch_im)
|
|
|
else:
|
|
|
batch_im = np.asarray(batch_im)
|
|
|
|
|
|
- return batch_im, batch_tar_shape
|
|
|
+ return (batch_im, ), batch_trans_info
|
|
|
|
|
|
- def _get_target_shape(self, ori_shape):
|
|
|
- if self.sr_factor is None:
|
|
|
- return ori_shape
|
|
|
- else:
|
|
|
- return calc_hr_shape(ori_shape, self.sr_factor)
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def get_transforms_shape_info(batch_tar_shape, transforms):
|
|
|
- batch_restore_list = list()
|
|
|
- for tar_shape in batch_tar_shape:
|
|
|
- restore_list = list()
|
|
|
- h, w = tar_shape[0], tar_shape[1]
|
|
|
- for op in transforms:
|
|
|
- if op.__class__.__name__ == 'Resize':
|
|
|
- restore_list.append(('resize', (h, w)))
|
|
|
- h, w = op.target_size
|
|
|
- elif op.__class__.__name__ == 'ResizeByShort':
|
|
|
- restore_list.append(('resize', (h, w)))
|
|
|
- im_short_size = min(h, w)
|
|
|
- im_long_size = max(h, w)
|
|
|
- scale = float(op.short_size) / float(im_short_size)
|
|
|
- if 0 < op.max_size < np.round(scale * im_long_size):
|
|
|
- scale = float(op.max_size) / float(im_long_size)
|
|
|
- h = int(round(h * scale))
|
|
|
- w = int(round(w * scale))
|
|
|
- elif op.__class__.__name__ == 'ResizeByLong':
|
|
|
- restore_list.append(('resize', (h, w)))
|
|
|
- im_long_size = max(h, w)
|
|
|
- scale = float(op.long_size) / float(im_long_size)
|
|
|
- h = int(round(h * scale))
|
|
|
- w = int(round(w * scale))
|
|
|
- elif op.__class__.__name__ == 'Pad':
|
|
|
- if op.target_size:
|
|
|
- target_h, target_w = op.target_size
|
|
|
- else:
|
|
|
- target_h = int(
|
|
|
- (np.ceil(h / op.size_divisor) * op.size_divisor))
|
|
|
- target_w = int(
|
|
|
- (np.ceil(w / op.size_divisor) * op.size_divisor))
|
|
|
-
|
|
|
- if op.pad_mode == -1:
|
|
|
- offsets = op.offsets
|
|
|
- elif op.pad_mode == 0:
|
|
|
- offsets = [0, 0]
|
|
|
- elif op.pad_mode == 1:
|
|
|
- offsets = [(target_h - h) // 2, (target_w - w) // 2]
|
|
|
- else:
|
|
|
- offsets = [target_h - h, target_w - w]
|
|
|
- restore_list.append(('padding', (h, w), offsets))
|
|
|
- h, w = target_h, target_w
|
|
|
-
|
|
|
- batch_restore_list.append(restore_list)
|
|
|
- return batch_restore_list
|
|
|
-
|
|
|
- def postprocess(self, batch_pred, batch_tar_shape, transforms):
|
|
|
- batch_restore_list = BaseRestorer.get_transforms_shape_info(
|
|
|
- batch_tar_shape, transforms)
|
|
|
+ def postprocess(self, batch_pred, batch_restore_list):
|
|
|
if self.status == 'Infer':
|
|
|
return self._infer_postprocess(
|
|
|
batch_res_map=batch_pred, batch_restore_list=batch_restore_list)
|
|
@@ -572,11 +509,15 @@ class BaseRestorer(BaseModel):
|
|
|
pred = paddle.unsqueeze(pred, axis=0)
|
|
|
for item in restore_list[::-1]:
|
|
|
h, w = item[1][0], item[1][1]
|
|
|
+ if self.sr_factor:
|
|
|
+ h, w = calc_hr_shape((h, w), self.sr_factor)
|
|
|
if item[0] == 'resize':
|
|
|
pred = F.interpolate(
|
|
|
pred, (h, w), mode=mode, data_format='NCHW')
|
|
|
elif item[0] == 'padding':
|
|
|
x, y = item[2]
|
|
|
+ if self.sr_factor:
|
|
|
+ x, y = calc_hr_shape((x, y), self.sr_factor)
|
|
|
pred = pred[:, :, y:y + h, x:x + w]
|
|
|
else:
|
|
|
pass
|
|
@@ -590,6 +531,8 @@ class BaseRestorer(BaseModel):
|
|
|
res_map = paddle.unsqueeze(res_map, axis=0)
|
|
|
for item in restore_list[::-1]:
|
|
|
h, w = item[1][0], item[1][1]
|
|
|
+ if self.sr_factor:
|
|
|
+ h, w = calc_hr_shape((h, w), self.sr_factor)
|
|
|
if item[0] == 'resize':
|
|
|
if isinstance(res_map, np.ndarray):
|
|
|
res_map = cv2.resize(
|
|
@@ -601,6 +544,8 @@ class BaseRestorer(BaseModel):
|
|
|
data_format='NHWC')
|
|
|
elif item[0] == 'padding':
|
|
|
x, y = item[2]
|
|
|
+ if self.sr_factor:
|
|
|
+ x, y = calc_hr_shape((x, y), self.sr_factor)
|
|
|
if isinstance(res_map, np.ndarray):
|
|
|
res_map = res_map[y:y + h, x:x + w]
|
|
|
else:
|
|
@@ -621,7 +566,11 @@ class BaseRestorer(BaseModel):
|
|
|
raise TypeError(
|
|
|
"`transforms.arrange` must be an ArrangeRestorer object.")
|
|
|
|
|
|
- def build_data_loader(self, dataset, batch_size, mode='train'):
|
|
|
+ def build_data_loader(self,
|
|
|
+ dataset,
|
|
|
+ batch_size,
|
|
|
+ mode='train',
|
|
|
+ collate_fn=None):
|
|
|
if dataset.num_samples < batch_size:
|
|
|
raise ValueError(
|
|
|
'The volume of dataset({}) must be larger than batch size({}).'
|
|
@@ -633,7 +582,8 @@ class BaseRestorer(BaseModel):
|
|
|
batch_size=batch_size,
|
|
|
shuffle=dataset.shuffle,
|
|
|
drop_last=False,
|
|
|
- collate_fn=dataset.batch_transforms,
|
|
|
+ collate_fn=dataset.collate_fn
|
|
|
+ if collate_fn is None else collate_fn,
|
|
|
num_workers=dataset.num_workers,
|
|
|
return_list=True,
|
|
|
use_shared_memory=False)
|
|
@@ -758,7 +708,7 @@ class DRN(BaseRestorer):
|
|
|
|
|
|
def train_step(self, step, data, net):
|
|
|
outputs = self.run_gan(
|
|
|
- net, data, mode='train', gan_mode='forward_primary')
|
|
|
+ net, data[0], mode='train', gan_mode='forward_primary')
|
|
|
outputs.update(
|
|
|
self.run_gan(
|
|
|
net, (outputs['sr'], outputs['lr']),
|
|
@@ -800,6 +750,9 @@ class LESRCNN(BaseRestorer):
|
|
|
|
|
|
|
|
|
class ESRGAN(BaseRestorer):
|
|
|
+
|
|
|
+ find_unused_parameters = True
|
|
|
+
|
|
|
def __init__(self,
|
|
|
losses=None,
|
|
|
sr_factor=4,
|
|
@@ -915,14 +868,14 @@ class ESRGAN(BaseRestorer):
|
|
|
optim_g, optim_d = self.optimizer
|
|
|
|
|
|
outputs = self.run_gan(
|
|
|
- net, data, mode='train', gan_mode='forward_g')
|
|
|
+ net, data[0], mode='train', gan_mode='forward_g')
|
|
|
optim_g.clear_grad()
|
|
|
(outputs['loss_g_pps'] + outputs['loss_g_gan']).backward()
|
|
|
optim_g.step()
|
|
|
|
|
|
outputs.update(
|
|
|
self.run_gan(
|
|
|
- net, (outputs['g_pred'], data[1]),
|
|
|
+ net, (outputs['g_pred'], data[0][1]),
|
|
|
mode='train',
|
|
|
gan_mode='forward_d'))
|
|
|
optim_d.clear_grad()
|