|
@@ -28,6 +28,7 @@ import paddlers
|
|
|
import paddlers.models.ppseg as ppseg
|
|
|
import paddlers.rs_models.cd as cmcd
|
|
|
import paddlers.utils.logging as logging
|
|
|
+from paddlers.models import seg_losses
|
|
|
from paddlers.transforms import Resize, decode_image
|
|
|
from paddlers.utils import get_single_card_bs, DisablePrint
|
|
|
from paddlers.utils.checkpoint import seg_pretrain_weights_dict
|
|
@@ -45,6 +46,7 @@ class BaseChangeDetector(BaseModel):
|
|
|
model_name,
|
|
|
num_classes=2,
|
|
|
use_mixed_loss=False,
|
|
|
+ losses=None,
|
|
|
**params):
|
|
|
self.init_params = locals()
|
|
|
if 'with_net' in self.init_params:
|
|
@@ -56,7 +58,7 @@ class BaseChangeDetector(BaseModel):
|
|
|
self.model_name = model_name
|
|
|
self.num_classes = num_classes
|
|
|
self.use_mixed_loss = use_mixed_loss
|
|
|
- self.losses = None
|
|
|
+ self.losses = losses
|
|
|
self.labels = None
|
|
|
if params.get('with_net', True):
|
|
|
params.pop('with_net', None)
|
|
@@ -178,13 +180,13 @@ class BaseChangeDetector(BaseModel):
|
|
|
if isinstance(self.use_mixed_loss, bool):
|
|
|
if self.use_mixed_loss:
|
|
|
losses = [
|
|
|
- ppseg.models.CrossEntropyLoss(),
|
|
|
- ppseg.models.LovaszSoftmaxLoss()
|
|
|
+ seg_losses.CrossEntropyLoss(),
|
|
|
+ seg_losses.LovaszSoftmaxLoss()
|
|
|
]
|
|
|
coef = [.8, .2]
|
|
|
- loss_type = [ppseg.models.MixedLoss(losses=losses, coef=coef), ]
|
|
|
+ loss_type = [seg_losses.MixedLoss(losses=losses, coef=coef), ]
|
|
|
else:
|
|
|
- loss_type = [ppseg.models.CrossEntropyLoss()]
|
|
|
+ loss_type = [seg_losses.CrossEntropyLoss()]
|
|
|
else:
|
|
|
losses, coef = list(zip(*self.use_mixed_loss))
|
|
|
if not set(losses).issubset(
|
|
@@ -192,8 +194,8 @@ class BaseChangeDetector(BaseModel):
|
|
|
raise ValueError(
|
|
|
"Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported."
|
|
|
)
|
|
|
- losses = [getattr(ppseg.models, loss)() for loss in losses]
|
|
|
- loss_type = [ppseg.models.MixedLoss(losses=losses, coef=list(coef))]
|
|
|
+ losses = [getattr(seg_losses, loss)() for loss in losses]
|
|
|
+ loss_type = [seg_losses.MixedLoss(losses=losses, coef=list(coef))]
|
|
|
loss_coef = [1.0]
|
|
|
losses = {'types': loss_type, 'coef': loss_coef}
|
|
|
return losses
|
|
@@ -810,11 +812,17 @@ class BaseChangeDetector(BaseModel):
|
|
|
raise TypeError(
|
|
|
"`transforms.arrange` must be an ArrangeChangeDetector object.")
|
|
|
|
|
|
+ def set_losses(self, losses, weights=None):
|
|
|
+ if weights is None:
|
|
|
+ weights = [1. for _ in range(len(losses))]
|
|
|
+ self.losses = {'types': losses, 'coef': weights}
|
|
|
+
|
|
|
|
|
|
class CDNet(BaseChangeDetector):
|
|
|
def __init__(self,
|
|
|
num_classes=2,
|
|
|
use_mixed_loss=False,
|
|
|
+ losses=None,
|
|
|
in_channels=6,
|
|
|
**params):
|
|
|
params.update({'in_channels': in_channels})
|
|
@@ -822,6 +830,7 @@ class CDNet(BaseChangeDetector):
|
|
|
model_name='CDNet',
|
|
|
num_classes=num_classes,
|
|
|
use_mixed_loss=use_mixed_loss,
|
|
|
+ losses=losses,
|
|
|
**params)
|
|
|
|
|
|
|
|
@@ -829,6 +838,7 @@ class FCEarlyFusion(BaseChangeDetector):
|
|
|
def __init__(self,
|
|
|
num_classes=2,
|
|
|
use_mixed_loss=False,
|
|
|
+ losses=None,
|
|
|
in_channels=6,
|
|
|
use_dropout=False,
|
|
|
**params):
|
|
@@ -837,6 +847,7 @@ class FCEarlyFusion(BaseChangeDetector):
|
|
|
model_name='FCEarlyFusion',
|
|
|
num_classes=num_classes,
|
|
|
use_mixed_loss=use_mixed_loss,
|
|
|
+ losses=losses,
|
|
|
**params)
|
|
|
|
|
|
|
|
@@ -844,6 +855,7 @@ class FCSiamConc(BaseChangeDetector):
|
|
|
def __init__(self,
|
|
|
num_classes=2,
|
|
|
use_mixed_loss=False,
|
|
|
+ losses=None,
|
|
|
in_channels=3,
|
|
|
use_dropout=False,
|
|
|
**params):
|
|
@@ -852,6 +864,7 @@ class FCSiamConc(BaseChangeDetector):
|
|
|
model_name='FCSiamConc',
|
|
|
num_classes=num_classes,
|
|
|
use_mixed_loss=use_mixed_loss,
|
|
|
+ losses=losses,
|
|
|
**params)
|
|
|
|
|
|
|
|
@@ -859,6 +872,7 @@ class FCSiamDiff(BaseChangeDetector):
|
|
|
def __init__(self,
|
|
|
num_classes=2,
|
|
|
use_mixed_loss=False,
|
|
|
+ losses=None,
|
|
|
in_channels=3,
|
|
|
use_dropout=False,
|
|
|
**params):
|
|
@@ -867,6 +881,7 @@ class FCSiamDiff(BaseChangeDetector):
|
|
|
model_name='FCSiamDiff',
|
|
|
num_classes=num_classes,
|
|
|
use_mixed_loss=use_mixed_loss,
|
|
|
+ losses=losses,
|
|
|
**params)
|
|
|
|
|
|
|
|
@@ -874,6 +889,7 @@ class STANet(BaseChangeDetector):
|
|
|
def __init__(self,
|
|
|
num_classes=2,
|
|
|
use_mixed_loss=False,
|
|
|
+ losses=None,
|
|
|
in_channels=3,
|
|
|
att_type='BAM',
|
|
|
ds_factor=1,
|
|
@@ -887,6 +903,7 @@ class STANet(BaseChangeDetector):
|
|
|
model_name='STANet',
|
|
|
num_classes=num_classes,
|
|
|
use_mixed_loss=use_mixed_loss,
|
|
|
+ losses=losses,
|
|
|
**params)
|
|
|
|
|
|
|
|
@@ -894,6 +911,7 @@ class BIT(BaseChangeDetector):
|
|
|
def __init__(self,
|
|
|
num_classes=2,
|
|
|
use_mixed_loss=False,
|
|
|
+ losses=None,
|
|
|
in_channels=3,
|
|
|
backbone='resnet18',
|
|
|
n_stages=4,
|
|
@@ -925,6 +943,7 @@ class BIT(BaseChangeDetector):
|
|
|
model_name='BIT',
|
|
|
num_classes=num_classes,
|
|
|
use_mixed_loss=use_mixed_loss,
|
|
|
+ losses=losses,
|
|
|
**params)
|
|
|
|
|
|
|
|
@@ -932,6 +951,7 @@ class SNUNet(BaseChangeDetector):
|
|
|
def __init__(self,
|
|
|
num_classes=2,
|
|
|
use_mixed_loss=False,
|
|
|
+ losses=None,
|
|
|
in_channels=3,
|
|
|
width=32,
|
|
|
**params):
|
|
@@ -940,6 +960,7 @@ class SNUNet(BaseChangeDetector):
|
|
|
model_name='SNUNet',
|
|
|
num_classes=num_classes,
|
|
|
use_mixed_loss=use_mixed_loss,
|
|
|
+ losses=losses,
|
|
|
**params)
|
|
|
|
|
|
|
|
@@ -947,6 +968,7 @@ class DSIFN(BaseChangeDetector):
|
|
|
def __init__(self,
|
|
|
num_classes=2,
|
|
|
use_mixed_loss=False,
|
|
|
+ losses=None,
|
|
|
use_dropout=False,
|
|
|
**params):
|
|
|
params.update({'use_dropout': use_dropout})
|
|
@@ -954,13 +976,14 @@ class DSIFN(BaseChangeDetector):
|
|
|
model_name='DSIFN',
|
|
|
num_classes=num_classes,
|
|
|
use_mixed_loss=use_mixed_loss,
|
|
|
+ losses=losses,
|
|
|
**params)
|
|
|
|
|
|
def default_loss(self):
|
|
|
if self.use_mixed_loss is False:
|
|
|
return {
|
|
|
# XXX: make sure the shallow copy works correctly here.
|
|
|
- 'types': [ppseg.models.CrossEntropyLoss()] * 5,
|
|
|
+ 'types': [seg_losses.CrossEntropyLoss()] * 5,
|
|
|
'coef': [1.0] * 5
|
|
|
}
|
|
|
else:
|
|
@@ -973,6 +996,7 @@ class DSAMNet(BaseChangeDetector):
|
|
|
def __init__(self,
|
|
|
num_classes=2,
|
|
|
use_mixed_loss=False,
|
|
|
+ losses=None,
|
|
|
in_channels=3,
|
|
|
ca_ratio=8,
|
|
|
sa_kernel=7,
|
|
@@ -986,14 +1010,15 @@ class DSAMNet(BaseChangeDetector):
|
|
|
model_name='DSAMNet',
|
|
|
num_classes=num_classes,
|
|
|
use_mixed_loss=use_mixed_loss,
|
|
|
+ losses=losses,
|
|
|
**params)
|
|
|
|
|
|
def default_loss(self):
|
|
|
if self.use_mixed_loss is False:
|
|
|
return {
|
|
|
'types': [
|
|
|
- ppseg.models.CrossEntropyLoss(), ppseg.models.DiceLoss(),
|
|
|
- ppseg.models.DiceLoss()
|
|
|
+ seg_losses.CrossEntropyLoss(), seg_losses.DiceLoss(),
|
|
|
+ seg_losses.DiceLoss()
|
|
|
],
|
|
|
'coef': [1.0, 0.05, 0.05]
|
|
|
}
|
|
@@ -1007,6 +1032,7 @@ class ChangeStar(BaseChangeDetector):
|
|
|
def __init__(self,
|
|
|
num_classes=2,
|
|
|
use_mixed_loss=False,
|
|
|
+ losses=None,
|
|
|
mid_channels=256,
|
|
|
inner_channels=16,
|
|
|
num_convs=4,
|
|
@@ -1022,13 +1048,14 @@ class ChangeStar(BaseChangeDetector):
|
|
|
model_name='ChangeStar',
|
|
|
num_classes=num_classes,
|
|
|
use_mixed_loss=use_mixed_loss,
|
|
|
+ losses=losses,
|
|
|
**params)
|
|
|
|
|
|
def default_loss(self):
|
|
|
if self.use_mixed_loss is False:
|
|
|
return {
|
|
|
# XXX: make sure the shallow copy works correctly here.
|
|
|
- 'types': [ppseg.models.CrossEntropyLoss()] * 4,
|
|
|
+ 'types': [seglosses.CrossEntropyLoss()] * 4,
|
|
|
'coef': [1.0] * 4
|
|
|
}
|
|
|
else:
|