|
@@ -32,8 +32,11 @@ class TestCDModel(TestModel):
|
|
|
self.assertIsInstance(output, list)
|
|
|
self.check_output_equal(len(output), len(target))
|
|
|
for o, t in zip(output, target):
|
|
|
- o = o.numpy()
|
|
|
- self.check_output_equal(o.shape, t.shape)
|
|
|
+ if isinstance(o, list):
|
|
|
+ self.check_output(o, t)
|
|
|
+ else:
|
|
|
+ o = o.numpy()
|
|
|
+ self.check_output_equal(o.shape, t.shape)
|
|
|
|
|
|
def set_inputs(self):
|
|
|
if self.EF_MODE == 'Concat':
|
|
@@ -234,5 +237,13 @@ class TestFCCDNModel(TestCDModel):
|
|
|
self.specs = [
|
|
|
dict(in_channels=3, num_classes=2),
|
|
|
dict(in_channels=8, num_classes=2),
|
|
|
- dict(in_channels=3, num_classes=8)
|
|
|
- ] # yapf: disable
|
|
|
+ dict(in_channels=3, num_classes=8),
|
|
|
+ dict(in_channels=3, num_classes=2, _phase='eval', _stop_grad=True)
|
|
|
+ ] # yapf: disable
|
|
|
+
|
|
|
+ def set_targets(self):
|
|
|
+ tar_c2 = [self.get_zeros_array(2), [self.get_zeros_array(1)] * 2]
|
|
|
+ self.targets = [
|
|
|
+ tar_c2, tar_c2, [self.get_zeros_array(8), tar_c2[1]],
|
|
|
+ [self.get_zeros_array(2)]
|
|
|
+ ]
|