Bladeren bron

Update style

Bobholamovic 2 jaren geleden
bovenliggende
commit
27bae3505c
1 gewijzigde bestanden met toevoegingen van 41 en 46 verwijderingen
  1. 41 46
      tests/rs_models/test_cd_models.py

+ 41 - 46
tests/rs_models/test_cd_models.py

@@ -76,12 +76,12 @@ class TestBITModel(TestCDModel):
     def set_specs(self):
         base_spec = dict(in_channels=3, num_classes=2)
         self.specs = [
-            base_spec, dict(
-                **base_spec, backbone='resnet34'), dict(
-                    **base_spec, n_stages=3), dict(
-                        **base_spec, enc_depth=4, dec_head_dim=16), dict(
-                            in_channels=4, num_classes=2), dict(
-                                in_channels=3, num_classes=8)
+            base_spec, 
+            dict(**base_spec, backbone='resnet34'), 
+            dict(**base_spec, n_stages=3), 
+            dict(**base_spec, enc_depth=4, dec_head_dim=16), 
+            dict(in_channels=4, num_classes=2), 
+            dict(in_channels=3, num_classes=8)
         ]
 
 
@@ -91,10 +91,9 @@ class TestCDNetModel(TestCDModel):
 
     def set_specs(self):
         self.specs = [
-            dict(
-                in_channels=6, num_classes=2), dict(
-                    in_channels=8, num_classes=2), dict(
-                        in_channels=6, num_classes=8)
+            dict(in_channels=6, num_classes=2), 
+            dict(in_channels=8, num_classes=2), 
+            dict(in_channels=6, num_classes=8)
         ]
 
 
@@ -103,9 +102,9 @@ class TestChangeStarModel(TestCDModel):
 
     def set_specs(self):
         self.specs = [
-            dict(num_classes=2), dict(num_classes=10), dict(
-                num_classes=2, mid_channels=128, num_convs=2), dict(
-                    num_classes=2, _phase='eval', _stop_grad=True)
+            dict(num_classes=2), dict(num_classes=10), 
+            dict(num_classes=2, mid_channels=128, num_convs=2), 
+            dict(num_classes=2, _phase='eval', _stop_grad=True)
         ]
 
     def set_targets(self):
@@ -124,11 +123,11 @@ class TestDSAMNetModel(TestCDModel):
     def set_specs(self):
         base_spec = dict(in_channels=3, num_classes=2)
         self.specs = [
-            base_spec, dict(
-                in_channels=8, num_classes=2), dict(
-                    in_channels=3, num_classes=8), dict(
-                        **base_spec, ca_ratio=4, sa_kernel=5), dict(
-                            **base_spec, _phase='eval', _stop_grad=True)
+            base_spec, 
+            dict(in_channels=8, num_classes=2), 
+            dict(in_channels=3, num_classes=8), 
+            dict(**base_spec, ca_ratio=4, sa_kernel=5), 
+            dict(*base_spec, _phase='eval', _stop_grad=True)
         ]
 
     def set_targets(self):
@@ -145,9 +144,9 @@ class TestDSIFNModel(TestCDModel):
 
     def set_specs(self):
         self.specs = [
-            dict(num_classes=2), dict(num_classes=10), dict(
-                num_classes=2, use_dropout=True), dict(
-                    num_classes=2, _phase='eval', _stop_grad=True)
+            dict(num_classes=2), dict(num_classes=10), 
+            dict(num_classes=2, use_dropout=True), 
+            dict(num_classes=2, _phase='eval', _stop_grad=True)
         ]
 
     def set_targets(self):
@@ -165,11 +164,10 @@ class TestFCEarlyFusionModel(TestCDModel):
 
     def set_specs(self):
         self.specs = [
-            dict(
-                in_channels=6, num_classes=2), dict(
-                    in_channels=8, num_classes=2), dict(
-                        in_channels=6, num_classes=8), dict(
-                            in_channels=6, num_classes=2, use_dropout=True)
+            dict(in_channels=6, num_classes=2), 
+            dict(in_channels=8, num_classes=2), 
+            dict(in_channels=6, num_classes=8), 
+            dict(in_channels=6, num_classes=2, use_dropout=True)
         ]
 
 
@@ -178,11 +176,10 @@ class TestFCSiamConcModel(TestCDModel):
 
     def set_specs(self):
         self.specs = [
-            dict(
-                in_channels=3, num_classes=2), dict(
-                    in_channels=8, num_classes=2), dict(
-                        in_channels=3, num_classes=8), dict(
-                            in_channels=3, num_classes=2, use_dropout=True)
+            dict(in_channels=3, num_classes=2), 
+            dict(in_channels=8, num_classes=2), 
+            dict(in_channels=3, num_classes=8), 
+            dict(in_channels=3, num_classes=2, use_dropout=True)
         ]
 
 
@@ -191,11 +188,10 @@ class TestFCSiamDiffModel(TestCDModel):
 
     def set_specs(self):
         self.specs = [
-            dict(
-                in_channels=3, num_classes=2), dict(
-                    in_channels=8, num_classes=2), dict(
-                        in_channels=3, num_classes=8), dict(
-                            in_channels=3, num_classes=2, use_dropout=True)
+            dict(in_channels=3, num_classes=2), 
+            dict(in_channels=8, num_classes=2), 
+            dict(in_channels=3, num_classes=8), 
+            dict(in_channels=3, num_classes=2, use_dropout=True)
         ]
 
 
@@ -204,11 +200,10 @@ class TestSNUNetModel(TestCDModel):
 
     def set_specs(self):
         self.specs = [
-            dict(
-                in_channels=3, num_classes=2), dict(
-                    in_channels=8, num_classes=2), dict(
-                        in_channels=3, num_classes=8), dict(
-                            in_channels=3, num_classes=2, width=64)
+            dict(in_channels=3, num_classes=2), 
+            dict(in_channels=8, num_classes=2), 
+            dict(in_channels=3, num_classes=8), 
+            dict(in_channels=3, num_classes=2, width=64)
         ]
 
 
@@ -218,9 +213,9 @@ class TestSTANetModel(TestCDModel):
     def set_specs(self):
         base_spec = dict(in_channels=3, num_classes=2)
         self.specs = [
-            base_spec, dict(
-                in_channels=8, num_classes=2), dict(
-                    in_channels=3, num_classes=8), dict(
-                        **base_spec, att_type='PAM'), dict(
-                            **base_spec, ds_factor=4)
+            base_spec, 
+            dict(in_channels=8, num_classes=2), 
+            dict(in_channels=3, num_classes=8), 
+            dict(**base_spec, att_type='PAM'), 
+            dict(**base_spec, ds_factor=4)
         ]