|
@@ -21,7 +21,18 @@ import paddle
|
|
|
from . import logging
|
|
|
from .download import download_and_decompress
|
|
|
|
|
|
-cd_pretrain_weights_dict = {}
|
|
|
+cd_pretrain_weights_dict = {
|
|
|
+ 'BIT': ['LEVIRCD'],
|
|
|
+ 'CDNet': ['LEVIRCD'],
|
|
|
+ 'DSAMNet': ['LEVIRCD'],
|
|
|
+ 'DSIFN': ['LEVIRCD'],
|
|
|
+ 'FCEarlyFusion': ['LEVIRCD'],
|
|
|
+ 'FCSiamConc': ['LEVIRCD'],
|
|
|
+ 'FCSiamDiff': ['LEVIRCD'],
|
|
|
+ 'FCCDN': ['LEVIRCD'],
|
|
|
+ 'SNUNet': ['LEVIRCD'],
|
|
|
+ 'STANet': ['LEVIRCD']
|
|
|
+}
|
|
|
|
|
|
cls_pretrain_weights_dict = {
|
|
|
'ResNet50_vd': ['IMAGENET'],
|
|
@@ -404,6 +415,29 @@ coco_weights = {
|
|
|
'https://paddledet.bj.bcebos.com/models/mask_rcnn_r101_vd_fpn_1x_coco.pdparams'
|
|
|
}
|
|
|
|
|
|
+levircd_weights = {
|
|
|
+ 'BIT_LEVIRCD':
|
|
|
+ 'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/bit_levircd.pdparams',
|
|
|
+ 'CDNet_LEVIRCD':
|
|
|
+ 'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/cdnet_levircd.pdparams',
|
|
|
+ 'DSAMNet_LEVIRCD':
|
|
|
+ 'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/dsamnet_levircd.pdparams',
|
|
|
+ 'DSIFN_LEVIRCD':
|
|
|
+ 'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/dsifn_levircd.pdparams',
|
|
|
+ 'FCEarlyFusion_LEVIRCD':
|
|
|
+ 'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/fc_ef_levircd.pdparams',
|
|
|
+ 'FCSiamConc_LEVIRCD':
|
|
|
+ 'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/fc_siam_conc_levircd.pdparams',
|
|
|
+ 'FCSiamDiff_LEVIRCD':
|
|
|
+ 'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/fc_siam_diff_levircd.pdparams',
|
|
|
+ 'FCCDN_LEVIRCD':
|
|
|
+ 'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/fccdn_levircd.pdparams',
|
|
|
+ 'SNUNet_LEVIRCD':
|
|
|
+ 'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/snunet_levircd.pdparams',
|
|
|
+ 'STANet_LEVIRCD':
|
|
|
+ 'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/stanet_levircd.pdparams'
|
|
|
+}
|
|
|
+
|
|
|
|
|
|
def get_pretrain_weights(flag, class_name, save_dir, backbone_name=None):
|
|
|
if flag is None:
|
|
@@ -427,6 +461,8 @@ def get_pretrain_weights(flag, class_name, save_dir, backbone_name=None):
|
|
|
url = pascalvoc_weights[weights_key]
|
|
|
elif flag == 'COCO':
|
|
|
url = coco_weights[weights_key]
|
|
|
+ elif flag == 'LEVIRCD':
|
|
|
+ url = levircd_weights[weights_key]
|
|
|
else:
|
|
|
raise ValueError('Given pretrained weights {} is undefined.'.format(
|
|
|
flag))
|