|
@@ -52,16 +52,21 @@ class TestFarSegModel(TestSegModel):
|
|
|
MODEL_CLASS = paddlers.rs_models.seg.FarSeg
|
|
|
|
|
|
def set_specs(self):
|
|
|
+ base_spec = dict(in_channels=3, num_classes=2)
|
|
|
self.specs = [
|
|
|
- dict(), dict(
|
|
|
- in_channels=6, num_classes=10), dict(
|
|
|
- backbone='resnet18', backbone_pretrained=False), dict(
|
|
|
- fpn_out_channels=128,
|
|
|
- fsr_out_channels=64,
|
|
|
- decoder_out_channels=32), dict(scale_aware_proj=False)
|
|
|
- ]
|
|
|
+ base_spec,
|
|
|
+ dict(in_channels=6, num_classes=10),
|
|
|
+ dict(**base_spec,
|
|
|
+ backbone='resnet18',
|
|
|
+ backbone_pretrained=False),
|
|
|
+ dict(**base_spec,
|
|
|
+ fpn_out_channels=128,
|
|
|
+ fsr_out_channels=64,
|
|
|
+ decoder_out_channels=32),
|
|
|
+ dict(**base_spec, scale_aware_proj=False)
|
|
|
+ ] # yapf: disable
|
|
|
|
|
|
def set_targets(self):
|
|
|
- self.targets = [[self.get_zeros_array(16)], [self.get_zeros_array(10)],
|
|
|
- [self.get_zeros_array(16)], [self.get_zeros_array(16)],
|
|
|
- [self.get_zeros_array(16)]]
|
|
|
+ self.targets = [[self.get_zeros_array(2)], [self.get_zeros_array(10)],
|
|
|
+ [self.get_zeros_array(2)], [self.get_zeros_array(2)],
|
|
|
+ [self.get_zeros_array(2)]]
|