Browse Source

[feature] Add NAFNet and SwinIR for cloud removal (#135)

kongdebug 2 years ago
parent
commit
9148c821a1

+ 2 - 0
README_CN.md

@@ -103,6 +103,8 @@ PaddleRS具有以下五大特色:
           <li><a href="./tutorials/train/image_restoration/drn.py">DRN</a></li>
           <li><a href="./tutorials/train/image_restoration/esrgan.py">ESRGAN</a></li>
           <li><a href="./tutorials/train/image_restoration/lesrcnn.py">LESRCNN</a></li>
+          <li><a href="./tutorials/train/image_restoration/nafnet.py">NAFNet</a></li>
+          <li><a href="./tutorials/train/image_restoration/swinir.py">SwinIR</a></li>
         </ul>
         </details>
         <details><summary><b>目标检测</b></summary>

+ 2 - 0
README_EN.md

@@ -101,6 +101,8 @@ PaddleRS is an end-to-end high-efficent development toolkit for remote sensing a
           <li><a href="./tutorials/train/image_restoration/drn.py">DRN</a></li>
           <li><a href="./tutorials/train/image_restoration/esrgan.py">ESRGAN</a></li>
           <li><a href="./tutorials/train/image_restoration/lesrcnn.py">LESRCNN</a></li>
+           <li><a href="./tutorials/train/image_restoration/nafnet.py">NAFNet</a></li>
+          <li><a href="./tutorials/train/image_restoration/swinir.py">SwinIR</a></li>
         </ul>
         </details>
         <details><summary><b>Object Detection</b></summary>

+ 36 - 0
docs/intro/model_cons_params_cn.md

@@ -311,6 +311,42 @@
 | `group (int)`        | 卷积操作的分组数量                                        | `1` |
 
 
+## `NAFNet`
+
+基于PaddlePaddle的NAFNet实现。
+
+| 参数名                  | 描述                                                                                      | 默认值 |
+|----------------------|-----------------------------------------------------------------------------------------| --- |
+| `losses (list)`      | 损失函数列表                                                                                  | `None` |
+| `sr_factor (int)`    | 图像复原的缩放因子。NAFNet不适用于图像超分辨率重建任务,不改变图像的大小,请设置`sr_factor`为`None` | `None` |
+| `min_max (tuple)`    | 输入图像的像素值的最小值和最大值。如果未指定,则使用数据类型的默认最小值和最大值                                               | `None` |
+| `use_tlsc (bool)` | 是否在推理时使用tlsc技术                                                | `False` |
+| `in_channels (int)`  | 输入图像的通道数                                                | `3` |
+| `width (int)`        | NAFBlock的通道数                                        | `32` |
+| `middle_blk_num (int)`        | 过渡模块中NAFBlock的数量                                        | `1` |
+| `enc_blk_nums (list[int])`         | 不同层编码器中NAFBlock的数量                                        | `None` |
+| `dec_blk_nums (list[int])`         | 不同层解码器中NAFBlock的数量                                        | `None` |
+
+
+## `SwinIR`
+
+基于PaddlePaddle的SwinIR实现。
+
+| 参数名                  | 描述                                                                                      | 默认值 |
+|----------------------|-----------------------------------------------------------------------------------------| --- |
+| `losses (list)`      | 损失函数列表                                                                                  | `None` |
+| `sr_factor (int)`    | 图像复原的缩放因子。如果原始图像大小为 `H` x `W`,则输出图像大小将为 `sr_factor * H` x `sr_factor * W`  | `1` |
+| `min_max (tuple)`    | 输入图像的像素值的最小值和最大值。如果未指定,则使用数据类型的默认最小值和最大值                                               | `None` |
+| `in_channels (int)`  | 输入图像的通道数                                                | `3` |
+| `img_size (int)`        | 输入图像块的大小                                       | `128` |
+| `window_size (int)`        | 窗口大小                                        | `8` |
+| `depths (list[int])`         | 每个Swin Transformer 层的深度                                     | `[6, 6, 6, 6, 6, 6]` |
+| `num_heads (list[int])`         | 不同层中注意力头的数量                                       | `[6, 6, 6, 6]` |
+| `embed_dim (int)`        | Patch embedding 的维度                                       | `96` |
+| `window_size (int)`        | MLP中隐藏维度与编码维度的比率                                        | `4` |
+
+
+
 ## `FasterRCNN`
 
 基于PaddlePaddle的Faster R-CNN实现。

+ 36 - 0
docs/intro/model_cons_params_en.md

@@ -306,6 +306,42 @@ The LESRCNN implementation based on PaddlePaddle.
 | `group (int)`        | Number of groups used in convolution operations.                                                                    | `1` |
 
 
+## `NAFNet`
+
+The NAFNet implementation based on PaddlePaddle.
+
+| Parameter Name       | Description                                                                                                                                                                                                        | Default Value |
+|----------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| --- |
+| `losses (list)`      | List of loss functions                                                                                                                                                                                             | `None` |
+| `sr_factor (int)`    | Scaling factor for image restoration. NAFNet is not suitable for image super-resolution tasks and does not change the size of the image. Please set the `sr factor` to `None` | `None` |
+| `min_max (tuple)`    | Minimum and maximum pixel values of the input image. If not specified, the data type's default minimum and maximum values are used                                                                                 | `None` |
+| `use_tlsc (bool)`     | Whether to use tlsc (test-time local statistics converter) during testing. If yes, tlsc will be used                                                                                                   | `False` |
+| `in_channels (int)`  | Number of channels of the input image                                         | `3` |
+| `width (int)`        | Number of channels of NAFBlock                                      | `32` |
+| `middle_blk_num (int)`        | Number of NAFBlocks in middle block                                        | `1` |
+| `enc_blk_nums (list[int])`         | Number of NAFBlocks in different layers of the encoder                                   | `None` |
+| `dec_blk_nums (list[int])`         | Number of NAFBlocks in different layers of the decoder                                   | `None` |
+
+
+## `SwinIR`
+
+The SwinIR implementation based on PaddlePaddle.
+
+| 参数名                  | 描述                                                                                      | 默认值 |
+|----------------------|-----------------------------------------------------------------------------------------| --- |
+| `losses (list)`      | List of loss functions                                                                                  | `None` |
+| `sr_factor (int)`    | Scaling factor for image restoration. The output image size will be the original image size multiplied by this factor. For example, if the original image is `H` x `W`, the output image will be `sr_factor * H` x `sr_factor * W` | `1` |
+| `min_max (tuple)`    | Minimum and maximum pixel values of the input image. If not specified, the data type's default minimum and maximum values are used                                                                                 | `None` |
+| `in_channels (int)`  | Number of channels of the input image                                                 | `3` |
+| `img_size (int)`        |  Input image size                                       | `128` |
+| `window_size (int)`        | Window size                                        | `8` |
+| `depths (list[int])`         | Depth of each Swin Transformer layer                                    | `[6, 6, 6, 6, 6, 6]` |
+| `num_heads (list[int])`         | Number of attention heads in different layers  | `[6, 6, 6, 6]` |
+| `embed_dim (int)`        | Patch embedding dimension    | `96` |
+| `window_size (int)`        | Ratio of MLP hidden dim to embedding dim                                   | `4` |
+
+
+
 ##  `FasterRCNN`
 
 The Faster R-CNN implementation based on PaddlePaddle.

+ 2 - 0
docs/intro/model_zoo_cn.md

@@ -30,6 +30,8 @@ PaddleRS目前已支持的全部模型如下(标注\*的为遥感专用模型
 | 图像复原 | DRN | 否 |
 | 图像复原 | ESRGAN | 是 |
 | 图像复原 | LESRCNN | 否 |
+| 图像复原 | NAFNet | 是 |
+| 图像复原 | SwinIR | 是 |
 | 目标检测 | Faster R-CNN | 否 |
 | 目标检测 | PP-YOLO | 否 |
 | 目标检测 | PP-YOLO Tiny | 否 |

+ 2 - 0
docs/intro/model_zoo_en.md

@@ -30,6 +30,8 @@ All models currently supported by PaddleRS are listed below (those marked \* are
 | Image Restoration | DRN | No |
 | Image Restoration | ESRGAN | Yes |
 | Image Restoration | LESRCNN | No |
+| Image Restoration | SwinIR | Yes |
+| Image Restoration | NAFNet | Yes |
 | Object Detection | Faster R-CNN | No |
 | Object Detection | PP-YOLO | No |
 | Object Detection | PP-YOLO Tiny | No |

+ 3 - 2
examples/README.md

@@ -42,10 +42,10 @@ PaddleRS提供从科学研究到产业应用的丰富示例,希望帮助遥感
 |[手把手教你PaddleRS实现变化检测](https://aistudio.baidu.com/aistudio/projectdetail/3737991)|奔向未来的样子|入门教程|变化检测|
 |[【PPSIG】PaddleRS变化检测模型部署:以BIT为例](https://aistudio.baidu.com/aistudio/projectdetail/4184759)|古代飞|入门教程|变化检测,模型部署|
 |[【PPSIG】PaddleRS实现遥感影像场景分类](https://aistudio.baidu.com/aistudio/projectdetail/4198965)|古代飞|入门教程|场景分类|
-|[PaddleRS:使用超分模型提高真实的低分辨率无人机影像的分割精度](https://aistudio.baidu.com/aistudio/projectdetail/3696814)|KeyK-小胡之父|应用案例|超分辨率重建,无人机影像|
+|[PaddleRS:使用超分模型提高真实的低分辨率无人机影像的分割精度](https://aistudio.baidu.com/aistudio/projectdetail/3696814)|不爱做科研的KeyK|应用案例|超分辨率重建,无人机影像|
 |[PaddleRS:无人机汽车识别](https://aistudio.baidu.com/aistudio/projectdetail/3713122)|geoyee|应用案例|目标检测,无人机影像|
 |[PaddleRS:高光谱卫星影像场景分类](https://aistudio.baidu.com/aistudio/projectdetail/3711240)|geoyee|应用案例|场景分类,高光谱影像|
-|[PaddleRS:利用卫星影像与数字高程模型进行滑坡识别](https://aistudio.baidu.com/aistudio/projectdetail/4066570)|KeyK-小胡之父|应用案例|图像分割,DEM|
+|[PaddleRS:利用卫星影像与数字高程模型进行滑坡识别](https://aistudio.baidu.com/aistudio/projectdetail/4066570)|不爱做科研的KeyK|应用案例|图像分割,DEM|
 |[为PaddleRS添加一个袖珍配置系统](https://aistudio.baidu.com/aistudio/projectdetail/4203534)|古代飞|创意开发||
 |[万丈高楼平地起 基于PaddleGAN与PaddleRS的建筑物生成](https://aistudio.baidu.com/aistudio/projectdetail/3716885)|奔向未来的样子|创意开发|超分辨率重建|
 |[【官方】第十一届 “中国软件杯”百度遥感赛项:变化检测功能](https://aistudio.baidu.com/aistudio/projectdetail/3684588)|古代飞|竞赛打榜|变化检测,比赛基线|
@@ -55,3 +55,4 @@ PaddleRS提供从科学研究到产业应用的丰富示例,希望帮助遥感
 |[【十一届软件杯】遥感解译赛道:变化检测任务——预赛第四名方案分享](https://aistudio.baidu.com/aistudio/projectdetail/4116895)|lzzzzzm|竞赛打榜|变化检测,高分方案|
 |[【方案分享】第十一届 “中国软件杯”大学生软件设计大赛遥感解译赛道 比赛方案分享](https://aistudio.baidu.com/aistudio/projectdetail/4146154)|trainer|竞赛打榜|变化检测,高分方案|
 |[遥感变化检测助力信贷场景下工程进度管控](https://aistudio.baidu.com/aistudio/projectdetail/4543160)|古代飞|产业范例|变化检测,金融风控|
+|[使用PaddleRS进行单幅遥感影像的薄云去除](https://aistudio.baidu.com/aistudio/projectdetail/5955630)|不爱做科研的KeyK|应用案例|云雾去除|

+ 82 - 1
paddlers/tasks/restorer.py

@@ -35,7 +35,7 @@ from .base import BaseModel
 from .utils.res_adapters import GANAdapter, OptimizerAdapter
 from .utils.infer_nets import InferResNet
 
-__all__ = ["DRN", "LESRCNN", "ESRGAN"]
+__all__ = ["DRN", "LESRCNN", "ESRGAN", "NAFNet", "SwinIR"]
 
 
 class BaseRestorer(BaseModel):
@@ -924,3 +924,84 @@ class RCAN(BaseRestorer):
             sr_factor=sr_factor,
             min_max=min_max,
             **params)
+
+
+class NAFNet(BaseRestorer):
+    def __init__(self,
+                 losses=None,
+                 sr_factor=None,
+                 min_max=None,
+                 use_tlsc=False,
+                 in_channels=3,
+                 width=32,
+                 middle_blk_num=1,
+                 enc_blk_nums=None,
+                 dec_blk_nums=None,
+                 **params):
+        if sr_factor is not None:
+            raise ValueError(f"`sr_factor` must be set to None.")
+
+        params.update({
+            'img_channel': in_channels,
+            'width': width,
+            'middle_blk_num': middle_blk_num,
+            'enc_blk_nums': enc_blk_nums,
+            'dec_blk_nums': dec_blk_nums
+        })
+        self.use_tlsc = use_tlsc
+
+        super(NAFNet, self).__init__(
+            model_name='NAFNet',
+            losses=losses,
+            sr_factor=sr_factor,
+            min_max=min_max,
+            **params)
+
+    def build_net(self, **params):
+        if not self.use_tlsc:
+            net = ppgan.models.generators.NAFNet(**params)
+        else:
+            net = ppgan.models.generators.NAFNetLocal(**params)
+        return net
+
+    def default_loss(self):
+        return res_losses.PSNRLoss()
+
+
+class SwinIR(BaseRestorer):
+    def __init__(self,
+                 losses=None,
+                 sr_factor=1,
+                 min_max=None,
+                 in_channels=3,
+                 img_size=128,
+                 window_size=8,
+                 depths=[6, 6, 6, 6, 6, 6],
+                 embed_dim=180,
+                 num_heads=[6, 6, 6, 6, 6, 6],
+                 mlp_ratio=2,
+                 **params):
+
+        params.update({
+            'in_chans': in_channels,
+            'upscale': sr_factor,
+            'img_size': img_size,
+            'window_size': window_size,
+            'depths': depths,
+            'embed_dim': embed_dim,
+            'num_heads': num_heads,
+            'mlp_ratio': mlp_ratio
+        })
+        super(SwinIR, self).__init__(
+            model_name='SwinIR',
+            losses=losses,
+            sr_factor=sr_factor,
+            min_max=min_max,
+            **params)
+
+    def build_net(self, **params):
+        net = ppgan.models.generators.SwinIR(**params)
+        return net
+
+    def default_loss(self):
+        return res_losses.CharbonnierLoss(eps=0.000000001, reduction='mean')

+ 47 - 0
tools/prepare_dataset/prepare_rice.py

@@ -0,0 +1,47 @@
+#!/usr/bin/env python
+
+import random
+import os.path as osp
+from glob import iglob
+from functools import reduce, partial
+
+from common import (get_default_parser, create_file_list, link_dataset,
+                    random_split, get_path_tuples)
+
+SUBSETS = ('train', 'val')
+SUBDIRS = ('cloud', 'label')
+FILE_LIST_PATTERN = "{subset}.txt"
+
+if __name__ == '__main__':
+    parser = get_default_parser()
+    parser.add_argument('--seed', type=int, default=None, help="Random seed.")
+    parser.add_argument(
+        '--ratios',
+        type=float,
+        nargs='+',
+        default=(0.8, 0.2),
+        help="Ratios of each subset (train/val or train/val/test).")
+    args = parser.parse_args()
+
+    if args.seed is not None:
+        random.seed(args.seed)
+
+    if len(args.ratios) not in (2, 3):
+        raise ValueError("Wrong number of ratios!")
+
+    out_dir = osp.join(args.out_dataset_dir,
+                       osp.basename(osp.normpath(args.in_dataset_dir)))
+
+    link_dataset(args.in_dataset_dir, args.out_dataset_dir)
+
+    path_tuples = get_path_tuples(
+        *(osp.join(out_dir, subdir) for subdir in SUBDIRS),
+        glob_pattern='**/*.png',
+        data_dir=args.out_dataset_dir)
+    splits = random_split(path_tuples, ratios=args.ratios)
+
+    for subset, split in zip(SUBSETS, splits):
+        file_list = osp.join(
+            args.out_dataset_dir, FILE_LIST_PATTERN.format(subset=subset))
+        create_file_list(file_list, split)
+        print(f"Write file list to {file_list}.")

+ 1 - 1
tools/utils/raster.py

@@ -87,7 +87,7 @@ class Raster:
                 self._src_data = gdal_obj
             else:
                 raise ValueError(
-                    "At least one of `path` and `gdal_obj` is not None.")
+                    "At least one of `path` and `gdal_obj` should not be None.")
         self.to_uint8 = to_uint8
         self._getInfo()
         self.setBands(band_list)

+ 109 - 0
tutorials/train/image_restoration/nafnet.py

@@ -0,0 +1,109 @@
+#!/usr/bin/env python
+
+# 图像复原模型NAFNet训练示例脚本
+# 执行此脚本前,请确认已正确安装PaddleRS库
+
+import paddle
+
+import paddlers as pdrs
+from paddlers import transforms as T
+
+# 数据集存放目录
+DATA_DIR = './data/RICE1'
+# 训练集`file_list`文件路径
+TRAIN_FILE_LIST_PATH = './data/RICE1/train.txt'
+# 验证集`file_list`文件路径
+EVAL_FILE_LIST_PATH = './data/RICE1/val.txt'
+# 实验目录,保存输出的模型权重和结果
+EXP_DIR = './output/nafnet/'
+
+# 下载和解压遥感影像去云数据集
+pdrs.utils.download_and_decompress(
+    'https://paddlers.bj.bcebos.com/datasets/RICE1.zip', path='./data/')
+
+# 定义训练和验证时使用的数据变换(数据增强、预处理等)
+# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
+# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md
+train_transforms = [
+    # 从输入影像中裁剪256×256大小的影像块
+    T.RandomCrop(crop_size=256),
+    # 以50%的概率实施随机水平翻转
+    T.RandomHorizontalFlip(prob=0.5),
+    # 以50%的概率实施随机垂直翻转
+    T.RandomVerticalFlip(prob=0.5),
+    # 以默认设置实施随机的翻转或旋转
+    T.RandomFlipOrRotate(),
+    # 将数据归一化到[0,1]
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0])
+]
+
+eval_transforms = [
+    # 验证阶段与训练阶段的数据归一化方式必须相同
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0])
+]
+
+# 分别构建训练和验证所用的数据集
+train_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=TRAIN_FILE_LIST_PATH,
+    transforms=train_transforms,
+    num_workers=0,
+    shuffle=True,
+    sr_factor=None)
+
+eval_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=EVAL_FILE_LIST_PATH,
+    transforms=eval_transforms,
+    num_workers=0,
+    shuffle=False,
+    sr_factor=None)
+
+# 使用以下参数构建NAFNet模型
+in_channels = 3
+width = 32
+middle_blk_num = 12
+enc_blk_nums = [2, 2, 4, 8]
+dec_blk_nums = [2, 2, 2, 2]
+# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
+# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py
+
+model = pdrs.tasks.res.NAFNet(
+    in_channels=in_channels,
+    width=width,
+    middle_blk_num=middle_blk_num,
+    enc_blk_nums=enc_blk_nums,
+    dec_blk_nums=dec_blk_nums)
+
+# 制定余弦学习率衰减策略
+lr_scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
+    learning_rate=0.0006, T_max=4000, eta_min=8e-7)
+
+# 构造AdamW优化器
+optimizer = paddle.optimizer.AdamW(
+    learning_rate=lr_scheduler,
+    parameters=model.net.parameters(),
+    weight_decay=0.0,
+    beta1=0.9,
+    beta2=0.9,
+    epsilon=1e-8)
+
+# 执行模型训练
+model.train(
+    num_epochs=200,
+    train_dataset=train_dataset,
+    train_batch_size=20,
+    eval_dataset=eval_dataset,
+    optimizer=optimizer,
+    save_interval_epochs=10,
+    # 每多少次迭代记录一次日志
+    log_interval_steps=10,
+    save_dir=EXP_DIR,
+    # 是否使用early stopping策略,当精度不再改善时提前终止训练
+    early_stop=False,
+    # 是否启用VisualDL日志功能
+    use_vdl=True,
+    # 指定从某个检查点继续训练
+    resume_checkpoint=None)

+ 100 - 0
tutorials/train/image_restoration/swinir.py

@@ -0,0 +1,100 @@
+#!/usr/bin/env python
+
+# 图像复原模型SwinIR训练示例脚本
+# 执行此脚本前,请确认已正确安装PaddleRS库
+
+import paddle
+
+import paddlers as pdrs
+from paddlers import transforms as T
+
+# 数据集存放目录
+DATA_DIR = './data/RICE1'
+# 训练集`file_list`文件路径
+TRAIN_FILE_LIST_PATH = './data/RICE1/train.txt'
+# 验证集`file_list`文件路径
+EVAL_FILE_LIST_PATH = './data/RICE1/val.txt'
+# 实验目录,保存输出的模型权重和结果
+EXP_DIR = './output/swinir/'
+
+# 下载和解压遥感影像超分辨率数据集
+pdrs.utils.download_and_decompress(
+    'https://paddlers.bj.bcebos.com/datasets/RICE1.zip', path='./data/')
+
+# 定义训练和验证时使用的数据变换(数据增强、预处理等)
+# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
+# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md
+train_transforms = [
+    # 从输入影像中裁剪256×256大小的影像块
+    T.RandomCrop(crop_size=128),
+    # 以50%的概率实施随机水平翻转
+    T.RandomHorizontalFlip(prob=0.5),
+    # 以50%的概率实施随机垂直翻转
+    T.RandomVerticalFlip(prob=0.5),
+    # 以默认设置实施随机的翻转或旋转
+    T.RandomFlipOrRotate(),
+    # 将数据归一化到[0,1]
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0])
+]
+
+eval_transforms = [
+    # 验证阶段与训练阶段的数据归一化方式必须相同
+    T.Normalize(
+        mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0])
+]
+
+# 分别构建训练和验证所用的数据集
+train_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=TRAIN_FILE_LIST_PATH,
+    transforms=train_transforms,
+    num_workers=0,
+    shuffle=True,
+    sr_factor=1)
+
+eval_dataset = pdrs.datasets.ResDataset(
+    data_dir=DATA_DIR,
+    file_list=EVAL_FILE_LIST_PATH,
+    transforms=eval_transforms,
+    num_workers=0,
+    shuffle=False,
+    sr_factor=1)
+
+# 使用默认参数构建SwinIR模型
+# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
+# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py
+
+model = pdrs.tasks.res.SwinIR()
+
+# 制定定步长学习率衰减策略
+lr_scheduler = paddle.optimizer.lr.MultiStepDecay(
+    learning_rate=0.00005,
+    milestones=[20000, 30000, 35000, 38000, 40000],
+    gamma=0.5)
+
+# 构造Adam优化器
+optimizer = paddle.optimizer.Adam(
+    learning_rate=lr_scheduler,
+    parameters=model.net.parameters(),
+    beta1=0.9,
+    beta2=0.999,
+    epsilon=1e-8)
+
+# 执行模型训练
+model.train(
+    num_epochs=200,
+    train_dataset=train_dataset,
+    train_batch_size=2,
+    eval_dataset=eval_dataset,
+    optimizer=optimizer,
+    save_interval_epochs=10,
+    # 每多少次迭代记录一次日志
+    log_interval_steps=10,
+    save_dir=EXP_DIR,
+    # 是否使用early stopping策略,当精度不再改善时提前终止训练
+    early_stop=False,
+    # 是否启用VisualDL日志功能
+    use_vdl=True,
+    # 指定从某个检查点继续训练
+    resume_checkpoint=None)