|  | @@ -28,7 +28,7 @@ import paddlers.models.ppgan.metrics as metrics
 | 
											
												
													
														|  |  import paddlers.utils.logging as logging
 |  |  import paddlers.utils.logging as logging
 | 
											
												
													
														|  |  from paddlers.models import res_losses
 |  |  from paddlers.models import res_losses
 | 
											
												
													
														|  |  from paddlers.models.ppgan.modules.init import init_weights
 |  |  from paddlers.models.ppgan.modules.init import init_weights
 | 
											
												
													
														|  | -from paddlers.transforms import Resize, decode_image
 |  | 
 | 
											
												
													
														|  | 
 |  | +from paddlers.transforms import Resize, decode_image, construct_sample
 | 
											
												
													
														|  |  from paddlers.transforms.functions import calc_hr_shape
 |  |  from paddlers.transforms.functions import calc_hr_shape
 | 
											
												
													
														|  |  from paddlers.utils.checkpoint import res_pretrain_weights_dict
 |  |  from paddlers.utils.checkpoint import res_pretrain_weights_dict
 | 
											
												
													
														|  |  from .base import BaseModel
 |  |  from .base import BaseModel
 | 
											
										
											
												
													
														|  | @@ -58,7 +58,6 @@ class BaseRestorer(BaseModel):
 | 
											
												
													
														|  |          if params.get('with_net', True):
 |  |          if params.get('with_net', True):
 | 
											
												
													
														|  |              params.pop('with_net', None)
 |  |              params.pop('with_net', None)
 | 
											
												
													
														|  |              self.net = self.build_net(**params)
 |  |              self.net = self.build_net(**params)
 | 
											
												
													
														|  | -        self.find_unused_parameters = True
 |  | 
 | 
											
												
													
														|  |          if min_max is None:
 |  |          if min_max is None:
 | 
											
												
													
														|  |              self.min_max = self.MIN_MAX
 |  |              self.min_max = self.MIN_MAX
 | 
											
												
													
														|  |  
 |  |  
 | 
											
										
											
												
													
														|  | @@ -116,14 +115,13 @@ class BaseRestorer(BaseModel):
 | 
											
												
													
														|  |          return input_spec
 |  |          return input_spec
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |      def run(self, net, inputs, mode):
 |  |      def run(self, net, inputs, mode):
 | 
											
												
													
														|  | 
 |  | +        inputs, batch_restore_list = inputs
 | 
											
												
													
														|  |          outputs = OrderedDict()
 |  |          outputs = OrderedDict()
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          if mode == 'test':
 |  |          if mode == 'test':
 | 
											
												
													
														|  | -            tar_shape = inputs[1]
 |  | 
 | 
											
												
													
														|  |              if self.status == 'Infer':
 |  |              if self.status == 'Infer':
 | 
											
												
													
														|  |                  net_out = net(inputs[0])
 |  |                  net_out = net(inputs[0])
 | 
											
												
													
														|  | -                res_map_list = self.postprocess(
 |  | 
 | 
											
												
													
														|  | -                    net_out, tar_shape, transforms=inputs[2])
 |  | 
 | 
											
												
													
														|  | 
 |  | +                res_map_list = self.postprocess(net_out, batch_restore_list)
 | 
											
												
													
														|  |              else:
 |  |              else:
 | 
											
												
													
														|  |                  if isinstance(net, GANAdapter):
 |  |                  if isinstance(net, GANAdapter):
 | 
											
												
													
														|  |                      net_out = net.generator(inputs[0])
 |  |                      net_out = net.generator(inputs[0])
 | 
											
										
											
												
													
														|  | @@ -131,8 +129,7 @@ class BaseRestorer(BaseModel):
 | 
											
												
													
														|  |                      net_out = net(inputs[0])
 |  |                      net_out = net(inputs[0])
 | 
											
												
													
														|  |                  if self.TEST_OUT_KEY is not None:
 |  |                  if self.TEST_OUT_KEY is not None:
 | 
											
												
													
														|  |                      net_out = net_out[self.TEST_OUT_KEY]
 |  |                      net_out = net_out[self.TEST_OUT_KEY]
 | 
											
												
													
														|  | -                pred = self.postprocess(
 |  | 
 | 
											
												
													
														|  | -                    net_out, tar_shape, transforms=inputs[2])
 |  | 
 | 
											
												
													
														|  | 
 |  | +                pred = self.postprocess(net_out, batch_restore_list)
 | 
											
												
													
														|  |                  res_map_list = []
 |  |                  res_map_list = []
 | 
											
												
													
														|  |                  for res_map in pred:
 |  |                  for res_map in pred:
 | 
											
												
													
														|  |                      res_map = self._tensor_to_images(res_map)
 |  |                      res_map = self._tensor_to_images(res_map)
 | 
											
										
											
												
													
														|  | @@ -147,9 +144,7 @@ class BaseRestorer(BaseModel):
 | 
											
												
													
														|  |              if self.TEST_OUT_KEY is not None:
 |  |              if self.TEST_OUT_KEY is not None:
 | 
											
												
													
														|  |                  net_out = net_out[self.TEST_OUT_KEY]
 |  |                  net_out = net_out[self.TEST_OUT_KEY]
 | 
											
												
													
														|  |              tar = inputs[1]
 |  |              tar = inputs[1]
 | 
											
												
													
														|  | -            tar_shape = [tar.shape[-2:]]
 |  | 
 | 
											
												
													
														|  | -            pred = self.postprocess(
 |  | 
 | 
											
												
													
														|  | -                net_out, tar_shape, transforms=inputs[2])[0]  # NCHW
 |  | 
 | 
											
												
													
														|  | 
 |  | +            pred = self.postprocess(net_out, batch_restore_list)[0]  # NCHW
 | 
											
												
													
														|  |              pred = self._tensor_to_images(pred)
 |  |              pred = self._tensor_to_images(pred)
 | 
											
												
													
														|  |              outputs['pred'] = pred
 |  |              outputs['pred'] = pred
 | 
											
												
													
														|  |              tar = self._tensor_to_images(tar)
 |  |              tar = self._tensor_to_images(tar)
 | 
											
										
											
												
													
														|  | @@ -424,7 +419,6 @@ class BaseRestorer(BaseModel):
 | 
											
												
													
														|  |                      eval_dataset.num_samples, eval_dataset.num_samples))
 |  |                      eval_dataset.num_samples, eval_dataset.num_samples))
 | 
											
												
													
														|  |              with paddle.no_grad():
 |  |              with paddle.no_grad():
 | 
											
												
													
														|  |                  for step, data in enumerate(self.eval_data_loader):
 |  |                  for step, data in enumerate(self.eval_data_loader):
 | 
											
												
													
														|  | -                    data.append(eval_dataset.transforms.transforms)
 |  | 
 | 
											
												
													
														|  |                      outputs = self.run(self.net, data, 'eval')
 |  |                      outputs = self.run(self.net, data, 'eval')
 | 
											
												
													
														|  |                      psnr.update(outputs['pred'], outputs['tar'])
 |  |                      psnr.update(outputs['pred'], outputs['tar'])
 | 
											
												
													
														|  |                      ssim.update(outputs['pred'], outputs['tar'])
 |  |                      ssim.update(outputs['pred'], outputs['tar'])
 | 
											
										
											
												
													
														|  | @@ -472,10 +466,8 @@ class BaseRestorer(BaseModel):
 | 
											
												
													
														|  |              images = [img_file]
 |  |              images = [img_file]
 | 
											
												
													
														|  |          else:
 |  |          else:
 | 
											
												
													
														|  |              images = img_file
 |  |              images = img_file
 | 
											
												
													
														|  | -        batch_im, batch_tar_shape = self.preprocess(images, transforms,
 |  | 
 | 
											
												
													
														|  | -                                                    self.model_type)
 |  | 
 | 
											
												
													
														|  | 
 |  | +        data = self.preprocess(images, transforms, self.model_type)
 | 
											
												
													
														|  |          self.net.eval()
 |  |          self.net.eval()
 | 
											
												
													
														|  | -        data = (batch_im, batch_tar_shape, transforms.transforms)
 |  | 
 | 
											
												
													
														|  |          outputs = self.run(self.net, data, 'test')
 |  |          outputs = self.run(self.net, data, 'test')
 | 
											
												
													
														|  |          res_map_list = outputs['res_map']
 |  |          res_map_list = outputs['res_map']
 | 
											
												
													
														|  |          if isinstance(img_file, list):
 |  |          if isinstance(img_file, list):
 | 
											
										
											
												
													
														|  | @@ -487,79 +479,24 @@ class BaseRestorer(BaseModel):
 | 
											
												
													
														|  |      def preprocess(self, images, transforms, to_tensor=True):
 |  |      def preprocess(self, images, transforms, to_tensor=True):
 | 
											
												
													
														|  |          self._check_transforms(transforms, 'test')
 |  |          self._check_transforms(transforms, 'test')
 | 
											
												
													
														|  |          batch_im = list()
 |  |          batch_im = list()
 | 
											
												
													
														|  | -        batch_tar_shape = list()
 |  | 
 | 
											
												
													
														|  | 
 |  | +        batch_trans_info = list()
 | 
											
												
													
														|  |          for im in images:
 |  |          for im in images:
 | 
											
												
													
														|  |              if isinstance(im, str):
 |  |              if isinstance(im, str):
 | 
											
												
													
														|  |                  im = decode_image(im, read_raw=True)
 |  |                  im = decode_image(im, read_raw=True)
 | 
											
												
													
														|  | -            ori_shape = im.shape[:2]
 |  | 
 | 
											
												
													
														|  | -            sample = {'image': im}
 |  | 
 | 
											
												
													
														|  | -            im = transforms(sample)[0]
 |  | 
 | 
											
												
													
														|  | 
 |  | +            sample = construct_sample(image=im)
 | 
											
												
													
														|  | 
 |  | +            data = transforms(sample)
 | 
											
												
													
														|  | 
 |  | +            im = data[0][0]
 | 
											
												
													
														|  | 
 |  | +            trans_info = data[1]
 | 
											
												
													
														|  |              batch_im.append(im)
 |  |              batch_im.append(im)
 | 
											
												
													
														|  | -            batch_tar_shape.append(self._get_target_shape(ori_shape))
 |  | 
 | 
											
												
													
														|  | 
 |  | +            batch_trans_info.append(trans_info)
 | 
											
												
													
														|  |          if to_tensor:
 |  |          if to_tensor:
 | 
											
												
													
														|  |              batch_im = paddle.to_tensor(batch_im)
 |  |              batch_im = paddle.to_tensor(batch_im)
 | 
											
												
													
														|  |          else:
 |  |          else:
 | 
											
												
													
														|  |              batch_im = np.asarray(batch_im)
 |  |              batch_im = np.asarray(batch_im)
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -        return batch_im, batch_tar_shape
 |  | 
 | 
											
												
													
														|  | 
 |  | +        return (batch_im, ), batch_trans_info
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -    def _get_target_shape(self, ori_shape):
 |  | 
 | 
											
												
													
														|  | -        if self.sr_factor is None:
 |  | 
 | 
											
												
													
														|  | -            return ori_shape
 |  | 
 | 
											
												
													
														|  | -        else:
 |  | 
 | 
											
												
													
														|  | -            return calc_hr_shape(ori_shape, self.sr_factor)
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -    @staticmethod
 |  | 
 | 
											
												
													
														|  | -    def get_transforms_shape_info(batch_tar_shape, transforms):
 |  | 
 | 
											
												
													
														|  | -        batch_restore_list = list()
 |  | 
 | 
											
												
													
														|  | -        for tar_shape in batch_tar_shape:
 |  | 
 | 
											
												
													
														|  | -            restore_list = list()
 |  | 
 | 
											
												
													
														|  | -            h, w = tar_shape[0], tar_shape[1]
 |  | 
 | 
											
												
													
														|  | -            for op in transforms:
 |  | 
 | 
											
												
													
														|  | -                if op.__class__.__name__ == 'Resize':
 |  | 
 | 
											
												
													
														|  | -                    restore_list.append(('resize', (h, w)))
 |  | 
 | 
											
												
													
														|  | -                    h, w = op.target_size
 |  | 
 | 
											
												
													
														|  | -                elif op.__class__.__name__ == 'ResizeByShort':
 |  | 
 | 
											
												
													
														|  | -                    restore_list.append(('resize', (h, w)))
 |  | 
 | 
											
												
													
														|  | -                    im_short_size = min(h, w)
 |  | 
 | 
											
												
													
														|  | -                    im_long_size = max(h, w)
 |  | 
 | 
											
												
													
														|  | -                    scale = float(op.short_size) / float(im_short_size)
 |  | 
 | 
											
												
													
														|  | -                    if 0 < op.max_size < np.round(scale * im_long_size):
 |  | 
 | 
											
												
													
														|  | -                        scale = float(op.max_size) / float(im_long_size)
 |  | 
 | 
											
												
													
														|  | -                    h = int(round(h * scale))
 |  | 
 | 
											
												
													
														|  | -                    w = int(round(w * scale))
 |  | 
 | 
											
												
													
														|  | -                elif op.__class__.__name__ == 'ResizeByLong':
 |  | 
 | 
											
												
													
														|  | -                    restore_list.append(('resize', (h, w)))
 |  | 
 | 
											
												
													
														|  | -                    im_long_size = max(h, w)
 |  | 
 | 
											
												
													
														|  | -                    scale = float(op.long_size) / float(im_long_size)
 |  | 
 | 
											
												
													
														|  | -                    h = int(round(h * scale))
 |  | 
 | 
											
												
													
														|  | -                    w = int(round(w * scale))
 |  | 
 | 
											
												
													
														|  | -                elif op.__class__.__name__ == 'Pad':
 |  | 
 | 
											
												
													
														|  | -                    if op.target_size:
 |  | 
 | 
											
												
													
														|  | -                        target_h, target_w = op.target_size
 |  | 
 | 
											
												
													
														|  | -                    else:
 |  | 
 | 
											
												
													
														|  | -                        target_h = int(
 |  | 
 | 
											
												
													
														|  | -                            (np.ceil(h / op.size_divisor) * op.size_divisor))
 |  | 
 | 
											
												
													
														|  | -                        target_w = int(
 |  | 
 | 
											
												
													
														|  | -                            (np.ceil(w / op.size_divisor) * op.size_divisor))
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -                    if op.pad_mode == -1:
 |  | 
 | 
											
												
													
														|  | -                        offsets = op.offsets
 |  | 
 | 
											
												
													
														|  | -                    elif op.pad_mode == 0:
 |  | 
 | 
											
												
													
														|  | -                        offsets = [0, 0]
 |  | 
 | 
											
												
													
														|  | -                    elif op.pad_mode == 1:
 |  | 
 | 
											
												
													
														|  | -                        offsets = [(target_h - h) // 2, (target_w - w) // 2]
 |  | 
 | 
											
												
													
														|  | -                    else:
 |  | 
 | 
											
												
													
														|  | -                        offsets = [target_h - h, target_w - w]
 |  | 
 | 
											
												
													
														|  | -                    restore_list.append(('padding', (h, w), offsets))
 |  | 
 | 
											
												
													
														|  | -                    h, w = target_h, target_w
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -            batch_restore_list.append(restore_list)
 |  | 
 | 
											
												
													
														|  | -        return batch_restore_list
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -    def postprocess(self, batch_pred, batch_tar_shape, transforms):
 |  | 
 | 
											
												
													
														|  | -        batch_restore_list = BaseRestorer.get_transforms_shape_info(
 |  | 
 | 
											
												
													
														|  | -            batch_tar_shape, transforms)
 |  | 
 | 
											
												
													
														|  | 
 |  | +    def postprocess(self, batch_pred, batch_restore_list):
 | 
											
												
													
														|  |          if self.status == 'Infer':
 |  |          if self.status == 'Infer':
 | 
											
												
													
														|  |              return self._infer_postprocess(
 |  |              return self._infer_postprocess(
 | 
											
												
													
														|  |                  batch_res_map=batch_pred, batch_restore_list=batch_restore_list)
 |  |                  batch_res_map=batch_pred, batch_restore_list=batch_restore_list)
 | 
											
										
											
												
													
														|  | @@ -572,11 +509,15 @@ class BaseRestorer(BaseModel):
 | 
											
												
													
														|  |              pred = paddle.unsqueeze(pred, axis=0)
 |  |              pred = paddle.unsqueeze(pred, axis=0)
 | 
											
												
													
														|  |              for item in restore_list[::-1]:
 |  |              for item in restore_list[::-1]:
 | 
											
												
													
														|  |                  h, w = item[1][0], item[1][1]
 |  |                  h, w = item[1][0], item[1][1]
 | 
											
												
													
														|  | 
 |  | +                if self.sr_factor:
 | 
											
												
													
														|  | 
 |  | +                    h, w = calc_hr_shape((h, w), self.sr_factor)
 | 
											
												
													
														|  |                  if item[0] == 'resize':
 |  |                  if item[0] == 'resize':
 | 
											
												
													
														|  |                      pred = F.interpolate(
 |  |                      pred = F.interpolate(
 | 
											
												
													
														|  |                          pred, (h, w), mode=mode, data_format='NCHW')
 |  |                          pred, (h, w), mode=mode, data_format='NCHW')
 | 
											
												
													
														|  |                  elif item[0] == 'padding':
 |  |                  elif item[0] == 'padding':
 | 
											
												
													
														|  |                      x, y = item[2]
 |  |                      x, y = item[2]
 | 
											
												
													
														|  | 
 |  | +                    if self.sr_factor:
 | 
											
												
													
														|  | 
 |  | +                        x, y = calc_hr_shape((x, y), self.sr_factor)
 | 
											
												
													
														|  |                      pred = pred[:, :, y:y + h, x:x + w]
 |  |                      pred = pred[:, :, y:y + h, x:x + w]
 | 
											
												
													
														|  |                  else:
 |  |                  else:
 | 
											
												
													
														|  |                      pass
 |  |                      pass
 | 
											
										
											
												
													
														|  | @@ -590,6 +531,8 @@ class BaseRestorer(BaseModel):
 | 
											
												
													
														|  |                  res_map = paddle.unsqueeze(res_map, axis=0)
 |  |                  res_map = paddle.unsqueeze(res_map, axis=0)
 | 
											
												
													
														|  |              for item in restore_list[::-1]:
 |  |              for item in restore_list[::-1]:
 | 
											
												
													
														|  |                  h, w = item[1][0], item[1][1]
 |  |                  h, w = item[1][0], item[1][1]
 | 
											
												
													
														|  | 
 |  | +                if self.sr_factor:
 | 
											
												
													
														|  | 
 |  | +                    h, w = calc_hr_shape((h, w), self.sr_factor)
 | 
											
												
													
														|  |                  if item[0] == 'resize':
 |  |                  if item[0] == 'resize':
 | 
											
												
													
														|  |                      if isinstance(res_map, np.ndarray):
 |  |                      if isinstance(res_map, np.ndarray):
 | 
											
												
													
														|  |                          res_map = cv2.resize(
 |  |                          res_map = cv2.resize(
 | 
											
										
											
												
													
														|  | @@ -601,6 +544,8 @@ class BaseRestorer(BaseModel):
 | 
											
												
													
														|  |                              data_format='NHWC')
 |  |                              data_format='NHWC')
 | 
											
												
													
														|  |                  elif item[0] == 'padding':
 |  |                  elif item[0] == 'padding':
 | 
											
												
													
														|  |                      x, y = item[2]
 |  |                      x, y = item[2]
 | 
											
												
													
														|  | 
 |  | +                    if self.sr_factor:
 | 
											
												
													
														|  | 
 |  | +                        x, y = calc_hr_shape((x, y), self.sr_factor)
 | 
											
												
													
														|  |                      if isinstance(res_map, np.ndarray):
 |  |                      if isinstance(res_map, np.ndarray):
 | 
											
												
													
														|  |                          res_map = res_map[y:y + h, x:x + w]
 |  |                          res_map = res_map[y:y + h, x:x + w]
 | 
											
												
													
														|  |                      else:
 |  |                      else:
 | 
											
										
											
												
													
														|  | @@ -621,7 +566,11 @@ class BaseRestorer(BaseModel):
 | 
											
												
													
														|  |              raise TypeError(
 |  |              raise TypeError(
 | 
											
												
													
														|  |                  "`transforms.arrange` must be an ArrangeRestorer object.")
 |  |                  "`transforms.arrange` must be an ArrangeRestorer object.")
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -    def build_data_loader(self, dataset, batch_size, mode='train'):
 |  | 
 | 
											
												
													
														|  | 
 |  | +    def build_data_loader(self,
 | 
											
												
													
														|  | 
 |  | +                          dataset,
 | 
											
												
													
														|  | 
 |  | +                          batch_size,
 | 
											
												
													
														|  | 
 |  | +                          mode='train',
 | 
											
												
													
														|  | 
 |  | +                          collate_fn=None):
 | 
											
												
													
														|  |          if dataset.num_samples < batch_size:
 |  |          if dataset.num_samples < batch_size:
 | 
											
												
													
														|  |              raise ValueError(
 |  |              raise ValueError(
 | 
											
												
													
														|  |                  'The volume of dataset({}) must be larger than batch size({}).'
 |  |                  'The volume of dataset({}) must be larger than batch size({}).'
 | 
											
										
											
												
													
														|  | @@ -633,7 +582,8 @@ class BaseRestorer(BaseModel):
 | 
											
												
													
														|  |                  batch_size=batch_size,
 |  |                  batch_size=batch_size,
 | 
											
												
													
														|  |                  shuffle=dataset.shuffle,
 |  |                  shuffle=dataset.shuffle,
 | 
											
												
													
														|  |                  drop_last=False,
 |  |                  drop_last=False,
 | 
											
												
													
														|  | -                collate_fn=dataset.batch_transforms,
 |  | 
 | 
											
												
													
														|  | 
 |  | +                collate_fn=dataset.collate_fn
 | 
											
												
													
														|  | 
 |  | +                if collate_fn is None else collate_fn,
 | 
											
												
													
														|  |                  num_workers=dataset.num_workers,
 |  |                  num_workers=dataset.num_workers,
 | 
											
												
													
														|  |                  return_list=True,
 |  |                  return_list=True,
 | 
											
												
													
														|  |                  use_shared_memory=False)
 |  |                  use_shared_memory=False)
 | 
											
										
											
												
													
														|  | @@ -758,7 +708,7 @@ class DRN(BaseRestorer):
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |      def train_step(self, step, data, net):
 |  |      def train_step(self, step, data, net):
 | 
											
												
													
														|  |          outputs = self.run_gan(
 |  |          outputs = self.run_gan(
 | 
											
												
													
														|  | -            net, data, mode='train', gan_mode='forward_primary')
 |  | 
 | 
											
												
													
														|  | 
 |  | +            net, data[0], mode='train', gan_mode='forward_primary')
 | 
											
												
													
														|  |          outputs.update(
 |  |          outputs.update(
 | 
											
												
													
														|  |              self.run_gan(
 |  |              self.run_gan(
 | 
											
												
													
														|  |                  net, (outputs['sr'], outputs['lr']),
 |  |                  net, (outputs['sr'], outputs['lr']),
 | 
											
										
											
												
													
														|  | @@ -800,6 +750,9 @@ class LESRCNN(BaseRestorer):
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  class ESRGAN(BaseRestorer):
 |  |  class ESRGAN(BaseRestorer):
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    find_unused_parameters = True
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |      def __init__(self,
 |  |      def __init__(self,
 | 
											
												
													
														|  |                   losses=None,
 |  |                   losses=None,
 | 
											
												
													
														|  |                   sr_factor=4,
 |  |                   sr_factor=4,
 | 
											
										
											
												
													
														|  | @@ -915,14 +868,14 @@ class ESRGAN(BaseRestorer):
 | 
											
												
													
														|  |              optim_g, optim_d = self.optimizer
 |  |              optim_g, optim_d = self.optimizer
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |              outputs = self.run_gan(
 |  |              outputs = self.run_gan(
 | 
											
												
													
														|  | -                net, data, mode='train', gan_mode='forward_g')
 |  | 
 | 
											
												
													
														|  | 
 |  | +                net, data[0], mode='train', gan_mode='forward_g')
 | 
											
												
													
														|  |              optim_g.clear_grad()
 |  |              optim_g.clear_grad()
 | 
											
												
													
														|  |              (outputs['loss_g_pps'] + outputs['loss_g_gan']).backward()
 |  |              (outputs['loss_g_pps'] + outputs['loss_g_gan']).backward()
 | 
											
												
													
														|  |              optim_g.step()
 |  |              optim_g.step()
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |              outputs.update(
 |  |              outputs.update(
 | 
											
												
													
														|  |                  self.run_gan(
 |  |                  self.run_gan(
 | 
											
												
													
														|  | -                    net, (outputs['g_pred'], data[1]),
 |  | 
 | 
											
												
													
														|  | 
 |  | +                    net, (outputs['g_pred'], data[0][1]),
 | 
											
												
													
														|  |                      mode='train',
 |  |                      mode='train',
 | 
											
												
													
														|  |                      gan_mode='forward_d'))
 |  |                      gan_mode='forward_d'))
 | 
											
												
													
														|  |              optim_d.clear_grad()
 |  |              optim_d.clear_grad()
 |