Prechádzať zdrojové kódy

[Fix] Remove extra line in changedetector.py (#26)

Lin Manhui 3 rokov pred
rodič
commit
119b474a8a
1 zmenil súbory, kde vykonal 45 pridanie a 64 odobranie
  1. 45 64
      paddlers/tasks/changedetector.py

+ 45 - 64
paddlers/tasks/changedetector.py

@@ -34,19 +34,9 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from .base import BaseModel
 from .utils import seg_metrics as metrics
 
-
 __all__ = [
-    "CDNet", 
-    "UNetEarlyFusion", 
-    "UNetSiamConc", 
-    "UNetSiamDiff", 
-    "STANet", 
-    "BIT", 
-    "SNUNet", 
-    "DSIFN", 
-    "DSAMNet", 
-    "ChangeStar"
-    "DSAMNet"
+    "CDNet", "UNetEarlyFusion", "UNetSiamConc", "UNetSiamDiff", "STANet", "BIT",
+    "SNUNet", "DSIFN", "DSAMNet", "ChangeStar"
 ]
 
 
@@ -75,8 +65,8 @@ class BaseChangeDetector(BaseModel):
 
     def build_net(self, **params):
         # TODO: add other model
-        net = cd.models.__dict__[self.model_name](
-            num_classes=self.num_classes, **params)
+        net = cd.models.__dict__[self.model_name](num_classes=self.num_classes,
+                                                  **params)
         return net
 
     def _fix_transforms_shape(self, image_shape):
@@ -136,8 +126,7 @@ class BaseChangeDetector(BaseModel):
                         .squeeze().numpy())
                     score_map_list.append(
                         F.softmax(
-                            logit, axis=-1).squeeze().numpy().astype(
-                                'float32'))
+                            logit, axis=-1).squeeze().numpy().astype('float32'))
             outputs['label_map'] = label_map_list
             outputs['score_map'] = score_map_list
 
@@ -145,8 +134,7 @@ class BaseChangeDetector(BaseModel):
             if self.status == 'Infer':
                 pred = paddle.unsqueeze(net_out[0], axis=1)  # NCHW
             else:
-                pred = paddle.argmax(
-                    logit, axis=1, keepdim=True, dtype='int32')
+                pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
             label = inputs[2]
             origin_shape = [label.shape[-2:]]
             pred = self._postprocess(
@@ -159,13 +147,21 @@ class BaseChangeDetector(BaseModel):
             outputs['conf_mat'] = metrics.confusion_matrix(pred, label,
                                                            self.num_classes)
         if mode == 'train':
-            if hasattr(net, 'USE_MULTITASK_DECODER') and net.USE_MULTITASK_DECODER is True:
+            if hasattr(net, 'USE_MULTITASK_DECODER') and \
+                net.USE_MULTITASK_DECODER is True:
                 # CD+Seg
                 if len(inputs) != 5:
-                    raise ValueError("Cannot perform loss computation with {} inputs.".format(len(inputs)))
-                labels_list = [inputs[2+idx] for idx in map(attrgetter('value'), net.OUT_TYPES)]
+                    raise ValueError(
+                        "Cannot perform loss computation with {} inputs.".
+                        format(len(inputs)))
+                labels_list = [
+                    inputs[2 + idx]
+                    for idx in map(attrgetter('value'), net.OUT_TYPES)
+                ]
                 loss_list = metrics.multitask_loss_computation(
-                    logits_list=net_out, labels_list=labels_list, losses=self.losses)
+                    logits_list=net_out,
+                    labels_list=labels_list,
+                    losses=self.losses)
             else:
                 loss_list = metrics.loss_computation(
                     logits_list=net_out, labels=inputs[2], losses=self.losses)
@@ -291,8 +287,7 @@ class BaseChangeDetector(BaseModel):
                                 "set pretrain_weights to be None.".format(
                                     seg_pretrain_weights_dict[self.model_name][
                                         0]))
-                pretrain_weights = seg_pretrain_weights_dict[self.model_name][
-                    0]
+                pretrain_weights = seg_pretrain_weights_dict[self.model_name][0]
         elif pretrain_weights is not None and osp.exists(pretrain_weights):
             if osp.splitext(pretrain_weights)[-1] != '.pdparams':
                 logging.error(
@@ -404,8 +399,8 @@ class BaseChangeDetector(BaseModel):
         local_rank = paddle.distributed.get_rank()
         if nranks > 1:
             # Initialize parallel environment if not done.
-            if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
-            ):
+            if not (paddle.distributed.parallel.parallel_helper.
+                    _is_parallel_ctx_initialized()):
                 paddle.distributed.init_parallel_env()
 
         batch_size_each_card = get_single_card_bs(batch_size)
@@ -415,7 +410,8 @@ class BaseChangeDetector(BaseModel):
             logging.warning(
                 "Segmenter only supports batch_size=1 for each gpu/cpu card " \
                 "during evaluation, so batch_size " \
-                "is forcibly set to {}.".format(batch_size))
+                "is forcibly set to {}.".format(batch_size)
+            )
         self.eval_data_loader = self.build_data_loader(
             eval_dataset, batch_size=batch_size, mode='eval')
 
@@ -471,8 +467,8 @@ class BaseChangeDetector(BaseModel):
         # TODO 确认是按oacc还是macc
         class_acc, oacc = paddleseg.utils.metrics.accuracy(intersect_area_all,
                                                            pred_area_all)
-        kappa = paddleseg.utils.metrics.kappa(intersect_area_all,
-                                              pred_area_all, label_area_all)
+        kappa = paddleseg.utils.metrics.kappa(intersect_area_all, pred_area_all,
+                                              label_area_all)
         category_f1score = metrics.f1_score(intersect_area_all, pred_area_all,
                                             label_area_all)
         eval_metrics = OrderedDict(
@@ -696,10 +692,7 @@ class UNetEarlyFusion(BaseChangeDetector):
                  in_channels=6,
                  use_dropout=False,
                  **params):
-        params.update({
-            'in_channels': in_channels,
-            'use_dropout': use_dropout
-        })
+        params.update({'in_channels': in_channels, 'use_dropout': use_dropout})
         super(UNetEarlyFusion, self).__init__(
             model_name='UNetEarlyFusion',
             num_classes=num_classes,
@@ -714,10 +707,7 @@ class UNetSiamConc(BaseChangeDetector):
                  in_channels=3,
                  use_dropout=False,
                  **params):
-        params.update({
-            'in_channels': in_channels,
-            'use_dropout': use_dropout
-        })
+        params.update({'in_channels': in_channels, 'use_dropout': use_dropout})
         super(UNetSiamConc, self).__init__(
             model_name='UNetSiamConc',
             num_classes=num_classes,
@@ -732,10 +722,7 @@ class UNetSiamDiff(BaseChangeDetector):
                  in_channels=3,
                  use_dropout=False,
                  **params):
-        params.update({
-            'in_channels': in_channels,
-            'use_dropout': use_dropout
-        })
+        params.update({'in_channels': in_channels, 'use_dropout': use_dropout})
         super(UNetSiamDiff, self).__init__(
             model_name='UNetSiamDiff',
             num_classes=num_classes,
@@ -768,16 +755,16 @@ class BIT(BaseChangeDetector):
                  num_classes=2,
                  use_mixed_loss=False,
                  in_channels=3,
-                 backbone='resnet18', 
-                 n_stages=4, 
-                 use_tokenizer=True, 
-                 token_len=4, 
-                 pool_mode='max', 
+                 backbone='resnet18',
+                 n_stages=4,
+                 use_tokenizer=True,
+                 token_len=4,
+                 pool_mode='max',
                  pool_size=2,
-                 enc_with_pos=True, 
-                 enc_depth=1, 
-                 enc_head_dim=64, 
-                 dec_depth=8, 
+                 enc_with_pos=True,
+                 enc_depth=1,
+                 enc_head_dim=64,
+                 dec_depth=8,
                  dec_head_dim=8,
                  **params):
         params.update({
@@ -808,10 +795,7 @@ class SNUNet(BaseChangeDetector):
                  in_channels=3,
                  width=32,
                  **params):
-        params.update({
-            'in_channels': in_channels,
-            'width': width
-        })
+        params.update({'in_channels': in_channels, 'width': width})
         super(SNUNet, self).__init__(
             model_name='SNUNet',
             num_classes=num_classes,
@@ -825,9 +809,7 @@ class DSIFN(BaseChangeDetector):
                  use_mixed_loss=False,
                  use_dropout=False,
                  **params):
-        params.update({
-            'use_dropout': use_dropout
-        })
+        params.update({'use_dropout': use_dropout})
         super(DSIFN, self).__init__(
             model_name='DSIFN',
             num_classes=num_classes,
@@ -838,8 +820,8 @@ class DSIFN(BaseChangeDetector):
         if self.use_mixed_loss is False:
             return {
                 # XXX: make sure the shallow copy works correctly here.
-                'types': [paddleseg.models.CrossEntropyLoss()]*5,
-                'coef': [1.0]*5
+                'types': [paddleseg.models.CrossEntropyLoss()] * 5,
+                'coef': [1.0] * 5
             }
         else:
             return super().default_loss()
@@ -868,9 +850,8 @@ class DSAMNet(BaseChangeDetector):
         if self.use_mixed_loss is False:
             return {
                 'types': [
-                    paddleseg.models.CrossEntropyLoss(), 
-                    paddleseg.models.DiceLoss(), 
-                    paddleseg.models.DiceLoss()
+                    paddleseg.models.CrossEntropyLoss(),
+                    paddleseg.models.DiceLoss(), paddleseg.models.DiceLoss()
                 ],
                 'coef': [1.0, 0.05, 0.05]
             }
@@ -903,8 +884,8 @@ class ChangeStar(BaseChangeDetector):
         if self.use_mixed_loss is False:
             return {
                 # XXX: make sure the shallow copy works correctly here.
-                'types': [paddleseg.models.CrossEntropyLoss()]*4,
-                'coef': [1.0]*4
+                'types': [paddleseg.models.CrossEntropyLoss()] * 4,
+                'coef': [1.0] * 4
             }
         else:
             return super().default_loss()