Browse Source

add the detection training demo of sar ship by faster rcnn

juncaipeng 3 years ago
parent
commit
5f9bc1d54f
4 changed files with 59 additions and 62 deletions
  1. 23 4
      paddlers/__init__.py
  2. 35 3
      paddlers/transforms/operators.py
  3. 1 1
      paddlers/utils/utils.py
  4. 0 54
      tutorials/train/ppyolo.py

+ 23 - 4
paddlers/__init__.py

@@ -1,5 +1,24 @@
-from . import tasks, datasets, transforms, utils, tools, models
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
 
-# TODO, add these info in installation
-env_info = {'place': 'gpu', 'num': 1}
-__version__ = 0.1
+__version__ = '0.0.1'
+
+from paddlers.utils.env import get_environ_info, init_parallel_env
+init_parallel_env()
+
+env_info = get_environ_info()
+
+log_level = 2
+
+from . import tasks, datasets, transforms, utils, tools, models

+ 35 - 3
paddlers/transforms/operators.py

@@ -16,6 +16,8 @@ import numpy as np
 import cv2
 import copy
 import random
+import imghdr
+import os
 from PIL import Image
 import paddlers
 
@@ -146,9 +148,39 @@ class Decode(Transform):
         super(Decode, self).__init__()
         self.to_rgb = to_rgb
 
-    def read_img(self, img_path):
-        return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_ANYCOLOR |
-                          cv2.IMREAD_COLOR)
+    def read_img(self, img_path, input_channel=3):
+        img_format = imghdr.what(img_path)
+        name, ext = os.path.splitext(img_path)
+        if img_format == 'tiff' or ext == '.img':
+            try:
+                import gdal
+            except:
+                try:
+                    from osgeo import gdal
+                except:
+                    raise Exception(
+                        "Failed to import gdal! You can try use conda to install gdal"
+                    )
+                    six.reraise(*sys.exc_info())
+
+            dataset = gdal.Open(img_path)
+            if dataset == None:
+                raise Exception('Can not open', img_path)
+            im_data = dataset.ReadAsArray()
+            if im_data.ndim == 3:
+                im_data.transpose((1, 2, 0))
+            return im_data
+        elif img_format in ['jpeg', 'bmp', 'png', 'jpg']:
+            if input_channel == 3:
+                return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
+                                  cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR)
+            else:
+                return cv2.imread(im_file, cv2.IMREAD_ANYDEPTH |
+                                  cv2.IMREAD_ANYCOLOR)
+        elif ext == '.npy':
+            return np.load(img_path)
+        else:
+            raise Exception('Image format {} is not supported!'.format(ext))
 
     def apply_im(self, im_path):
         if isinstance(im_path, str):

+ 1 - 1
paddlers/utils/utils.py

@@ -74,7 +74,7 @@ def path_normalization(path):
 
 
 def is_pic(img_name):
-    valid_suffix = ['JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png']
+    valid_suffix = ['JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png', 'tiff']
     suffix = img_name.split('.')[-1]
     if suffix not in valid_suffix:
         return False

+ 0 - 54
tutorials/train/ppyolo.py

@@ -1,54 +0,0 @@
-import sys
-
-sys.path.append("/ssd2/pengjuncai/PaddleRS")
-
-import paddlers as pdrs
-from paddlers import transforms as T
-
-train_transforms = T.Compose([
-    T.MixupImage(mixup_epoch=-1), T.RandomDistort(),
-    T.RandomExpand(im_padding_value=[123.675, 116.28, 103.53]), T.RandomCrop(),
-    T.RandomHorizontalFlip(), T.BatchRandomResize(
-        target_sizes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608],
-        interp='RANDOM'), T.Normalize(
-            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
-])
-
-eval_transforms = T.Compose([
-    T.Resize(
-        target_size=608, interp='CUBIC'), T.Normalize(
-            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
-])
-
-
-train_dataset = pdrs.datasets.VOCDetection(
-    data_dir='insect_det',
-    file_list='insect_det/train_list.txt',
-    label_list='insect_det/labels.txt',
-    transforms=train_transforms,
-    shuffle=True)
-
-eval_dataset = pdrs.datasets.VOCDetection(
-    data_dir='insect_det',
-    file_list='insect_det/val_list.txt',
-    label_list='insect_det/labels.txt',
-    transforms=eval_transforms,
-    shuffle=False)
-
-
-num_classes = len(train_dataset.labels)
-model = pdrs.tasks.det.PPYOLO(num_classes=num_classes, backbone='ResNet50_vd_dcn')
-
-model.train(
-    num_epochs=200,
-    train_dataset=train_dataset,
-    train_batch_size=8,
-    eval_dataset=eval_dataset,
-    pretrain_weights='COCO',
-    learning_rate=0.005 / 12,
-    warmup_steps=500,
-    warmup_start_lr=0.0,
-    save_interval_epochs=5,
-    lr_decay_epochs=[85, 135],
-    save_dir='output/ppyolo_r50vd_dcn',
-    use_vdl=True)