|
@@ -12,8 +12,11 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
|
|
|
+import inspect
|
|
|
+
|
|
|
import paddle
|
|
|
import numpy as np
|
|
|
+from paddle.static import InputSpec
|
|
|
|
|
|
from testing_utils import CommonTest
|
|
|
|
|
@@ -38,6 +41,15 @@ class _TestModelNamespace:
|
|
|
output = model(input)
|
|
|
self.check_output(output, target)
|
|
|
|
|
|
+ def test_to_static(self):
|
|
|
+ for i, (
|
|
|
+ input, model, target
|
|
|
+ ) in enumerate(zip(self.inputs, self.models, self.targets)):
|
|
|
+ with self.subTest(i=i):
|
|
|
+ static_model = self.convert_to_static(model, input)
|
|
|
+ output = static_model(output)
|
|
|
+ self.check_output(output, target)
|
|
|
+
|
|
|
def check_output(self, output, target):
|
|
|
pass
|
|
|
|
|
@@ -90,5 +102,21 @@ class _TestModelNamespace:
|
|
|
shape = self.get_shape(c, b, h, w)
|
|
|
return paddle.randn(shape)
|
|
|
|
|
|
+ def get_input_spec(self, model, input):
|
|
|
+ if not isinstance(input, list):
|
|
|
+ input = [input]
|
|
|
+ input_spec = []
|
|
|
+ for param_name, tensor in zip(
|
|
|
+ inspect.signature(model.forward).parameters, input):
|
|
|
+ # XXX: Hard-code dtype
|
|
|
+ input_spec.append(
|
|
|
+ InputSpec(
|
|
|
+ shape=tensor.shape, name=param_name, dtype='float32'))
|
|
|
+ return input_spec
|
|
|
+
|
|
|
+ def convert_to_static(self, model, input):
|
|
|
+ return paddle.jit.to_static(
|
|
|
+ model, input_spec=self.get_input_spec(model, input))
|
|
|
+
|
|
|
|
|
|
TestModel = _TestModelNamespace.TestModel
|