Browse Source

Bypass OOM error for some models (#27)

Lin Manhui 2 years ago
parent
commit
ee05f40d72
2 changed files with 35 additions and 10 deletions
  1. 3 8
      tests/rs_models/test_cd_models.py
  2. 32 2
      tests/rs_models/test_model.py

+ 3 - 8
tests/rs_models/test_cd_models.py

@@ -12,11 +12,10 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-import platform
 from itertools import cycle
 from itertools import cycle
 
 
 import paddlers
 import paddlers
-from rs_models.test_model import TestModel
+from rs_models.test_model import TestModel, allow_oom
 
 
 __all__ = [
 __all__ = [
     'TestBITModel', 'TestCDNetModel', 'TestChangeStarModel', 'TestDSAMNetModel',
     'TestBITModel', 'TestCDNetModel', 'TestChangeStarModel', 'TestDSAMNetModel',
@@ -202,6 +201,7 @@ class TestSNUNetModel(TestCDModel):
         ]   # yapf: disable
         ]   # yapf: disable
 
 
 
 
+@allow_oom
 class TestSTANetModel(TestCDModel):
 class TestSTANetModel(TestCDModel):
     MODEL_CLASS = paddlers.rs_models.cd.STANet
     MODEL_CLASS = paddlers.rs_models.cd.STANet
 
 
@@ -216,6 +216,7 @@ class TestSTANetModel(TestCDModel):
         ]   # yapf: disable
         ]   # yapf: disable
 
 
 
 
+@allow_oom
 class TestChangeFormerModel(TestCDModel):
 class TestChangeFormerModel(TestCDModel):
     MODEL_CLASS = paddlers.rs_models.cd.ChangeFormer
     MODEL_CLASS = paddlers.rs_models.cd.ChangeFormer
 
 
@@ -226,9 +227,3 @@ class TestChangeFormerModel(TestCDModel):
             dict(**base_spec, decoder_softmax=True),
             dict(**base_spec, decoder_softmax=True),
             dict(**base_spec, embed_dim=56)
             dict(**base_spec, embed_dim=56)
         ]   # yapf: disable
         ]   # yapf: disable
-
-
-# HACK:FIXME: We observe an OOM error when running TestSTANetModel.test_forward() on a Windows machine.
-# Currently, we do not perform this test.
-if platform.system() == 'Windows':
-    TestSTANetModel.test_forward = lambda self: None

+ 32 - 2
tests/rs_models/test_model.py

@@ -18,6 +18,7 @@ import paddle
 import numpy as np
 import numpy as np
 from paddle.static import InputSpec
 from paddle.static import InputSpec
 
 
+from paddlers.utils import logging
 from testing_utils import CommonTest
 from testing_utils import CommonTest
 
 
 
 
@@ -37,20 +38,26 @@ class _TestModelNamespace:
             for i, (
             for i, (
                     input, model, target
                     input, model, target
             ) in enumerate(zip(self.inputs, self.models, self.targets)):
             ) in enumerate(zip(self.inputs, self.models, self.targets)):
-                with self.subTest(i=i):
+                try:
                     if isinstance(input, list):
                     if isinstance(input, list):
                         output = model(*input)
                         output = model(*input)
                     else:
                     else:
                         output = model(input)
                         output = model(input)
                     self.check_output(output, target)
                     self.check_output(output, target)
+                except:
+                    logging.warning(f"Model built with spec{i} failed!")
+                    raise
 
 
         def test_to_static(self):
         def test_to_static(self):
             for i, (
             for i, (
                     input, model, target
                     input, model, target
             ) in enumerate(zip(self.inputs, self.models, self.targets)):
             ) in enumerate(zip(self.inputs, self.models, self.targets)):
-                with self.subTest(i=i):
+                try:
                     static_model = paddle.jit.to_static(
                     static_model = paddle.jit.to_static(
                         model, input_spec=self.get_input_spec(model, input))
                         model, input_spec=self.get_input_spec(model, input))
+                except:
+                    logging.warning(f"Model built with spec{i} failed!")
+                    raise
 
 
         def check_output(self, output, target):
         def check_output(self, output, target):
             pass
             pass
@@ -117,4 +124,27 @@ class _TestModelNamespace:
             return input_spec
             return input_spec
 
 
 
 
+def allow_oom(cls):
+    def _deco(func):
+        def _wrapper(self, *args, **kwargs):
+            try:
+                func(self, *args, **kwargs)
+            except (SystemError, RuntimeError, OSError) as e:
+                msg = str(e)
+                if "Out of memory error" in msg \
+                    or "(External) CUDNN error(4), CUDNN_STATUS_INTERNAL_ERROR." in msg:
+                    logging.warning("An OOM error has been ignored.")
+                else:
+                    raise
+
+        return _wrapper
+
+    for key, value in inspect.getmembers(cls):
+        if key.startswith('test'):
+            value = _deco(value)
+            setattr(cls, key, value)
+
+    return cls
+
+
 TestModel = _TestModelNamespace.TestModel
 TestModel = _TestModelNamespace.TestModel