浏览代码

Merge pull request #8 from LutaoChu/seg-pipeline

Support multi-channel transform and model training
Liu Yi 3 年之前
父节点
当前提交
c6d911e71d

+ 5 - 1
docs/README.md

@@ -1 +1,5 @@
-PaddleSeg commit fec42fd869b6f796c74cd510671595e3512bc8e9
+PaddleSeg commit fec42fd869b6f796c74cd510671595e3512bc8e9
+
+# 开发规范
+请注意,paddlers/models/ppxxx系列除了修改import路径和支持多通道模型外,不要增删改任何代码。
+新增的模型需放在paddlers/models/下的seg、det、cls、cd目录下。

+ 40 - 0
docs/datasets.md

@@ -0,0 +1,40 @@
+# 遥感数据集
+
+遥感影像的格式多种多样,不同传感器产生的数据格式也可能不同。PaddleRS至少兼容以下6种格式图片读取:
+
+- `tif`
+- `png`, `jpeg`, `bmp`
+- `img`
+- `npy`
+
+标注图要求必须为单通道的png格式图像,像素值即为对应的类别,像素标注类别需要从0开始递增。例如0,1,2,3表示有4种类别,255用于指定不参与训练和评估的像素,标注类别最多为256类。
+
+## L8 SPARCS数据集
+[L8 SPARCS公开数据集](https://www.usgs.gov/land-resources/nli/landsat/spatial-procedures-automated-removal-cloud-and-shadow-sparcs-validation)进行云雪分割,该数据集包含80张卫星影像,涵盖10个波段。原始标注图片包含7个类别,分别是`cloud`, `cloud shadow`, `shadow over water`, `snow/ice`, `water`, `land`和`flooded`。由于`flooded`和`shadow over water`2个类别占比仅为`1.8%`和`0.24%`,我们将其进行合并,`flooded`归为`land`,`shadow over water`归为`shadow`,合并后标注包含5个类别。
+
+数值、类别、颜色对应表:
+
+|Pixel value|Class|Color|
+|---|---|---|
+|0|cloud|white|
+|1|shadow|black|
+|2|snow/ice|cyan|
+|3|water|blue|
+|4|land|grey|
+
+<p align="center">
+ <img src="./images/dataset.png" align="middle"
+</p>
+
+<p align='center'>
+ L8 SPARCS数据集示例
+</p>
+
+执行以下命令下载并解压经过类别合并后的数据集:
+```shell script
+mkdir dataset && cd dataset
+wget https://paddleseg.bj.bcebos.com/dataset/remote_sensing_seg.zip
+unzip remote_sensing_seg.zip
+cd ..
+```
+其中`data`目录存放遥感影像,`data_vis`目录存放彩色合成预览图,`mask`目录存放标注图。

+ 2 - 2
paddlers/datasets/seg_dataset.py

@@ -64,10 +64,10 @@ class SegDataset(Dataset):
                         " file_list[{}] has a space in the image or label path.".format(line, file_list))
                 items[0] = path_normalization(items[0])
                 items[1] = path_normalization(items[1])
-                if not is_pic(items[0]) or not is_pic(items[1]):
-                    continue
                 full_path_im = osp.join(data_dir, items[0])
                 full_path_label = osp.join(data_dir, items[1])
+                if not is_pic(full_path_im) or not is_pic(full_path_label):
+                    continue
                 if not osp.exists(full_path_im):
                     raise IOError('Image file {} does not exist!'.format(
                         full_path_im))

+ 2 - 2
paddlers/datasets/voc.py

@@ -23,7 +23,7 @@ from collections import OrderedDict
 import xml.etree.ElementTree as ET
 from paddle.io import Dataset
 from paddlers.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
-from paddlers.transforms import Decode, MixupImage
+from paddlers.transforms import ImgDecoder, MixupImage
 from paddlers.tools import YOLOAnchorCluster
 
 
@@ -319,7 +319,7 @@ class VOCDetection(Dataset):
             if self.data_fields is not None:
                 sample_mix = {k: sample_mix[k] for k in self.data_fields}
             sample = self.mixup_op(sample=[
-                Decode(to_rgb=False)(sample), Decode(to_rgb=False)(sample_mix)
+                ImgDecoder(to_rgb=False)(sample), ImgDecoder(to_rgb=False)(sample_mix)
             ])
         sample = self.transforms(sample)
         return sample

+ 3 - 2
paddlers/models/ppseg/models/backbones/resnet_vd.py

@@ -211,13 +211,14 @@ class ResNet_vd(nn.Layer):
     """
 
     def __init__(self,
+                 input_channel=3,
                  layers=50,
                  output_stride=8,
                  multi_grid=(1, 1, 1),
                  pretrained=None,
                  data_format='NCHW'):
         super(ResNet_vd, self).__init__()
-
+        
         self.data_format = data_format
         self.conv1_logit = None  # for gscnn shape stream
         self.layers = layers
@@ -251,7 +252,7 @@ class ResNet_vd(nn.Layer):
             dilation_dict = {3: 2}
 
         self.conv1_1 = ConvBNLayer(
-            in_channels=3,
+            in_channels=input_channel,
             out_channels=32,
             kernel_size=3,
             stride=2,

+ 2 - 2
paddlers/tasks/changedetector.py

@@ -28,7 +28,7 @@ import paddlers.utils.logging as logging
 from .base import BaseModel
 from .utils import seg_metrics as metrics
 from paddlers.utils.checkpoint import seg_pretrain_weights_dict
-from paddlers.transforms import Decode, Resize
+from paddlers.transforms import ImgDecoder, Resize
 from paddlers.models.ppcd import CDNet as _CDNet
 
 __all__ = ["CDNet"]
@@ -516,7 +516,7 @@ class BaseChangeDetector(BaseModel):
         for im in images:
             sample = {'image': im}
             if isinstance(sample['image'], str):
-                sample = Decode(to_rgb=False)(sample)
+                sample = ImgDecoder(to_rgb=False)(sample)
             ori_shape = sample['image'].shape[:2]
             im = transforms(sample)[0]
             batch_im.append(im)

+ 2 - 2
paddlers/tasks/classifier.py

@@ -29,7 +29,7 @@ from paddlers.models.ppcls.metric import build_metrics
 from paddlers.models.ppcls.loss import build_loss
 from paddlers.models.ppcls.data.postprocess import build_postprocess
 from paddlers.utils.checkpoint import cls_pretrain_weights_dict
-from paddlers.transforms import Decode, Resize
+from paddlers.transforms import ImgDecoder, Resize
 
 __all__ = ["ResNet50_vd", "MobileNetV3_small_x1_0", "HRNet_W18_C"]
 
@@ -433,7 +433,7 @@ class BaseClassifier(BaseModel):
         for im in images:
             sample = {'image': im}
             if isinstance(sample['image'], str):
-                sample = Decode(to_rgb=False)(sample)
+                sample = ImgDecoder(to_rgb=False)(sample)
             ori_shape = sample['image'].shape[:2]
             im = transforms(sample)[0]
             batch_im.append(im)

+ 4 - 2
paddlers/tasks/segmenter.py

@@ -28,7 +28,7 @@ import paddlers.utils.logging as logging
 from .base import BaseModel
 from .utils import seg_metrics as metrics
 from paddlers.utils.checkpoint import seg_pretrain_weights_dict
-from paddlers.transforms import Decode, Resize
+from paddlers.transforms import ImgDecoder, Resize
 
 __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"]
 
@@ -525,7 +525,7 @@ class BaseSegmenter(BaseModel):
         for im in images:
             sample = {'image': im}
             if isinstance(sample['image'], str):
-                sample = Decode(to_rgb=False)(sample)
+                sample = ImgDecode(to_rgb=False)(sample)
             ori_shape = sample['image'].shape[:2]
             im = transforms(sample)[0]
             batch_im.append(im)
@@ -679,6 +679,7 @@ class UNet(BaseSegmenter):
 
 class DeepLabV3P(BaseSegmenter):
     def __init__(self,
+                 input_channel=3,
                  num_classes=2,
                  backbone='ResNet50_vd',
                  use_mixed_loss=False,
@@ -696,6 +697,7 @@ class DeepLabV3P(BaseSegmenter):
         if params.get('with_net', True):
             with DisablePrint():
                 backbone = getattr(paddleseg.models, backbone)(
+                    input_channel=input_channel,
                     output_stride=output_stride)
         else:
             backbone = None

+ 0 - 157
paddlers/transforms/img_decoder.py

@@ -1,157 +0,0 @@
-# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import numpy as np
-import os.path as osp
-import cv2
-import copy
-import random
-import imghdr
-from PIL import Image
-
-try:
-    from collections.abc import Sequence
-except Exception:
-    from collections import Sequence
-
-# from paddlers.transforms.operators import Transform
-
-
-class Transform(object):
-    """
-    Parent class of all data augmentation operations
-    """
-
-    def __init__(self):
-        pass
-
-    def apply_im(self, image):
-        pass
-
-    def apply_mask(self, mask):
-        pass
-
-    def apply_bbox(self, bbox):
-        pass
-
-    def apply_segm(self, segms):
-        pass
-
-    def apply(self, sample):
-        sample['image'] = self.apply_im(sample['image'])
-        if 'mask' in sample:
-            sample['mask'] = self.apply_mask(sample['mask'])
-        if 'gt_bbox' in sample:
-            sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'])
-
-        return sample
-
-    def __call__(self, sample):
-        if isinstance(sample, Sequence):
-            sample = [self.apply(s) for s in sample]
-        else:
-            sample = self.apply(sample)
-
-        return sample
-
-
-class ImgDecode(Transform):
-    """
-    Decode image(s) in input.
-    Args:
-        to_rgb (bool, optional): If True, convert input images from BGR format to RGB format. Defaults to True.
-    """
-
-    def __init__(self, to_rgb=True):
-        super(ImgDecode, self).__init__()
-        self.to_rgb = to_rgb
-
-    def read_img(self, img_path, input_channel=3):
-        img_format = imghdr.what(img_path)
-        name, ext = osp.splitext(img_path)
-        if img_format == 'tiff' or ext == '.img':
-            try:
-                import gdal
-            except:
-                try:
-                    from osgeo import gdal
-                except:
-                    raise Exception(
-                        "Failed to import gdal! You can try use conda to install gdal"
-                    )
-                    six.reraise(*sys.exc_info())
-
-            dataset = gdal.Open(img_path)
-            if dataset == None:
-                raise Exception('Can not open', img_path)
-            im_data = dataset.ReadAsArray()
-            return im_data.transpose((1, 2, 0))
-        elif img_format in ['jpeg', 'bmp', 'png', 'jpg']:
-            if input_channel == 3:
-                return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
-                                  cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR)
-            else:
-                return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
-                                  cv2.IMREAD_ANYCOLOR)
-        elif ext == '.npy':
-            return np.load(img_path)
-        else:
-            raise Exception('Image format {} is not supported!'.format(ext))
-
-    def apply_im(self, im_path):
-        if isinstance(im_path, str):
-            try:
-                image = self.read_img(im_path)
-            except:
-                raise ValueError('Cannot read the image file {}!'.format(
-                    im_path))
-        else:
-            image = im_path
-
-        if self.to_rgb:
-            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
-
-        return image
-
-    def apply_mask(self, mask):
-        try:
-            mask = np.asarray(Image.open(mask))
-        except:
-            raise ValueError("Cannot read the mask file {}!".format(mask))
-        if len(mask.shape) != 2:
-            raise Exception(
-                "Mask should be a 1-channel image, but recevied is a {}-channel image.".
-                format(mask.shape[2]))
-        return mask
-
-    def apply(self, sample):
-        """
-        Args:
-            sample (dict): Input sample, containing 'image' at least.
-        Returns:
-            dict: Decoded sample.
-        """
-        sample['image'] = self.apply_im(sample['image'])
-        if 'mask' in sample:
-            sample['mask'] = self.apply_mask(sample['mask'])
-            im_height, im_width, _ = sample['image'].shape
-            se_height, se_width = sample['mask'].shape
-            if im_height != se_height or im_width != se_width:
-                raise Exception(
-                    "The height or width of the im is not same as the mask")
-
-        sample['im_shape'] = np.array(
-            sample['image'].shape[:2], dtype=np.float32)
-        sample['scale_factor'] = np.array([1., 1.], dtype=np.float32)
-        return sample

+ 64 - 62
paddlers/transforms/operators.py

@@ -31,7 +31,7 @@ from .functions import normalize, horizontal_flip, permute, vertical_flip, cente
     crop_rle, expand_poly, expand_rle, resize_poly, resize_rle
 
 __all__ = [
-    "Compose", "Decode", "Resize", "RandomResize", "ResizeByShort",
+    "Compose", "ImgDecoder", "Resize", "RandomResize", "ResizeByShort",
     "RandomResizeByShort", "ResizeByLong", "RandomHorizontalFlip",
     "RandomVerticalFlip", "Normalize", "CenterCrop", "RandomCrop",
     "RandomScaleAspect", "RandomExpand", "Padding", "MixupImage",
@@ -90,66 +90,15 @@ class Transform(object):
         return sample
 
 
-class Compose(Transform):
-    """
-    Apply a series of data augmentation to the input.
-    All input images are in Height-Width-Channel ([H, W, C]) format.
-
-    Args:
-        transforms (List[paddlers.transforms.Transform]): List of data preprocess or augmentations.
-    Raises:
-        TypeError: Invalid type of transforms.
-        ValueError: Invalid length of transforms.
-    """
-
-    def __init__(self, transforms):
-        super(Compose, self).__init__()
-        if not isinstance(transforms, list):
-            raise TypeError(
-                'Type of transforms is invalid. Must be List, but received is {}'
-                .format(type(transforms)))
-        if len(transforms) < 1:
-            raise ValueError(
-                'Length of transforms must not be less than 1, but received is {}'
-                .format(len(transforms)))
-        self.transforms = transforms
-        self.decode_image = Decode()
-        self.arrange_outputs = None
-        self.apply_im_only = False
-
-    def __call__(self, sample):
-        if self.apply_im_only and 'mask' in sample:
-            mask_backup = copy.deepcopy(sample['mask'])
-            del sample['mask']
-
-        sample = self.decode_image(sample)
-
-        for op in self.transforms:
-            # skip batch transforms amd mixup
-            if isinstance(op, (paddlers.transforms.BatchRandomResize,
-                               paddlers.transforms.BatchRandomResizeByShort,
-                               MixupImage)):
-                continue
-            sample = op(sample)
-
-        if self.arrange_outputs is not None:
-            if self.apply_im_only:
-                sample['mask'] = mask_backup
-            sample = self.arrange_outputs(sample)
-
-        return sample
-
-
-class Decode(Transform):
+class ImgDecoder(Transform):
     """
     Decode image(s) in input.
-
     Args:
         to_rgb (bool, optional): If True, convert input images from BGR format to RGB format. Defaults to True.
     """
 
     def __init__(self, to_rgb=True):
-        super(Decode, self).__init__()
+        super(ImgDecoder, self).__init__()
         self.to_rgb = to_rgb
 
     def read_img(self, img_path, input_channel=3):
@@ -172,7 +121,7 @@ class Decode(Transform):
                 raise Exception('Can not open', img_path)
             im_data = dataset.ReadAsArray()
             if im_data.ndim == 3:
-                im_data.transpose((1, 2, 0))
+                im_data = im_data.transpose((1, 2, 0))
             return im_data
         elif img_format in ['jpeg', 'bmp', 'png', 'jpg']:
             if input_channel == 3:
@@ -196,7 +145,7 @@ class Decode(Transform):
         else:
             image = im_path
 
-        if self.to_rgb:
+        if self.to_rgb and image.shape[-1] == 3:
             image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
         return image
@@ -214,13 +163,10 @@ class Decode(Transform):
 
     def apply(self, sample):
         """
-
         Args:
             sample (dict): Input sample, containing 'image' at least.
-
         Returns:
             dict: Decoded sample.
-
         """
         if 'image' in sample:
             sample['image'] = self.apply_im(sample['image'])
@@ -234,12 +180,63 @@ class Decode(Transform):
             if im_height != se_height or im_width != se_width:
                 raise Exception(
                     "The height or width of the im is not same as the mask")
+
         sample['im_shape'] = np.array(
             sample['image'].shape[:2], dtype=np.float32)
         sample['scale_factor'] = np.array([1., 1.], dtype=np.float32)
         return sample
 
 
+class Compose(Transform):
+    """
+    Apply a series of data augmentation to the input.
+    All input images are in Height-Width-Channel ([H, W, C]) format.
+
+    Args:
+        transforms (List[paddlers.transforms.Transform]): List of data preprocess or augmentations.
+    Raises:
+        TypeError: Invalid type of transforms.
+        ValueError: Invalid length of transforms.
+    """
+
+    def __init__(self, transforms):
+        super(Compose, self).__init__()
+        if not isinstance(transforms, list):
+            raise TypeError(
+                'Type of transforms is invalid. Must be List, but received is {}'
+                .format(type(transforms)))
+        if len(transforms) < 1:
+            raise ValueError(
+                'Length of transforms must not be less than 1, but received is {}'
+                .format(len(transforms)))
+        self.transforms = transforms
+        self.decode_image = ImgDecoder()
+        self.arrange_outputs = None
+        self.apply_im_only = False
+
+    def __call__(self, sample):
+        if self.apply_im_only and 'mask' in sample:
+            mask_backup = copy.deepcopy(sample['mask'])
+            del sample['mask']
+
+        sample = self.decode_image(sample)
+
+        for op in self.transforms:
+            # skip batch transforms amd mixup
+            if isinstance(op, (paddlers.transforms.BatchRandomResize,
+                               paddlers.transforms.BatchRandomResizeByShort,
+                               MixupImage)):
+                continue
+            sample = op(sample)
+
+        if self.arrange_outputs is not None:
+            if self.apply_im_only:
+                sample['mask'] = mask_backup
+            sample = self.arrange_outputs(sample)
+
+        return sample
+
+
 class Resize(Transform):
     """
     Resize input.
@@ -618,10 +615,16 @@ class Normalize(Transform):
     def __init__(self,
                  mean=[0.485, 0.456, 0.406],
                  std=[0.229, 0.224, 0.225],
-                 min_val=[0, 0, 0],
-                 max_val=[255., 255., 255.],
+                 min_val=None,
+                 max_val=None,
                  is_scale=True):
         super(Normalize, self).__init__()
+        channel = len(mean)
+        if min_val is None:
+            min_val = [0] * channel
+        if max_val is None:
+            max_val = [255.] * channel
+
         from functools import reduce
         if reduce(lambda x, y: x * y, std) == 0:
             raise ValueError(
@@ -633,7 +636,6 @@ class Normalize(Transform):
                     '(max_val - min_val) should not have 0, but received is {}'.
                     format((np.asarray(max_val) - np.asarray(min_val)).tolist(
                     )))
-
         self.mean = mean
         self.std = std
         self.min_val = min_val

+ 12 - 6
paddlers/utils/utils.py

@@ -14,8 +14,10 @@
 
 import sys
 import os
+import os.path as osp
 import time
 import math
+import imghdr
 import chardet
 import json
 import numpy as np
@@ -73,12 +75,16 @@ def path_normalization(path):
     return path
 
 
-def is_pic(img_name):
-    valid_suffix = ['JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png', 'tiff']
-    suffix = img_name.split('.')[-1]
-    if suffix not in valid_suffix:
-        return False
-    return True
+def is_pic(img_path):
+    valid_suffix = ['JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png', '.npy']
+    suffix = img_path.split('.')[-1]
+    if suffix in valid_suffix:
+        return True
+    img_format = imghdr.what(img_path)
+    _, ext = osp.splitext(img_path)
+    if img_format == 'tiff' or ext == '.img':
+        return True
+    return False
 
 
 class MyEncoder(json.JSONEncoder):

+ 5 - 5
tutorials/train/README.md

@@ -5,7 +5,7 @@
 |代码 | 模型任务 | 数据 |
 |------|--------|---------|
 |object_detection/ppyolo.py | 目标检测PPYOLO | 昆虫检测 |
-|semantic_segmentation/deeplabv3p_resnet50_vd.py | 语义分割DeepLabV3 | 视盘分割 |
+|semantic_segmentation/deeplabv3p_resnet50_multi_channel.py | 语义分割DeepLabV3 | 视盘分割 |
 |semantic_segmentation/farseg_test.py | 语义分割FarSeg | 遥感建筑分割 |
 |change_detection/cdnet_build.py | 变化检测CDNet | 遥感变化检测 |
 |classification/resnet50_vd_rs.py | 图像分类ResNet50_vd | 遥感场景分类 |
@@ -25,7 +25,7 @@
 <!-- - [PaddleRS安装](../../docs/install.md) -->
 
 ## 开始训练
-* 修改tutorials/train/semantic_segmentation/deeplabv3p_resnet50_vd.py中sys.path路径
+* 修改tutorials/train/semantic_segmentation/deeplabv3p_resnet50_multi_channel.py中sys.path路径
 ```
 sys.path.append("your/PaddleRS/path")
 ```
@@ -34,13 +34,13 @@ sys.path.append("your/PaddleRS/path")
 
 ```commandline
 export CUDA_VISIBLE_DEVICES=0
-python tutorials/train/semantic_segmentation/deeplabv3p_resnet50_vd.py
+python tutorials/train/semantic_segmentation/deeplabv3p_resnet50_multi_channel.py
 ```
 
 * 若需使用多张GPU卡进行训练,例如使用2张卡时执行:
 
 ```commandline
-python -m paddle.distributed.launch --gpus 0,1 tutorials/train/semantic_segmentation/deeplabv3p_resnet50_vd.py
+python -m paddle.distributed.launch --gpus 0,1 tutorials/train/semantic_segmentation/deeplabv3p_resnet50_multi_channel.py
 ```
 使用多卡时,参考[训练参数调整](../../docs/parameters.md)调整学习率和批量大小。
 
@@ -48,7 +48,7 @@ python -m paddle.distributed.launch --gpus 0,1 tutorials/train/semantic_segmenta
 ## VisualDL可视化训练指标
 在模型训练过程,在`train`函数中,将`use_vdl`设为True,则训练过程会自动将训练日志以VisualDL的格式打点在`save_dir`(用户自己指定的路径)下的`vdl_log`目录,用户可以使用如下命令启动VisualDL服务,查看可视化指标
 ```commandline
-visualdl --logdir output/deeplabv3p_resnet50_vd/vdl_log --port 8001
+visualdl --logdir output/deeplabv3p_resnet50_multi_channel/vdl_log --port 8001
 ```
 
 服务启动后,使用浏览器打开 https://0.0.0.0:8001 或 https://localhost:8001

+ 13 - 12
tutorials/train/semantic_segmentation/deeplabv3p_resnet50_vd.py → tutorials/train/semantic_segmentation/deeplabv3p_resnet50_multi_channel.py

@@ -5,39 +5,40 @@ sys.path.append("/mnt/chulutao/PaddleRS")
 import paddlers as pdrs
 from paddlers import transforms as T
 
-# 下载和解压视盘分割数据集
-optic_dataset = 'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz'
-pdrs.utils.download_and_decompress(optic_dataset, path='./')
+# 下载和解压多光谱地块分类数据集
+dataset = 'https://paddleseg.bj.bcebos.com/dataset/remote_sensing_seg.zip'
+pdrs.utils.download_and_decompress(dataset, path='./data')
 
 # 定义训练和验证时的transforms
 # API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/transforms/transforms.md
+channel = 10
 train_transforms = T.Compose([
     T.Resize(target_size=512),
     T.RandomHorizontalFlip(),
     T.Normalize(
-        mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+        mean=[0.5] * 10, std=[0.5] * 10),
 ])
 
 eval_transforms = T.Compose([
     T.Resize(target_size=512),
     T.Normalize(
-        mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+        mean=[0.5] * 10, std=[0.5] * 10),
 ])
 
 # 定义训练和验证所用的数据集
 # API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/datasets.md
 train_dataset = pdrs.datasets.SegDataset(
-    data_dir='optic_disc_seg',
-    file_list='optic_disc_seg/train_list.txt',
-    label_list='optic_disc_seg/labels.txt',
+    data_dir='./data/remote_sensing_seg',
+    file_list='./data/remote_sensing_seg/train.txt',
+    label_list='./data/remote_sensing_seg/labels.txt',
     transforms=train_transforms,
     num_workers=0,
     shuffle=True)
 
 eval_dataset = pdrs.datasets.SegDataset(
-    data_dir='optic_disc_seg',
-    file_list='optic_disc_seg/val_list.txt',
-    label_list='optic_disc_seg/labels.txt',
+    data_dir='./data/remote_sensing_seg',
+    file_list='./data/remote_sensing_seg/val.txt',
+    label_list='./data/remote_sensing_seg/labels.txt',
     transforms=eval_transforms,
     num_workers=0,
     shuffle=False)
@@ -45,7 +46,7 @@ eval_dataset = pdrs.datasets.SegDataset(
 # 初始化模型,并进行训练
 # 可使用VisualDL查看训练指标,参考https://github.com/PaddlePaddle/paddlers/blob/develop/docs/visualdl.md
 num_classes = len(train_dataset.labels)
-model = pdrs.tasks.DeepLabV3P(num_classes=num_classes, backbone='ResNet50_vd')
+model = pdrs.tasks.DeepLabV3P(input_channel=channel, num_classes=num_classes, backbone='ResNet50_vd')
 
 # API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/models/semantic_segmentation.md
 # 各参数介绍与调整说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/parameters.md