瀏覽代碼

Merge branch 'develop' into det1

cc 3 年之前
父節點
當前提交
cb370b169b

+ 1 - 0
paddlers/datasets/__init__.py

@@ -1,2 +1,3 @@
 from .voc import VOCDetection
 from .seg_dataset import SegDataset
+from .raster import Raster

+ 139 - 0
paddlers/datasets/raster.py

@@ -0,0 +1,139 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path as osp
+import numpy as np
+from typing import List, Tuple, Union
+from paddlers.utils import raster2uint8
+
+try:
+    from osgeo import gdal
+except:
+    import gdal
+
+
+class Raster:
+    def __init__(self, 
+                 path: str,
+                 band_list: Union[List[int], Tuple[int], None]=None, 
+                 to_uint8: bool=False) -> None:
+        """ Class of read raster.
+
+        Args:
+            path (str): The path of raster.
+            band_list (Union[List[int], Tuple[int], None], optional): 
+                band list (start with 1) or None (all of bands). Defaults to None.
+            to_uint8 (bool, optional): 
+                Convert uint8 or return raw data. Defaults to False.
+        """
+        super(Raster, self).__init__()
+        if osp.exists(path):
+            self.path = path
+            self.__src_data = np.load(path) if path.split(".")[-1] == "npy" \
+                                            else gdal.Open(path)
+            self.__getInfo()
+            self.to_uint8 = to_uint8
+            self.setBands(band_list)
+        else:
+            raise ValueError("The path {0} not exists.".format(path))
+
+    def setBands(self,
+                 band_list: Union[List[int], Tuple[int], None]) -> None:
+        """ Set band of data.
+
+        Args:
+            band_list (Union[List[int], Tuple[int], None]): 
+                band list (start with 1) or None (all of bands).
+        """
+        if band_list is not None:
+            if len(band_list) > self.bands:
+                raise ValueError("The lenght of band_list must be less than {0}.".format(str(self.bands)))
+            if max(band_list) > self.bands or min(band_list) < 1:
+                raise ValueError("The range of band_list must within [1, {0}].".format(str(self.bands)))
+        self.band_list = band_list
+
+    def getArray(self,
+                 start_loc: Union[List[int], Tuple[int], None]=None, 
+                 block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray:
+        """ Get ndarray data 
+
+        Args:
+            start_loc (Union[List[int], Tuple[int], None], optional): 
+                Coordinates of the upper left corner of the block, if None means return full image.
+            block_size (Union[List[int], Tuple[int]], optional): 
+                Block size. Defaults to [512, 512].
+
+        Returns:
+            np.ndarray: data's ndarray.
+        """
+        if start_loc is None:
+            return self.__getAarray()
+        else:
+            return self.__getBlock(start_loc, block_size)
+
+    def __getInfo(self) -> None:
+        self.bands = self.__src_data.RasterCount
+        self.width = self.__src_data.RasterXSize
+        self.height = self.__src_data.RasterYSize
+
+    def __getAarray(self, window: Union[None, List[int], Tuple[int]]=None) -> np.ndarray:
+        if window is not None:
+            xoff, yoff, xsize, ysize = window
+        if self.band_list is None:
+            if window is None:
+                ima = self.__src_data.ReadAsArray()
+            else:
+                ima = self.__src_data.ReadAsArray(xoff, yoff, xsize, ysize)
+        else:
+            band_array = []
+            for b in self.band_list:
+                if window is None:
+                    band_i = self.__src_data.GetRasterBand(b).ReadAsArray()
+                else:
+                    band_i = self.__src_data.GetRasterBand(b).ReadAsArray(xoff, yoff, xsize, ysize)
+                band_array.append(band_i)
+            ima = np.stack(band_array, axis=0)
+        if self.bands == 1:
+            # the type is complex means this is a SAR data
+            if isinstance(type(ima[0, 0]), complex):
+                ima = abs(ima)
+        else:
+            ima = ima.transpose((1, 2, 0))
+        if self.to_uint8 is True:
+            ima = raster2uint8(ima)
+        return ima
+
+    def __getBlock(self,
+                   start_loc: Union[List[int], Tuple[int]], 
+                   block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray:
+        if len(start_loc) != 2 or len(block_size) != 2:
+            raise ValueError("The length start_loc/block_size must be 2.")
+        xoff, yoff = start_loc
+        xsize, ysize = block_size
+        if (xoff < 0 or xoff > self.width) or (yoff < 0 or yoff > self.height):
+            raise ValueError(
+                "start_loc must be within [0-{0}, 0-{1}].".format(str(self.width), str(self.height)))
+        if xoff + xsize > self.width:
+            xsize = self.width - xoff
+        if yoff + ysize > self.height:
+            ysize = self.height - yoff
+        ima = self.__getAarray([int(xoff), int(yoff), int(xsize), int(ysize)])
+        h, w = ima.shape[:2] if len(ima.shape) == 3 else ima.shape
+        if self.bands != 1:
+            tmp = np.zeros((block_size[0], block_size[1], self.bands), dtype=ima.dtype)
+            tmp[:h, :w, :] = ima
+        else:
+            tmp = np.zeros((block_size[0], block_size[1]), dtype=ima.dtype)
+            tmp[:h, :w] = ima
+        return tmp

+ 1 - 1
paddlers/tools/yolo_cluster.py

@@ -99,7 +99,7 @@ class YOLOAnchorCluster(BaseAnchorCluster):
             num_anchors (int): number of clusters
             dataset (DataSet): DataSet instance, VOC or COCO
             image_size (list or int): [h, w], being an int means image height and image width are the same.
-            cache (bool): whether using cache Defaults to True.
+            cache (bool): whether using cache. Defaults to True.
             cache_path (str or None, optional): cache directory path. If None, use `data_dir` of dataset. Defaults to None.
             iters (int, optional): iters of kmeans algorithm. Defaults to 300.
             gen_iters (int, optional): iters of genetic algorithm. Defaults to 1000.

+ 2 - 2
paddlers/transforms/batch_operators.py

@@ -69,7 +69,7 @@ class BatchRandomResize(Transform):
     """
     Resize a batch of input to random sizes.
 
-    AttentionIf interp is 'RANDOM', the interpolation method will be chose randomly.
+    Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
 
     Args:
         target_sizes (List[int], List[list or tuple] or Tuple[list or tuple]):
@@ -108,7 +108,7 @@ class BatchRandomResize(Transform):
 class BatchRandomResizeByShort(Transform):
     """Resize a batch of input to random sizes with keeping the aspect ratio.
 
-    AttentionIf interp is 'RANDOM', the interpolation method will be chose randomly.
+    Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
 
     Args:
         short_sizes (List[int], Tuple[int]): Target sizes of the shorter side of the image(s).

+ 2 - 3
paddlers/transforms/img_decoder.py

@@ -1,5 +1,3 @@
-
-   
 # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,6 +19,7 @@ import copy
 import random
 import imghdr
 from PIL import Image
+
 try:
     from collections.abc import Sequence
 except Exception:
@@ -103,7 +102,7 @@ class ImgDecode(Transform):
                 return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
                                   cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR)
             else:
-                return cv2.imread(im_file, cv2.IMREAD_ANYDEPTH |
+                return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
                                   cv2.IMREAD_ANYCOLOR)
         elif ext == '.npy':
             return np.load(img_path)

+ 8 - 8
paddlers/transforms/operators.py

@@ -236,9 +236,9 @@ class Resize(Transform):
     """
     Resize input.
 
-    - If target_size is an intresize the image(s) to (target_size, target_size).
+    - If target_size is an int, resize the image(s) to (target_size, target_size).
     - If target_size is a list or tuple, resize the image(s) to target_size.
-    AttentionIf interp is 'RANDOM', the interpolation method will be chose randomly.
+    Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
 
     Args:
         target_size (int, List[int] or Tuple[int]): Target size. If int, the height and width share the same target_size.
@@ -347,7 +347,7 @@ class RandomResize(Transform):
     """
     Resize input to random sizes.
 
-    AttentionIf interp is 'RANDOM', the interpolation method will be chose randomly.
+    Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
 
     Args:
         target_sizes (List[int], List[list or tuple] or Tuple[list or tuple]):
@@ -388,7 +388,7 @@ class ResizeByShort(Transform):
     """
     Resize input with keeping the aspect ratio.
 
-    AttentionIf interp is 'RANDOM', the interpolation method will be chose randomly.
+    Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
 
     Args:
         short_size (int): Target size of the shorter side of the image(s).
@@ -427,7 +427,7 @@ class RandomResizeByShort(Transform):
     """
     Resize input to random sizes with keeping the aspect ratio.
 
-    AttentionIf interp is 'RANDOM', the interpolation method will be chose randomly.
+    Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
 
     Args:
         short_sizes (List[int]): Target size of the shorter side of the image(s).
@@ -865,8 +865,8 @@ class RandomCrop(Transform):
 class RandomScaleAspect(Transform):
     """
     Crop input image(s) and resize back to original sizes.
-    Args
-        min_scale (float)Minimum ratio between the cropped region and the original image.
+    Args: 
+        min_scale (float): Minimum ratio between the cropped region and the original image.
             If 0, image(s) will not be cropped. Defaults to .5.
         aspect_ratio (float): Aspect ratio of cropped region. Defaults to .33.
     """
@@ -1262,7 +1262,7 @@ class RandomBlur(Transform):
     """
     Randomly blur input image(s).
 
-    Args
+    Args: 
         prob (float): Probability of blurring.
     """
 

+ 1 - 0
paddlers/utils/__init__.py

@@ -22,3 +22,4 @@ from .env import get_environ_info, get_num_workers, init_parallel_env
 from .download import download_and_decompress, decompress
 from .stats import SmoothedValue, TrainingStats
 from .shm import _get_shared_memory_size_in_M
+from .convert import raster2uint8

+ 95 - 0
paddlers/utils/convert.py

@@ -0,0 +1,95 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import operator
+from functools import reduce
+
+
+def raster2uint8(image: np.ndarray) -> np.ndarray:
+    """ Convert raster to uint8.
+    Args:
+        image (np.ndarray): image.
+    Returns:
+        np.ndarray: image on uint8.
+    """
+    dtype = image.dtype.name
+    dtypes = ["uint8", "uint16", "float32"]
+    if dtype not in dtypes:
+        raise ValueError(f"'dtype' must be uint8/uint16/float32, not {dtype}.")
+    if dtype == "uint8":
+        return image
+    else:
+        if dtype == "float32":
+            image = _sample_norm(image)
+        return _two_percentLinear(image)
+
+
+# 2% linear stretch
+def _two_percentLinear(image: np.ndarray, max_out: int=255, min_out: int=0) -> np.ndarray:
+    def _gray_process(gray, maxout=max_out, minout=min_out):
+        # get the corresponding gray level at 98% histogram
+        high_value = np.percentile(gray, 98)
+        low_value = np.percentile(gray, 2)
+        truncated_gray = np.clip(gray, a_min=low_value, a_max=high_value)
+        processed_gray = ((truncated_gray - low_value) / (high_value - low_value)) * (maxout - minout)
+        return processed_gray
+    if len(image.shape) == 3:
+        processes = []
+        for b in range(image.shape[-1]):
+            processes.append(_gray_process(image[:, :, b]))
+        result = np.stack(processes, axis=2)
+    else:  # if len(image.shape) == 2
+        result = _gray_process(image)
+    return np.uint8(result)
+
+
+# simple image standardization
+def _sample_norm(image: np.ndarray, NUMS: int=65536) -> np.ndarray:
+    stretches = []
+    if len(image.shape) == 3:
+        for b in range(image.shape[-1]):
+            stretched = _stretch(image[:, :, b], NUMS)
+            stretched /= float(NUMS)
+            stretches.append(stretched)
+        stretched_img = np.stack(stretches, axis=2)
+    else:  # if len(image.shape) == 2
+        stretched_img = _stretch(image, NUMS)
+    return np.uint8(stretched_img * 255)
+
+
+# histogram equalization
+def _stretch(ima: np.ndarray, NUMS: int) -> np.ndarray:
+    hist = _histogram(ima, NUMS)
+    lut = []
+    for bt in range(0, len(hist), NUMS):
+        # step size
+        step = reduce(operator.add, hist[bt : bt + NUMS]) / (NUMS - 1)
+        # create balanced lookup table
+        n = 0
+        for i in range(NUMS):
+            lut.append(n / step)
+            n += hist[i + bt]
+        np.take(lut, ima, out=ima)
+        return ima
+
+
+# calculate histogram
+def _histogram(ima: np.ndarray, NUMS: int) -> np.ndarray:
+    bins = list(range(0, NUMS))
+    flat = ima.flat
+    n = np.searchsorted(np.sort(flat), bins)
+    n = np.concatenate([n, [len(flat)]])
+    hist = n[1:] - n[:-1]
+    return hist

+ 2 - 2
requirements.txt

@@ -8,10 +8,10 @@ paddleslim == 2.2.1
 shapely
 paddlepaddle-gpu >= 2.2.0
 opencv-python
-scikit-learn==0.20.3
+scikit-learn == 0.20.3
 lap
 motmetrics
 matplotlib
 chardet
 openpyxl
-gdal
+GDAL >= 3.2.2

+ 53 - 0
tutorials/train/README.md

@@ -0,0 +1,53 @@
+# 使用教程——训练模型
+
+本目录下整理了使用PaddleRS训练模型的示例代码,代码中均提供了示例数据的自动下载,并均使用单张GPU卡进行训练。
+
+|代码 | 模型任务 | 数据 |
+|------|--------|---------|
+|object_detection/ppyolo.py | 目标检测PPYOLO | 昆虫检测 |
+|semantic_segmentation/deeplabv3p_resnet50_vd.py | 语义分割DeepLabV3 | 视盘分割 |
+
+<!-- 可参考API接口说明了解示例代码中的API:
+* [数据集读取API](../../docs/apis/datasets.md)
+* [数据预处理和数据增强API](../../docs/apis/transforms/transforms.md)
+* [模型API/模型加载API](../../docs/apis/models/README.md)
+* [预测结果可视化API](../../docs/apis/visualize.md) -->
+
+
+# 环境准备
+
+- [PaddlePaddle安装](https://www.paddlepaddle.org.cn/install/quick)
+* 版本要求:PaddlePaddle>=2.1.0
+
+<!-- - [PaddleRS安装](../../docs/install.md) -->
+
+## 开始训练
+* 修改tutorials/train/semantic_segmentation/deeplabv3p_resnet50_vd.py中sys.path路径
+```
+sys.path.append("your/PaddleRS/path")
+```
+
+* 在安装PaddleRS后,使用如下命令开始训练,代码会自动下载训练数据, 并均使用单张GPU卡进行训练。
+
+```commandline
+export CUDA_VISIBLE_DEVICES=0
+python tutorials/train/semantic_segmentation/deeplabv3p_resnet50_vd.py
+```
+
+* 若需使用多张GPU卡进行训练,例如使用2张卡时执行:
+
+```commandline
+python -m paddle.distributed.launch --gpus 0,1 tutorials/train/semantic_segmentation/deeplabv3p_resnet50_vd.py
+```
+使用多卡时,参考[训练参数调整](../../docs/parameters.md)调整学习率和批量大小。
+
+
+## VisualDL可视化训练指标
+在模型训练过程,在`train`函数中,将`use_vdl`设为True,则训练过程会自动将训练日志以VisualDL的格式打点在`save_dir`(用户自己指定的路径)下的`vdl_log`目录,用户可以使用如下命令启动VisualDL服务,查看可视化指标
+```commandline
+visualdl --logdir output/deeplabv3p_resnet50_vd/vdl_log --port 8001
+```
+
+服务启动后,使用浏览器打开 https://0.0.0.0:8001 或 https://localhost:8001
+
+

+ 54 - 0
tutorials/train/object_detection/ppyolo.py

@@ -0,0 +1,54 @@
+import sys
+
+sys.path.append("/ssd2/pengjuncai/PaddleRS")
+
+import paddlers as pdrs
+from paddlers import transforms as T
+
+train_transforms = T.Compose([
+    T.MixupImage(mixup_epoch=-1), T.RandomDistort(),
+    T.RandomExpand(im_padding_value=[123.675, 116.28, 103.53]), T.RandomCrop(),
+    T.RandomHorizontalFlip(), T.BatchRandomResize(
+        target_sizes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608],
+        interp='RANDOM'), T.Normalize(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+])
+
+eval_transforms = T.Compose([
+    T.Resize(
+        target_size=608, interp='CUBIC'), T.Normalize(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+])
+
+
+train_dataset = pdrs.datasets.VOCDetection(
+    data_dir='insect_det',
+    file_list='insect_det/train_list.txt',
+    label_list='insect_det/labels.txt',
+    transforms=train_transforms,
+    shuffle=True)
+
+eval_dataset = pdrs.datasets.VOCDetection(
+    data_dir='insect_det',
+    file_list='insect_det/val_list.txt',
+    label_list='insect_det/labels.txt',
+    transforms=eval_transforms,
+    shuffle=False)
+
+
+num_classes = len(train_dataset.labels)
+model = pdrs.tasks.det.PPYOLO(num_classes=num_classes, backbone='ResNet50_vd_dcn')
+
+model.train(
+    num_epochs=200,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    pretrain_weights='COCO',
+    learning_rate=0.005 / 12,
+    warmup_steps=500,
+    warmup_start_lr=0.0,
+    save_interval_epochs=5,
+    lr_decay_epochs=[85, 135],
+    save_dir='output/ppyolo_r50vd_dcn',
+    use_vdl=True)