|
@@ -31,7 +31,7 @@ import paddlers.utils.logging as logging
|
|
|
from paddlers.models import seg_losses
|
|
|
from paddlers.transforms import Resize, decode_image
|
|
|
from paddlers.utils import get_single_card_bs
|
|
|
-from paddlers.utils.checkpoint import seg_pretrain_weights_dict
|
|
|
+from paddlers.utils.checkpoint import cd_pretrain_weights_dict
|
|
|
from .base import BaseModel
|
|
|
from .utils import seg_metrics as metrics
|
|
|
from .utils.infer_nets import InferCDNet
|
|
@@ -275,7 +275,7 @@ class BaseChangeDetector(BaseModel):
|
|
|
exit=True)
|
|
|
if pretrain_weights is not None and resume_checkpoint is not None:
|
|
|
logging.error(
|
|
|
- "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
|
|
|
+ "`pretrain_weights` and `resume_checkpoint` cannot be set simultaneously.",
|
|
|
exit=True)
|
|
|
self.labels = train_dataset.labels
|
|
|
if self.losses is None:
|
|
@@ -289,23 +289,30 @@ class BaseChangeDetector(BaseModel):
|
|
|
else:
|
|
|
self.optimizer = optimizer
|
|
|
|
|
|
- if pretrain_weights is not None and not osp.exists(pretrain_weights):
|
|
|
- if pretrain_weights not in seg_pretrain_weights_dict[
|
|
|
- self.model_name]:
|
|
|
- logging.warning(
|
|
|
- "Path of pretrain_weights('{}') does not exist!".format(
|
|
|
- pretrain_weights))
|
|
|
- logging.warning("Pretrain_weights is forcibly set to '{}'. "
|
|
|
- "If don't want to use pretrain weights, "
|
|
|
- "set pretrain_weights to be None.".format(
|
|
|
- seg_pretrain_weights_dict[self.model_name][
|
|
|
- 0]))
|
|
|
- pretrain_weights = seg_pretrain_weights_dict[self.model_name][0]
|
|
|
- elif pretrain_weights is not None and osp.exists(pretrain_weights):
|
|
|
- if osp.splitext(pretrain_weights)[-1] != '.pdparams':
|
|
|
- logging.error(
|
|
|
- "Invalid pretrain weights. Please specify a '.pdparams' file.",
|
|
|
- exit=True)
|
|
|
+ if pretrain_weights is not None:
|
|
|
+ if not osp.exists(pretrain_weights):
|
|
|
+ if self.model_name not in cd_pretrain_weights_dict:
|
|
|
+ logging.warning(
|
|
|
+ "Path of pretrained weights ('{}') does not exist!".
|
|
|
+ format(pretrain_weights))
|
|
|
+ pretrain_weights = None
|
|
|
+ elif pretrain_weights not in cd_pretrain_weights_dict[
|
|
|
+ self.model_name]:
|
|
|
+ logging.warning(
|
|
|
+ "Path of pretrained weights ('{}') does not exist!".
|
|
|
+ format(pretrain_weights))
|
|
|
+ pretrain_weights = cd_pretrain_weights_dict[
|
|
|
+ self.model_name][0]
|
|
|
+ logging.warning(
|
|
|
+ "`pretrain_weights` is forcibly set to '{}'. "
|
|
|
+ "If you don't want to use pretrained weights, "
|
|
|
+ "please set `pretrain_weights` to None.".format(
|
|
|
+ pretrain_weights))
|
|
|
+ else:
|
|
|
+ if osp.splitext(pretrain_weights)[-1] != '.pdparams':
|
|
|
+ logging.error(
|
|
|
+ "Invalid pretrained weights. Please specify a .pdparams file.",
|
|
|
+ exit=True)
|
|
|
pretrained_dir = osp.join(save_dir, 'pretrain')
|
|
|
is_backbone_weights = pretrain_weights == 'IMAGENET'
|
|
|
self.net_initialize(
|
|
@@ -409,18 +416,18 @@ class BaseChangeDetector(BaseModel):
|
|
|
key-value pairs:
|
|
|
For binary change detection (number of classes == 2), the key-value
|
|
|
pairs are like:
|
|
|
- {"iou": `intersection over union for the change class`,
|
|
|
- "f1": `F1 score for the change class`,
|
|
|
- "oacc": `overall accuracy`,
|
|
|
- "kappa": ` kappa coefficient`}.
|
|
|
+ {"iou": intersection over union for the change class,
|
|
|
+ "f1": F1 score for the change class,
|
|
|
+ "oacc": overall accuracy,
|
|
|
+ "kappa": kappa coefficient}.
|
|
|
For multi-class change detection (number of classes > 2), the key-value
|
|
|
pairs are like:
|
|
|
- {"miou": `mean intersection over union`,
|
|
|
- "category_iou": `category-wise mean intersection over union`,
|
|
|
- "oacc": `overall accuracy`,
|
|
|
- "category_acc": `category-wise accuracy`,
|
|
|
- "kappa": ` kappa coefficient`,
|
|
|
- "category_F1-score": `F1 score`}.
|
|
|
+ {"miou": mean intersection over union,
|
|
|
+ "category_iou": category-wise mean intersection over union,
|
|
|
+ "oacc": overall accuracy,
|
|
|
+ "category_acc": category-wise accuracy,
|
|
|
+ "kappa": kappa coefficient,
|
|
|
+ "category_F1-score": F1 score}.
|
|
|
"""
|
|
|
|
|
|
self._check_transforms(eval_dataset.transforms, 'eval')
|