Bobholamovic 2 år sedan
förälder
incheckning
107e0083bb
1 ändrade filer med 15 tillägg och 10 borttagningar
  1. 15 10
      tests/rs_models/test_seg_models.py

+ 15 - 10
tests/rs_models/test_seg_models.py

@@ -52,16 +52,21 @@ class TestFarSegModel(TestSegModel):
     MODEL_CLASS = paddlers.rs_models.seg.FarSeg
     MODEL_CLASS = paddlers.rs_models.seg.FarSeg
 
 
     def set_specs(self):
     def set_specs(self):
+        base_spec = dict(in_channels=3, num_classes=2)
         self.specs = [
         self.specs = [
-            dict(), dict(
+            base_spec,
-                in_channels=6, num_classes=10), dict(
+            dict(in_channels=6, num_classes=10),
-                    backbone='resnet18', backbone_pretrained=False), dict(
+            dict(**base_spec,
-                        fpn_out_channels=128,
+                backbone='resnet18',
-                        fsr_out_channels=64,
+                backbone_pretrained=False),
-                        decoder_out_channels=32), dict(scale_aware_proj=False)
+            dict(**base_spec,
-        ]
+                fpn_out_channels=128,
+                fsr_out_channels=64,
+                decoder_out_channels=32),
+            dict(**base_spec, scale_aware_proj=False)
+        ]   # yapf: disable
 
 
     def set_targets(self):
     def set_targets(self):
-        self.targets = [[self.get_zeros_array(16)], [self.get_zeros_array(10)],
+        self.targets = [[self.get_zeros_array(2)], [self.get_zeros_array(10)],
-                        [self.get_zeros_array(16)], [self.get_zeros_array(16)],
+                        [self.get_zeros_array(2)], [self.get_zeros_array(2)],
-                        [self.get_zeros_array(16)]]
+                        [self.get_zeros_array(2)]]