kongdebug 3 жил өмнө
parent
commit
33b5f0fd3e

+ 5 - 0
paddlers/custom_models/gan/__init__.py

@@ -11,3 +11,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+<<<<<<< HEAD
+
+from .rcan_model import RCANModel
+=======
+>>>>>>> 343f646f7dabf2ff08d80fab4ac5a37511260bd2

+ 15 - 0
paddlers/custom_models/gan/generators/__init__.py

@@ -0,0 +1,15 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .rcan import RCAN

+ 25 - 0
paddlers/custom_models/gan/generators/builder.py

@@ -0,0 +1,25 @@
+#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+from ....models.ppgan.utils.registry import Registry
+
+GENERATORS = Registry("GENERATOR")
+
+
+def build_generator(cfg):
+    cfg_copy = copy.deepcopy(cfg)
+    name = cfg_copy.pop('name')
+    generator = GENERATORS.get(name)(**cfg_copy)
+    return generator

+ 190 - 0
paddlers/custom_models/gan/generators/rcan.py

@@ -0,0 +1,190 @@
+# base on https://github.com/kongdebug/RCAN-Paddle
+import math
+import paddle
+import paddle.nn as nn
+
+from .builder import GENERATORS
+
+
+def default_conv(in_channels, out_channels, kernel_size, bias=True):
+    return nn.Conv2D(
+        in_channels,
+        out_channels,
+        kernel_size,
+        padding=(kernel_size // 2),
+        bias_attr=bias)
+
+
+class MeanShift(nn.Conv2D):
+    def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
+        super(MeanShift, self).__init__(3, 3, kernel_size=1)
+        std = paddle.to_tensor(rgb_std)
+        self.weight.set_value(paddle.eye(3).reshape([3, 3, 1, 1]))
+        self.weight.set_value(self.weight / (std.reshape([3, 1, 1, 1])))
+
+        mean = paddle.to_tensor(rgb_mean)
+        self.bias.set_value(sign * rgb_range * mean / std)
+
+        self.weight.trainable = False
+        self.bias.trainable = False
+
+
+## Channel Attention (CA) Layer
+class CALayer(nn.Layer):
+    def __init__(self, channel, reduction=16):
+        super(CALayer, self).__init__()
+        # global average pooling: feature --> point
+        self.avg_pool = nn.AdaptiveAvgPool2D(1)
+        # feature channel downscale and upscale --> channel weight
+        self.conv_du = nn.Sequential(
+            nn.Conv2D(
+                channel, channel // reduction, 1, padding=0, bias_attr=True),
+            nn.ReLU(),
+            nn.Conv2D(
+                channel // reduction, channel, 1, padding=0, bias_attr=True),
+            nn.Sigmoid())
+
+    def forward(self, x):
+        y = self.avg_pool(x)
+        y = self.conv_du(y)
+        return x * y
+
+
+class RCAB(nn.Layer):
+    def __init__(self,
+                 conv,
+                 n_feat,
+                 kernel_size,
+                 reduction=16,
+                 bias=True,
+                 bn=False,
+                 act=nn.ReLU(),
+                 res_scale=1):
+        super(RCAB, self).__init__()
+        modules_body = []
+        for i in range(2):
+            modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
+            if bn: modules_body.append(nn.BatchNorm2D(n_feat))
+            if i == 0: modules_body.append(act)
+        modules_body.append(CALayer(n_feat, reduction))
+        self.body = nn.Sequential(*modules_body)
+        self.res_scale = res_scale
+
+    def forward(self, x):
+        res = self.body(x)
+        res += x
+        return res
+
+
+## Residual Group (RG)
+class ResidualGroup(nn.Layer):
+    def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale,
+                 n_resblocks):
+        super(ResidualGroup, self).__init__()
+        modules_body = []
+        modules_body = [
+            RCAB(
+                conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(), res_scale=1) \
+            for _ in range(n_resblocks)]
+        modules_body.append(conv(n_feat, n_feat, kernel_size))
+        self.body = nn.Sequential(*modules_body)
+
+    def forward(self, x):
+        res = self.body(x)
+        res += x
+        return res
+
+
+class Upsampler(nn.Sequential):
+    def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
+        m = []
+        if (scale & (scale - 1)) == 0:  # Is scale = 2^n?
+            for _ in range(int(math.log(scale, 2))):
+                m.append(conv(n_feats, 4 * n_feats, 3, bias))
+                m.append(nn.PixelShuffle(2))
+                if bn: m.append(nn.BatchNorm2D(n_feats))
+
+                if act == 'relu':
+                    m.append(nn.ReLU())
+                elif act == 'prelu':
+                    m.append(nn.PReLU(n_feats))
+
+        elif scale == 3:
+            m.append(conv(n_feats, 9 * n_feats, 3, bias))
+            m.append(nn.PixelShuffle(3))
+            if bn: m.append(nn.BatchNorm2D(n_feats))
+
+            if act == 'relu':
+                m.append(nn.ReLU())
+            elif act == 'prelu':
+                m.append(nn.PReLU(n_feats))
+        else:
+            raise NotImplementedError
+
+        super(Upsampler, self).__init__(*m)
+
+
+@GENERATORS.register()
+class RCAN(nn.Layer):
+    def __init__(
+            self,
+            scale,
+            n_resgroups,
+            n_resblocks,
+            n_feats=64,
+            n_colors=3,
+            rgb_range=255,
+            kernel_size=3,
+            reduction=16,
+            conv=default_conv, ):
+        super(RCAN, self).__init__()
+        self.scale = scale
+        act = nn.ReLU()
+
+        n_resgroups = n_resgroups
+        n_resblocks = n_resblocks
+        n_feats = n_feats
+        kernel_size = kernel_size
+        reduction = reduction
+        scale = scale
+        act = nn.ReLU()
+
+        rgb_mean = (0.4488, 0.4371, 0.4040)
+        rgb_std = (1.0, 1.0, 1.0)
+        self.sub_mean = MeanShift(rgb_range, rgb_mean, rgb_std)
+
+        # define head module
+        modules_head = [conv(n_colors, n_feats, kernel_size)]
+
+        # define body module
+        modules_body = [
+            ResidualGroup(
+                conv, n_feats, kernel_size, reduction, act=act, res_scale= 1, n_resblocks=n_resblocks) \
+            for _ in range(n_resgroups)]
+
+        modules_body.append(conv(n_feats, n_feats, kernel_size))
+
+        # define tail module
+        modules_tail = [
+            Upsampler(
+                conv, scale, n_feats, act=False),
+            conv(n_feats, n_colors, kernel_size)
+        ]
+
+        self.head = nn.Sequential(*modules_head)
+        self.body = nn.Sequential(*modules_body)
+        self.tail = nn.Sequential(*modules_tail)
+
+        self.add_mean = MeanShift(rgb_range, rgb_mean, rgb_std, 1)
+
+    def forward(self, x):
+        x = self.sub_mean(x)
+        x = self.head(x)
+
+        res = self.body(x)
+        res += x
+
+        x = self.tail(res)
+        x = self.add_mean(x)
+
+        return x

+ 93 - 0
paddlers/custom_models/gan/rcan_model.py

@@ -0,0 +1,93 @@
+#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+
+from .generators.builder import build_generator
+from ...models.ppgan.models.criterions.builder import build_criterion
+from ...models.ppgan.models.base_model import BaseModel
+from ...models.ppgan.models.builder import MODELS
+from ...models.ppgan.utils.visual import tensor2img
+from ...models.ppgan.modules.init import reset_parameters
+
+
+@MODELS.register()
+class RCANModel(BaseModel):
+    """Base SR model for single image super-resolution.
+    """
+
+    def __init__(self, generator, pixel_criterion=None, use_init_weight=False):
+        """
+        Args:
+            generator (dict): config of generator.
+            pixel_criterion (dict): config of pixel criterion.
+        """
+        super(RCANModel, self).__init__()
+
+        self.nets['generator'] = build_generator(generator)
+
+        if pixel_criterion:
+            self.pixel_criterion = build_criterion(pixel_criterion)
+        if use_init_weight:
+            init_sr_weight(self.nets['generator'])
+
+    def setup_input(self, input):
+        self.lq = paddle.to_tensor(input['lq'])
+        self.visual_items['lq'] = self.lq
+        if 'gt' in input:
+            self.gt = paddle.to_tensor(input['gt'])
+            self.visual_items['gt'] = self.gt
+        self.image_paths = input['lq_path']
+
+    def forward(self):
+        pass
+
+    def train_iter(self, optims=None):
+        optims['optim'].clear_grad()
+
+        self.output = self.nets['generator'](self.lq)
+        self.visual_items['output'] = self.output
+        # pixel loss
+        loss_pixel = self.pixel_criterion(self.output, self.gt)
+        self.losses['loss_pixel'] = loss_pixel
+
+        loss_pixel.backward()
+        optims['optim'].step()
+
+    def test_iter(self, metrics=None):
+        self.nets['generator'].eval()
+        with paddle.no_grad():
+            self.output = self.nets['generator'](self.lq)
+            self.visual_items['output'] = self.output
+        self.nets['generator'].train()
+
+        out_img = []
+        gt_img = []
+        for out_tensor, gt_tensor in zip(self.output, self.gt):
+            out_img.append(tensor2img(out_tensor, (0., 255.)))
+            gt_img.append(tensor2img(gt_tensor, (0., 255.)))
+
+        if metrics is not None:
+            for metric in metrics.values():
+                metric.update(out_img, gt_img)
+
+
+def init_sr_weight(net):
+    def reset_func(m):
+        if hasattr(m, 'weight') and (
+                not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D))):
+            reset_parameters(m)
+
+    net.apply(reset_func)

+ 34 - 0
paddlers/tasks/imagerestorer.py

@@ -751,3 +751,37 @@ class ESRGANet(BasicSRNet):
                 'name': 'CosineAnnealingRestartLR',
                 'eta_min': 1e-07
             }
+
+
+# RCAN模型训练
+class RCANet(BasicSRNet):
+    def __init__(
+            self,
+            scale=2,
+            n_resgroups=10,
+            n_resblocks=20, ):
+        super(RCANet, self).__init__()
+        self.min_max = '(0., 255.)'
+        self.generator = {
+            'name': 'RCAN',
+            'scale': scale,
+            'n_resgroups': n_resgroups,
+            'n_resblocks': n_resblocks
+        }
+        self.pixel_criterion = {'name': 'L1Loss'}
+        self.model = {
+            'name': 'RCANModel',
+            'generator': self.generator,
+            'pixel_criterion': self.pixel_criterion
+        }
+        self.optimizer = {
+            'name': 'Adam',
+            'net_names': ['generator'],
+            'beta1': 0.9,
+            'beta2': 0.99
+        }
+        self.lr_scheduler = {
+            'name': 'MultiStepDecay',
+            'milestones': [250000, 500000, 750000, 1000000],
+            'gamma': 0.5
+        }