|
@@ -18,6 +18,7 @@ import paddle
|
|
|
import numpy as np
|
|
|
from paddle.static import InputSpec
|
|
|
|
|
|
+from paddlers.utils import logging
|
|
|
from testing_utils import CommonTest
|
|
|
|
|
|
|
|
@@ -37,20 +38,26 @@ class _TestModelNamespace:
|
|
|
for i, (
|
|
|
input, model, target
|
|
|
) in enumerate(zip(self.inputs, self.models, self.targets)):
|
|
|
- with self.subTest(i=i):
|
|
|
+ try:
|
|
|
if isinstance(input, list):
|
|
|
output = model(*input)
|
|
|
else:
|
|
|
output = model(input)
|
|
|
self.check_output(output, target)
|
|
|
+ except:
|
|
|
+ logging.warning(f"Model built with spec{i} failed!")
|
|
|
+ raise
|
|
|
|
|
|
def test_to_static(self):
|
|
|
for i, (
|
|
|
input, model, target
|
|
|
) in enumerate(zip(self.inputs, self.models, self.targets)):
|
|
|
- with self.subTest(i=i):
|
|
|
+ try:
|
|
|
static_model = paddle.jit.to_static(
|
|
|
model, input_spec=self.get_input_spec(model, input))
|
|
|
+ except:
|
|
|
+ logging.warning(f"Model built with spec{i} failed!")
|
|
|
+ raise
|
|
|
|
|
|
def check_output(self, output, target):
|
|
|
pass
|
|
@@ -117,4 +124,29 @@ class _TestModelNamespace:
|
|
|
return input_spec
|
|
|
|
|
|
|
|
|
+def allow_oom(cls):
|
|
|
+ def _deco(func):
|
|
|
+ def _wrapper(self, *args, **kwargs):
|
|
|
+ try:
|
|
|
+ func(self, *args, **kwargs)
|
|
|
+ except (SystemError, RuntimeError, OSError, MemoryError) as e:
|
|
|
+ # XXX: This may not cover all OOM cases.
|
|
|
+ msg = str(e)
|
|
|
+ if "Out of memory error" in msg \
|
|
|
+ or "(External) CUDNN error(4), CUDNN_STATUS_INTERNAL_ERROR." in msg \
|
|
|
+ or isinstance(e, MemoryError):
|
|
|
+ logging.warning("An OOM error has been ignored.")
|
|
|
+ else:
|
|
|
+ raise
|
|
|
+
|
|
|
+ return _wrapper
|
|
|
+
|
|
|
+ for key, value in inspect.getmembers(cls):
|
|
|
+ if key.startswith('test'):
|
|
|
+ value = _deco(value)
|
|
|
+ setattr(cls, key, value)
|
|
|
+
|
|
|
+ return cls
|
|
|
+
|
|
|
+
|
|
|
TestModel = _TestModelNamespace.TestModel
|