|
@@ -27,7 +27,6 @@ from ...models.ppgan.modules.init import reset_parameters
|
|
|
class RCANModel(BaseModel):
|
|
|
"""Base SR model for single image super-resolution.
|
|
|
"""
|
|
|
-
|
|
|
def __init__(self, generator, pixel_criterion=None, use_init_weight=False):
|
|
|
"""
|
|
|
Args:
|
|
@@ -37,7 +36,8 @@ class RCANModel(BaseModel):
|
|
|
super(RCANModel, self).__init__()
|
|
|
|
|
|
self.nets['generator'] = build_generator(generator)
|
|
|
-
|
|
|
+ self.error_last = 1e8
|
|
|
+ self.batch = 0
|
|
|
if pixel_criterion:
|
|
|
self.pixel_criterion = build_criterion(pixel_criterion)
|
|
|
if use_init_weight:
|
|
@@ -63,8 +63,21 @@ class RCANModel(BaseModel):
|
|
|
loss_pixel = self.pixel_criterion(self.output, self.gt)
|
|
|
self.losses['loss_pixel'] = loss_pixel
|
|
|
|
|
|
- loss_pixel.backward()
|
|
|
- optims['optim'].step()
|
|
|
+ skip_threshold = 1e6
|
|
|
+
|
|
|
+ if loss_pixel.item() < skip_threshold * self.error_last:
|
|
|
+ loss_pixel.backward()
|
|
|
+ optims['optim'].step()
|
|
|
+ else:
|
|
|
+ print('Skip this batch {}! (Loss: {})'.format(
|
|
|
+ self.batch + 1, loss_pixel.item()
|
|
|
+ ))
|
|
|
+ self.batch += 1
|
|
|
+
|
|
|
+ if self.batch % 1000 == 0:
|
|
|
+ self.error_last = loss_pixel.item()/1000
|
|
|
+ print("update error_last:{}".format(self.error_last))
|
|
|
+
|
|
|
|
|
|
def test_iter(self, metrics=None):
|
|
|
self.nets['generator'].eval()
|
|
@@ -86,8 +99,8 @@ class RCANModel(BaseModel):
|
|
|
|
|
|
def init_sr_weight(net):
|
|
|
def reset_func(m):
|
|
|
- if hasattr(m, 'weight') and (
|
|
|
- not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D))):
|
|
|
+ if hasattr(m, 'weight') and (not isinstance(
|
|
|
+ m, (nn.BatchNorm, nn.BatchNorm2D))):
|
|
|
reset_parameters(m)
|
|
|
|
|
|
net.apply(reset_func)
|