Bobholamovic 2 жил өмнө
parent
commit
c4743373e3

+ 15 - 4
tests/rs_models/test_cd_models.py

@@ -32,8 +32,11 @@ class TestCDModel(TestModel):
         self.assertIsInstance(output, list)
         self.check_output_equal(len(output), len(target))
         for o, t in zip(output, target):
-            o = o.numpy()
-            self.check_output_equal(o.shape, t.shape)
+            if isinstance(o, list):
+                self.check_output(o, t)
+            else:
+                o = o.numpy()
+                self.check_output_equal(o.shape, t.shape)
 
     def set_inputs(self):
         if self.EF_MODE == 'Concat':
@@ -234,5 +237,13 @@ class TestFCCDNModel(TestCDModel):
         self.specs = [
             dict(in_channels=3, num_classes=2),
             dict(in_channels=8, num_classes=2),
-            dict(in_channels=3, num_classes=8)
-        ]   # yapf: disable
+            dict(in_channels=3, num_classes=8),
+            dict(in_channels=3, num_classes=2, _phase='eval', _stop_grad=True)
+        ]   # yapf: disable
+
+    def set_targets(self):
+        tar_c2 = [self.get_zeros_array(2), [self.get_zeros_array(1)] * 2]
+        self.targets = [
+            tar_c2, tar_c2, [self.get_zeros_array(8), tar_c2[1]],
+            [self.get_zeros_array(2)]
+        ]

+ 5 - 2
tests/rs_models/test_seg_models.py

@@ -25,8 +25,11 @@ class TestSegModel(TestModel):
         self.assertIsInstance(output, list)
         self.check_output_equal(len(output), len(target))
         for o, t in zip(output, target):
-            o = o.numpy()
-            self.check_output_equal(o.shape, t.shape)
+            if isinstance(o, list):
+                self.check_output(o, t)
+            else:
+                o = o.numpy()
+                self.check_output_equal(o.shape, t.shape)
 
     def set_inputs(self):
         def _gen_data(specs):