Explorar o código

Update rs_models unittests

Bobholamovic %!s(int64=2) %!d(string=hai) anos
pai
achega
40a5edc039

+ 0 - 1
tests/check_coverage.sh

@@ -1,6 +1,5 @@
 #!/usr/bin/bash
 
-export PYTHONPATH=${PYTHONPATH}:'../tools'
 coverage run --source paddlers,$(ls -d ../tools/* | tr '\n' ',') --omit=../paddlers/models/* -m unittest discover -v
 coverage report
 coverage html -d coverage_html

+ 4 - 1
tests/rs_models/test_cd_models.py

@@ -69,6 +69,9 @@ class TestCDModel(TestModel):
         model = super().build_model(spec)
         return _CDModelAdapter(model)
 
+    def convert_to_static(self, model, input):
+        return super().convert_to_static(model.cd_model, input)
+
 
 class TestBITModel(TestCDModel):
     MODEL_CLASS = paddlers.custom_models.cd.BIT
@@ -127,7 +130,7 @@ class TestDSAMNetModel(TestCDModel):
             dict(in_channels=8, num_classes=2), 
             dict(in_channels=3, num_classes=8), 
             dict(**base_spec, ca_ratio=4, sa_kernel=5), 
-            dict(*base_spec, _phase='eval', _stop_grad=True)
+            dict(**base_spec, _phase='eval', _stop_grad=True)
         ]
 
     def set_targets(self):

+ 2 - 1
tests/rs_models/test_clas_models.py

@@ -32,4 +32,5 @@ class TestCDModel(TestModel):
         self.inputs = _gen_data(self.specs)
 
     def set_targets(self):
-        self.targets = [[2, spec.get('num_classes', 2)] for spec in self.specs]
+        self.targets = [[self.DEFAULT_BATCH_SIZE, spec.get('num_classes', 2)]
+                        for spec in self.specs]

+ 28 - 0
tests/rs_models/test_model.py

@@ -12,8 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import inspect
+
 import paddle
 import numpy as np
+from paddle.static import InputSpec
 
 from testing_utils import CommonTest
 
@@ -38,6 +41,15 @@ class _TestModelNamespace:
                     output = model(input)
                     self.check_output(output, target)
 
+        def test_to_static(self):
+            for i, (
+                    input, model, target
+            ) in enumerate(zip(self.inputs, self.models, self.targets)):
+                with self.subTest(i=i):
+                    static_model = self.convert_to_static(model, input)
+                    output = static_model(output)
+                    self.check_output(output, target)
+
         def check_output(self, output, target):
             pass
 
@@ -90,5 +102,21 @@ class _TestModelNamespace:
             shape = self.get_shape(c, b, h, w)
             return paddle.randn(shape)
 
+        def get_input_spec(self, model, input):
+            if not isinstance(input, list):
+                input = [input]
+            input_spec = []
+            for param_name, tensor in zip(
+                    inspect.signature(model.forward).parameters, input):
+                # XXX: Hard-code dtype
+                input_spec.append(
+                    InputSpec(
+                        shape=tensor.shape, name=param_name, dtype='float32'))
+            return input_spec
+
+        def convert_to_static(self, model, input):
+            return paddle.jit.to_static(
+                model, input_spec=self.get_input_spec(model, input))
+
 
 TestModel = _TestModelNamespace.TestModel

+ 0 - 1
tests/run_tests.sh

@@ -1,4 +1,3 @@
 #!/usr/bin/bash
 
-export PYTHONPATH=${PYTHONPATH}:'../tools'
 python -m unittest discover -v