Explorar o código

Update rcan_model.py

Alleviate gradient explosion, but convergence is still difficult
kongdebug %!s(int64=3) %!d(string=hai) anos
pai
achega
7e74e85ab7
Modificáronse 1 ficheiros con 19 adicións e 6 borrados
  1. 19 6
      paddlers/custom_models/gan/rcan_model.py

+ 19 - 6
paddlers/custom_models/gan/rcan_model.py

@@ -27,7 +27,6 @@ from ...models.ppgan.modules.init import reset_parameters
 class RCANModel(BaseModel):
     """Base SR model for single image super-resolution.
     """
-
     def __init__(self, generator, pixel_criterion=None, use_init_weight=False):
         """
         Args:
@@ -37,7 +36,8 @@ class RCANModel(BaseModel):
         super(RCANModel, self).__init__()
 
         self.nets['generator'] = build_generator(generator)
-
+        self.error_last = 1e8
+        self.batch = 0
         if pixel_criterion:
             self.pixel_criterion = build_criterion(pixel_criterion)
         if use_init_weight:
@@ -63,8 +63,21 @@ class RCANModel(BaseModel):
         loss_pixel = self.pixel_criterion(self.output, self.gt)
         self.losses['loss_pixel'] = loss_pixel
 
-        loss_pixel.backward()
-        optims['optim'].step()
+        skip_threshold = 1e6
+
+        if loss_pixel.item() < skip_threshold * self.error_last:
+            loss_pixel.backward()
+            optims['optim'].step()
+        else:
+            print('Skip this batch {}! (Loss: {})'.format(
+                    self.batch + 1, loss_pixel.item()
+                ))
+        self.batch += 1
+
+        if self.batch % 1000 == 0:
+            self.error_last = loss_pixel.item()/1000
+            print("update error_last:{}".format(self.error_last))
+
 
     def test_iter(self, metrics=None):
         self.nets['generator'].eval()
@@ -86,8 +99,8 @@ class RCANModel(BaseModel):
 
 def init_sr_weight(net):
     def reset_func(m):
-        if hasattr(m, 'weight') and (
-                not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D))):
+        if hasattr(m, 'weight') and (not isinstance(
+                m, (nn.BatchNorm, nn.BatchNorm2D))):
             reset_parameters(m)
 
     net.apply(reset_func)