Browse Source

Merge pull request #36 from Bobholamovic/add_cd_pretr

[Feat] Add Pretrained Models for CD Tasks
cc 3 years ago
parent
commit
61c74a24bf
2 changed files with 39 additions and 3 deletions
  1. 37 1
      paddlers/utils/checkpoint.py
  2. 2 2
      tutorials/train/change_detection/fccdn.py

+ 37 - 1
paddlers/utils/checkpoint.py

@@ -21,7 +21,18 @@ import paddle
 from . import logging
 from . import logging
 from .download import download_and_decompress
 from .download import download_and_decompress
 
 
-cd_pretrain_weights_dict = {}
+cd_pretrain_weights_dict = {
+    'BIT': ['LEVIRCD'],
+    'CDNet': ['LEVIRCD'],
+    'DSAMNet': ['LEVIRCD'],
+    'DSIFN': ['LEVIRCD'],
+    'FCEarlyFusion': ['LEVIRCD'],
+    'FCSiamConc': ['LEVIRCD'],
+    'FCSiamDiff': ['LEVIRCD'],
+    'FCCDN': ['LEVIRCD'],
+    'SNUNet': ['LEVIRCD'],
+    'STANet': ['LEVIRCD']
+}
 
 
 cls_pretrain_weights_dict = {
 cls_pretrain_weights_dict = {
     'ResNet50_vd': ['IMAGENET'],
     'ResNet50_vd': ['IMAGENET'],
@@ -404,6 +415,29 @@ coco_weights = {
     'https://paddledet.bj.bcebos.com/models/mask_rcnn_r101_vd_fpn_1x_coco.pdparams'
     'https://paddledet.bj.bcebos.com/models/mask_rcnn_r101_vd_fpn_1x_coco.pdparams'
 }
 }
 
 
+levircd_weights = {
+    'BIT_LEVIRCD':
+    'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/bit_levircd.pdparams',
+    'CDNet_LEVIRCD':
+    'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/cdnet_levircd.pdparams',
+    'DSAMNet_LEVIRCD':
+    'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/dsamnet_levircd.pdparams',
+    'DSIFN_LEVIRCD':
+    'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/dsifn_levircd.pdparams',
+    'FCEarlyFusion_LEVIRCD':
+    'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/fc_ef_levircd.pdparams',
+    'FCSiamConc_LEVIRCD':
+    'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/fc_siam_conc_levircd.pdparams',
+    'FCSiamDiff_LEVIRCD':
+    'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/fc_siam_diff_levircd.pdparams',
+    'FCCDN_LEVIRCD':
+    'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/fccdn_levircd.pdparams',
+    'SNUNet_LEVIRCD':
+    'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/snunet_levircd.pdparams',
+    'STANet_LEVIRCD':
+    'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/stanet_levircd.pdparams'
+}
+
 
 
 def get_pretrain_weights(flag, class_name, save_dir, backbone_name=None):
 def get_pretrain_weights(flag, class_name, save_dir, backbone_name=None):
     if flag is None:
     if flag is None:
@@ -427,6 +461,8 @@ def get_pretrain_weights(flag, class_name, save_dir, backbone_name=None):
         url = pascalvoc_weights[weights_key]
         url = pascalvoc_weights[weights_key]
     elif flag == 'COCO':
     elif flag == 'COCO':
         url = coco_weights[weights_key]
         url = coco_weights[weights_key]
+    elif flag == 'LEVIRCD':
+        url = levircd_weights[weights_key]
     else:
     else:
         raise ValueError('Given pretrained weights {} is undefined.'.format(
         raise ValueError('Given pretrained weights {} is undefined.'.format(
             flag))
             flag))

+ 2 - 2
tutorials/train/change_detection/fccdn.py

@@ -78,11 +78,11 @@ model = pdrs.tasks.cd.FCCDN()
 
 
 # 执行模型训练
 # 执行模型训练
 model.train(
 model.train(
-    num_epochs=5,
+    num_epochs=10,
     train_dataset=train_dataset,
     train_dataset=train_dataset,
     train_batch_size=4,
     train_batch_size=4,
     eval_dataset=eval_dataset,
     eval_dataset=eval_dataset,
-    save_interval_epochs=2,
+    save_interval_epochs=4,
     # 每多少次迭代记录一次日志
     # 每多少次迭代记录一次日志
     log_interval_steps=50,
     log_interval_steps=50,
     save_dir=EXP_DIR,
     save_dir=EXP_DIR,