Browse Source

Fix seg test

Bobholamovic 2 years ago
parent
commit
107e0083bb
1 changed files with 15 additions and 10 deletions
  1. 15 10
      tests/rs_models/test_seg_models.py

+ 15 - 10
tests/rs_models/test_seg_models.py

@@ -52,16 +52,21 @@ class TestFarSegModel(TestSegModel):
     MODEL_CLASS = paddlers.rs_models.seg.FarSeg
 
     def set_specs(self):
+        base_spec = dict(in_channels=3, num_classes=2)
         self.specs = [
-            dict(), dict(
-                in_channels=6, num_classes=10), dict(
-                    backbone='resnet18', backbone_pretrained=False), dict(
-                        fpn_out_channels=128,
-                        fsr_out_channels=64,
-                        decoder_out_channels=32), dict(scale_aware_proj=False)
-        ]
+            base_spec,
+            dict(in_channels=6, num_classes=10),
+            dict(**base_spec,
+                backbone='resnet18',
+                backbone_pretrained=False),
+            dict(**base_spec,
+                fpn_out_channels=128,
+                fsr_out_channels=64,
+                decoder_out_channels=32),
+            dict(**base_spec, scale_aware_proj=False)
+        ]   # yapf: disable
 
     def set_targets(self):
-        self.targets = [[self.get_zeros_array(16)], [self.get_zeros_array(10)],
-                        [self.get_zeros_array(16)], [self.get_zeros_array(16)],
-                        [self.get_zeros_array(16)]]
+        self.targets = [[self.get_zeros_array(2)], [self.get_zeros_array(10)],
+                        [self.get_zeros_array(2)], [self.get_zeros_array(2)],
+                        [self.get_zeros_array(2)]]