|
@@ -43,7 +43,12 @@ class BaseRestorer(BaseModel):
|
|
|
MIN_MAX = (0., 1.)
|
|
|
TEST_OUT_KEY = None
|
|
|
|
|
|
- def __init__(self, model_name, losses=None, sr_factor=None, **params):
|
|
|
+ def __init__(self,
|
|
|
+ model_name,
|
|
|
+ losses=None,
|
|
|
+ sr_factor=None,
|
|
|
+ min_max=None,
|
|
|
+ **params):
|
|
|
self.init_params = locals()
|
|
|
if 'with_net' in self.init_params:
|
|
|
del self.init_params['with_net']
|
|
@@ -55,6 +60,8 @@ class BaseRestorer(BaseModel):
|
|
|
params.pop('with_net', None)
|
|
|
self.net = self.build_net(**params)
|
|
|
self.find_unused_parameters = True
|
|
|
+ if min_max is None:
|
|
|
+ self.min_max = self.MIN_MAX
|
|
|
|
|
|
def build_net(self, **params):
|
|
|
# Currently, only use models from cmres.
|
|
@@ -283,11 +290,13 @@ class BaseRestorer(BaseModel):
|
|
|
exit=True)
|
|
|
pretrained_dir = osp.join(save_dir, 'pretrain')
|
|
|
is_backbone_weights = pretrain_weights == 'IMAGENET'
|
|
|
+ # XXX: Currently, do not load optimizer state dict.
|
|
|
self.initialize_net(
|
|
|
pretrain_weights=pretrain_weights,
|
|
|
save_dir=pretrained_dir,
|
|
|
resume_checkpoint=resume_checkpoint,
|
|
|
- is_backbone_weights=is_backbone_weights)
|
|
|
+ is_backbone_weights=is_backbone_weights,
|
|
|
+ load_optim_state=False)
|
|
|
|
|
|
self.train_loop(
|
|
|
num_epochs=num_epochs,
|
|
@@ -434,6 +443,7 @@ class BaseRestorer(BaseModel):
|
|
|
|
|
|
return eval_metrics
|
|
|
|
|
|
+ @paddle.no_grad()
|
|
|
def predict(self, img_file, transforms=None):
|
|
|
"""
|
|
|
Do inference.
|
|
@@ -653,9 +663,9 @@ class BaseRestorer(BaseModel):
|
|
|
if copy:
|
|
|
im = im.copy()
|
|
|
if clip:
|
|
|
- im = np.clip(im, self.MIN_MAX[0], self.MIN_MAX[1])
|
|
|
- im -= im.min()
|
|
|
- im /= im.max() + 1e-32
|
|
|
+ im = np.clip(im, self.min_max[0], self.min_max[1])
|
|
|
+ im -= self.min_max[0]
|
|
|
+ im /= self.min_max[1] - self.min_max[0]
|
|
|
if quantize:
|
|
|
im *= 255
|
|
|
im = im.astype('uint8')
|
|
@@ -668,6 +678,7 @@ class DRN(BaseRestorer):
|
|
|
def __init__(self,
|
|
|
losses=None,
|
|
|
sr_factor=4,
|
|
|
+ min_max=None,
|
|
|
scales=(2, 4),
|
|
|
n_blocks=30,
|
|
|
n_feats=16,
|
|
@@ -691,7 +702,11 @@ class DRN(BaseRestorer):
|
|
|
self.dual_loss_weight = dual_loss_weight
|
|
|
self.scales = scales
|
|
|
super(DRN, self).__init__(
|
|
|
- model_name='DRN', losses=losses, sr_factor=sr_factor, **params)
|
|
|
+ model_name='DRN',
|
|
|
+ losses=losses,
|
|
|
+ sr_factor=sr_factor,
|
|
|
+ min_max=min_max,
|
|
|
+ **params)
|
|
|
|
|
|
def build_net(self, **params):
|
|
|
from ppgan.modules.init import init_weights
|
|
@@ -769,6 +784,7 @@ class LESRCNN(BaseRestorer):
|
|
|
def __init__(self,
|
|
|
losses=None,
|
|
|
sr_factor=4,
|
|
|
+ min_max=None,
|
|
|
multi_scale=False,
|
|
|
group=1,
|
|
|
**params):
|
|
@@ -778,7 +794,11 @@ class LESRCNN(BaseRestorer):
|
|
|
'group': group
|
|
|
})
|
|
|
super(LESRCNN, self).__init__(
|
|
|
- model_name='LESRCNN', losses=losses, sr_factor=sr_factor, **params)
|
|
|
+ model_name='LESRCNN',
|
|
|
+ losses=losses,
|
|
|
+ sr_factor=sr_factor,
|
|
|
+ min_max=min_max,
|
|
|
+ **params)
|
|
|
|
|
|
def build_net(self, **params):
|
|
|
net = ppgan.models.generators.LESRCNNGenerator(**params)
|
|
@@ -789,6 +809,7 @@ class ESRGAN(BaseRestorer):
|
|
|
def __init__(self,
|
|
|
losses=None,
|
|
|
sr_factor=4,
|
|
|
+ min_max=None,
|
|
|
use_gan=True,
|
|
|
in_channels=3,
|
|
|
out_channels=3,
|
|
@@ -805,7 +826,11 @@ class ESRGAN(BaseRestorer):
|
|
|
})
|
|
|
self.use_gan = use_gan
|
|
|
super(ESRGAN, self).__init__(
|
|
|
- model_name='ESRGAN', losses=losses, sr_factor=sr_factor, **params)
|
|
|
+ model_name='ESRGAN',
|
|
|
+ losses=losses,
|
|
|
+ sr_factor=sr_factor,
|
|
|
+ min_max=min_max,
|
|
|
+ **params)
|
|
|
|
|
|
def build_net(self, **params):
|
|
|
from ppgan.modules.init import init_weights
|
|
@@ -932,6 +957,7 @@ class RCAN(BaseRestorer):
|
|
|
def __init__(self,
|
|
|
losses=None,
|
|
|
sr_factor=4,
|
|
|
+ min_max=None,
|
|
|
n_resgroups=10,
|
|
|
n_resblocks=20,
|
|
|
n_feats=64,
|
|
@@ -950,4 +976,8 @@ class RCAN(BaseRestorer):
|
|
|
'reduction': reduction
|
|
|
})
|
|
|
super(RCAN, self).__init__(
|
|
|
- model_name='RCAN', losses=losses, sr_factor=sr_factor, **params)
|
|
|
+ model_name='RCAN',
|
|
|
+ losses=losses,
|
|
|
+ sr_factor=sr_factor,
|
|
|
+ min_max=min_max,
|
|
|
+ **params)
|