Browse Source

Add iSAID pretrained weight of FactSeg (#57)

Lin Manhui 2 years ago
parent
commit
b75e0d19d5
2 changed files with 11 additions and 2 deletions
  1. 9 1
      paddlers/utils/checkpoint.py
  2. 2 1
      tutorials/train/semantic_segmentation/factseg.py

+ 9 - 1
paddlers/utils/checkpoint.py

@@ -86,7 +86,8 @@ seg_pretrain_weights_dict = {
     'DeepLabV3P': ['CITYSCAPES', 'PascalVOC', 'IMAGENET'],
     'DeepLabV3P': ['CITYSCAPES', 'PascalVOC', 'IMAGENET'],
     'FastSCNN': ['CITYSCAPES'],
     'FastSCNN': ['CITYSCAPES'],
     'HRNet': ['CITYSCAPES', 'PascalVOC'],
     'HRNet': ['CITYSCAPES', 'PascalVOC'],
-    'BiSeNetV2': ['CITYSCAPES']
+    'BiSeNetV2': ['CITYSCAPES'],
+    'FactSeg': ['iSAID']
 }
 }
 
 
 cityscapes_weights = {
 cityscapes_weights = {
@@ -438,6 +439,11 @@ levircd_weights = {
     'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/stanet_levircd.pdparams'
     'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/stanet_levircd.pdparams'
 }
 }
 
 
+isaid_weights = {
+    'FactSeg_iSAID':
+    'https://paddlers.bj.bcebos.com/pretrained/seg/isaid/weights/factseg_isaid.pdparams'
+}
+
 
 
 def get_pretrain_weights(flag, class_name, save_dir, backbone_name=None):
 def get_pretrain_weights(flag, class_name, save_dir, backbone_name=None):
     if flag is None:
     if flag is None:
@@ -463,6 +469,8 @@ def get_pretrain_weights(flag, class_name, save_dir, backbone_name=None):
         url = coco_weights[weights_key]
         url = coco_weights[weights_key]
     elif flag == 'LEVIRCD':
     elif flag == 'LEVIRCD':
         url = levircd_weights[weights_key]
         url = levircd_weights[weights_key]
+    elif flag == 'iSAID':
+        url = isaid_weights[weights_key]
     else:
     else:
         raise ValueError('Given pretrained weights {} is undefined.'.format(
         raise ValueError('Given pretrained weights {} is undefined.'.format(
             flag))
             flag))

+ 2 - 1
tutorials/train/semantic_segmentation/factseg.py

@@ -83,7 +83,8 @@ model.train(
     # 每多少次迭代记录一次日志
     # 每多少次迭代记录一次日志
     log_interval_steps=4,
     log_interval_steps=4,
     save_dir=EXP_DIR,
     save_dir=EXP_DIR,
-    pretrain_weights=None,
+    # 使用iSAID数据集上的预训练权重
+    pretrain_weights='iSAID',
     # 初始学习率大小
     # 初始学习率大小
     learning_rate=0.001,
     learning_rate=0.001,
     # 是否使用early stopping策略,当精度不再改善时提前终止训练
     # 是否使用early stopping策略,当精度不再改善时提前终止训练