test_tutorials.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. import tempfile
  16. import shutil
  17. from glob import iglob
  18. from testing_utils import run_script, CpuCommonTest
  19. class TestTutorial(CpuCommonTest):
  20. SUBDIR = "./"
  21. TIMEOUT = 300
  22. PATTERN = "*.py"
  23. @classmethod
  24. def setUpClass(cls):
  25. cls._td = tempfile.TemporaryDirectory(dir='./')
  26. # Recursively copy the content of `cls.SUBDIR` to td.
  27. # This is necessary for running scripts in td.
  28. cls._TSUBDIR = osp.join(cls._td.name, osp.basename(cls.SUBDIR))
  29. shutil.copytree(cls.SUBDIR, cls._TSUBDIR)
  30. return super().setUpClass()
  31. @classmethod
  32. def tearDownClass(cls):
  33. cls._td.cleanup()
  34. @staticmethod
  35. def add_tests(cls):
  36. """
  37. Automatically patch testing functions to cls.
  38. """
  39. def _test_tutorial(script_name):
  40. def _test_tutorial_impl(self):
  41. # Set working directory to `cls._TSUBDIR` such that the
  42. # files generated by the script will be automatically cleaned.
  43. run_script(f"python {script_name}", wd=cls._TSUBDIR)
  44. return _test_tutorial_impl
  45. for script_path in iglob(osp.join(cls.SUBDIR, cls.PATTERN)):
  46. script_name = osp.basename(script_path)
  47. if osp.normpath(osp.join(cls.SUBDIR, script_name)) != osp.normpath(
  48. script_path):
  49. raise ValueError(
  50. f"{script_name} should be directly contained in {cls.SUBDIR}"
  51. )
  52. setattr(cls, 'test_' + script_name, _test_tutorial(script_name))
  53. return cls
  54. @TestTutorial.add_tests
  55. class TestCDTutorial(TestTutorial):
  56. SUBDIR = "../tutorials/train/change_detection"
  57. @TestTutorial.add_tests
  58. class TestClasTutorial(TestTutorial):
  59. SUBDIR = "../tutorials/train/classification"
  60. @TestTutorial.add_tests
  61. class TestDetTutorial(TestTutorial):
  62. SUBDIR = "../tutorials/train/object_detection"
  63. @TestTutorial.add_tests
  64. class TestSegTutorial(TestTutorial):
  65. SUBDIR = "../tutorials/train/semantic_segmentation"