Bobholamovic %!s(int64=2) %!d(string=hai) anos
pai
achega
5dcc6cd078

+ 1 - 2
paddlers/rs_models/res/generators/rcan.py

@@ -146,7 +146,6 @@ class RCAN(nn.Layer):
         n_feats = n_feats
         kernel_size = kernel_size
         reduction = reduction
-        scale = scale
         act = nn.ReLU()
 
         rgb_mean = (0.4488, 0.4371, 0.4040)
@@ -167,7 +166,7 @@ class RCAN(nn.Layer):
         # Define tail module
         modules_tail = [
             Upsampler(
-                conv, scale, n_feats, act=False),
+                conv, self.scale, n_feats, act=False),
             conv(n_feats, n_colors, kernel_size)
         ]
 

+ 1 - 1
paddlers/tasks/base.py

@@ -307,7 +307,7 @@ class BaseModel(metaclass=ModelMeta):
                    use_vdl=True):
         self._check_transforms(train_dataset.transforms, 'train')
 
-        if "RCNN" in self.__class__.__name__ and train_dataset.pos_num < len(
+        if self.model_type == 'detector' and 'RCNN' in self.__class__.__name__ and train_dataset.pos_num < len(
                 train_dataset.file_list):
             nranks = 1
         else:

+ 24 - 11
paddlers/tasks/restorer.py

@@ -54,7 +54,7 @@ class BaseRestorer(BaseModel):
 
     def build_net(self, **params):
         # Currently, only use models from cmres.
-        if not hasattr(cmres, model_name):
+        if not hasattr(cmres, self.model_name):
             raise ValueError("ERROR: There is no model named {}.".format(
                 model_name))
         net = dict(**cmres.__dict__)[self.model_name](**params)
@@ -618,7 +618,6 @@ class RCAN(BaseRestorer):
                  reduction=16,
                  **params):
         params.update({
-            'factor': sr_factor,
             'n_resgroups': n_resgroups,
             'n_resblocks': n_resblocks,
             'n_feats': n_feats,
@@ -661,8 +660,17 @@ class DRN(BaseRestorer):
 
 
 class LESRCNN(BaseRestorer):
-    def __init__(self, losses=None, sr_factor=4, multi_scale=False, group=1):
-        params.update({'scale': sr_factor, 'multi_scale': False, 'group': 1})
+    def __init__(self,
+                 losses=None,
+                 sr_factor=4,
+                 multi_scale=False,
+                 group=1,
+                 **params):
+        params.update({
+            'scale': sr_factor,
+            'multi_scale': multi_scale,
+            'group': group
+        })
         super(LESRCNN, self).__init__(
             model_name='LESRCNN', losses=losses, sr_factor=sr_factor, **params)
 
@@ -681,9 +689,11 @@ class ESRGAN(BaseRestorer):
                  in_channels=3,
                  out_channels=3,
                  nf=64,
-                 nb=23):
+                 nb=23,
+                 **params):
+        if sr_factor != 4:
+            raise ValueError("`sr_factor` must be 4.")
         params.update({
-            'scale': sr_factor,
             'in_nc': in_channels,
             'out_nc': out_channels,
             'nf': nf,
@@ -696,7 +706,7 @@ class ESRGAN(BaseRestorer):
     def build_net(self, **params):
         generator = ppgan.models.generators.RRDBNet(**params)
         if self.use_gan:
-            discriminator = ppgan.models.discriminators.VGGDiscrinimator128(
+            discriminator = ppgan.models.discriminators.VGGDiscriminator128(
                 in_channels=params['out_nc'], num_feat=64)
             net = GANAdapter(
                 generators=[generator], discriminators=[discriminator])
@@ -719,9 +729,9 @@ class ESRGAN(BaseRestorer):
     def default_optimizer(self, parameters, *args, **kwargs):
         if self.use_gan:
             optim_g = super(ESRGAN, self).default_optimizer(
-                parameters['optims_g'][0], *args, **kwargs)
+                parameters['params_g'][0], *args, **kwargs)
             optim_d = super(ESRGAN, self).default_optimizer(
-                parameters['optims_d'][0], *args, **kwargs)
+                parameters['params_d'][0], *args, **kwargs)
             return OptimizerAdapter(optim_g, optim_d)
         else:
             return super(ESRGAN, self).default_optimizer(params, *args,
@@ -777,14 +787,17 @@ class ESRGAN(BaseRestorer):
         if self.use_gan:
             optim_g, optim_d = self.optimizer
 
-            outputs = self.run_gan(net, data, gan_mode='forward_g')
+            outputs = self.run_gan(
+                net, data, mode='train', gan_mode='forward_g')
             optim_g.clear_grad()
             (outputs['loss_g_pps'] + outputs['loss_g_gan']).backward()
             optim_g.step()
 
             outputs.update(
                 self.run_gan(
-                    net, (outputs['g_pred'], data[1]), gan_mode='forward_d'))
+                    net, (outputs['g_pred'], data[1]),
+                    mode='train',
+                    gan_mode='forward_d'))
             optim_d.clear_grad()
             outputs['loss_d'].backward()
             optim_d.step()