Browse Source

Add installation and readme

juncaipeng 3 years ago
parent
commit
7d25f5855e
4 changed files with 133 additions and 0 deletions
  1. 1 0
      requirements.txt
  2. 43 0
      setup.py
  3. 64 0
      tutorials/train/detection/faster_rcnn_sar_ship.py
  4. 25 0
      tutorials/train/detection/readme.md

+ 1 - 0
requirements.txt

@@ -14,3 +14,4 @@ motmetrics
 matplotlib
 chardet
 openpyxl
+gdal

+ 43 - 0
setup.py

@@ -0,0 +1,43 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import setuptools
+
+long_description = "Awesome Remote Sensing Toolkit based on PaddlePaddle"
+
+setuptools.setup(
+    name="paddlers",
+    version='0.0.1',
+    author="paddlers",
+    author_email="paddlers@baidu.com",
+    description=long_description,
+    long_description=long_description,
+    long_description_content_type="text/plain",
+    url="https://github.com/PaddleCV-SIG/PaddleRS",
+    packages=setuptools.find_packages(),
+    setup_requires=['cython', 'numpy'],
+    install_requires=[
+        "pycocotools", 'pyyaml', 'colorama', 'tqdm', 'paddleslim==2.2.1',
+        'visualdl>=2.2.2', 'shapely>=1.7.0', 'opencv-python', 'scipy', 'lap',
+        'motmetrics', 'scikit-learn==0.23.2', 'chardet', 'flask_cors',
+        'openpyxl', 'gdal'
+    ],
+    classifiers=[
+        "Programming Language :: Python :: 3",
+        "License :: OSI Approved :: Apache Software License",
+        "Operating System :: OS Independent",
+    ],
+    license='Apache 2.0',
+    )
+

+ 64 - 0
tutorials/train/detection/faster_rcnn_sar_ship.py

@@ -0,0 +1,64 @@
+import os
+import paddlers as pdrs
+from paddlers import transforms as T
+
+# download dataset
+data_dir = 'sar_ship_1'
+if not os.path.exists(data_dir):
+    dataset_url = 'https://paddleseg.bj.bcebos.com/dataset/sar_ship_1.tar.gz'
+    pdrs.utils.download_and_decompress(dataset_url, path='./')
+
+# define transforms
+train_transforms = T.Compose([
+    T.RandomDistort(),
+    T.RandomExpand(im_padding_value=[123.675, 116.28, 103.53]),
+    T.RandomCrop(),
+    T.RandomHorizontalFlip(),
+    T.BatchRandomResize(
+        target_sizes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608],
+        interp='RANDOM'),
+    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+])
+
+eval_transforms = T.Compose([
+    T.Resize(
+        target_size=608, interp='CUBIC'), T.Normalize(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+])
+
+# define dataset
+train_file_list = os.path.join(data_dir, 'train.txt')
+val_file_list = os.path.join(data_dir, 'valid.txt')
+label_file_list = os.path.join(data_dir, 'labels.txt')
+train_dataset = pdrs.datasets.VOCDetection(
+    data_dir=data_dir,
+    file_list=train_file_list,
+    label_list=label_file_list,
+    transforms=train_transforms,
+    shuffle=True)
+
+eval_dataset = pdrs.datasets.VOCDetection(
+    data_dir=data_dir,
+    file_list=train_file_list,
+    label_list=label_file_list,
+    transforms=eval_transforms,
+    shuffle=False)
+
+# define models
+num_classes = len(train_dataset.labels)
+model = pdrs.tasks.det.FasterRCNN(num_classes=num_classes)
+
+# train
+model.train(
+    num_epochs=60,
+    train_dataset=train_dataset,
+    train_batch_size=2,
+    eval_dataset=eval_dataset,
+    pretrain_weights='COCO',
+    learning_rate=0.005 / 12,
+    warmup_steps=10,
+    warmup_start_lr=0.0,
+    save_interval_epochs=5,
+    lr_decay_epochs=[20, 40],
+    save_dir='output/faster_rcnn_sar_ship',
+    use_vdl=True)

+ 25 - 0
tutorials/train/detection/readme.md

@@ -0,0 +1,25 @@
+Run the detection training demo:
+
+1, Install PaddleRS
+
+```
+git clone https://github.com/PaddleCV-SIG/PaddleRS.git
+cd PaddleRS
+pip install -r requirements.txt
+python setup.py install
+```
+
+
+2. Run the demo
+
+```
+cd tutorials/train/detection/
+
+# run training on single GPU
+export CUDA_VISIBLE_DEVICES=0
+python faster_rcnn_sar_ship.py
+
+# run traing on multi gpu
+export CUDA_VISIBLE_DEVICES=0,1
+python -m paddle.distributed.launch faster_rcnn_sar_ship.py
+```