瀏覽代碼

Add predictor unittests

Bobholamovic 2 年之前
父節點
當前提交
1a35b297af

+ 27 - 9
deploy/export/export_model.py

@@ -21,9 +21,23 @@ from paddlers.tasks import load_model
 
 def get_parser():
     parser = argparse.ArgumentParser()
-    parser.add_argument('--model_dir', '-m', type=str, default=None, help='model directory path')
-    parser.add_argument('--save_dir', '-s', type=str, default=None, help='path to save inference model')
-    parser.add_argument('--fixed_input_shape', '-fs', type=str, default=None,
+    parser.add_argument(
+        '--model_dir',
+        '-m',
+        type=str,
+        default=None,
+        help='model directory path')
+    parser.add_argument(
+        '--save_dir',
+        '-s',
+        type=str,
+        default=None,
+        help='path to save inference model')
+    parser.add_argument(
+        '--fixed_input_shape',
+        '-fs',
+        type=str,
+        default=None,
         help="export inference model with fixed input shape: [w,h] or [n,c,w,h]")
     return parser
 
@@ -39,13 +53,17 @@ if __name__ == '__main__':
         fixed_input_shape = literal_eval(args.fixed_input_shape)
         # Check validaty
         if not isinstance(fixed_input_shape, list):
-            raise ValueError("fixed_input_shape should be of None or list type.")
+            raise ValueError(
+                "fixed_input_shape should be of None or list type.")
         if len(fixed_input_shape) not in (2, 4):
-            raise ValueError("fixed_input_shape contains an incorrect number of elements.")
+            raise ValueError(
+                "fixed_input_shape contains an incorrect number of elements.")
         if fixed_input_shape[-1] <= 0 or fixed_input_shape[-2] <= 0:
-            raise ValueError("the input width and height must be positive integers.")
-        if len(fixed_input_shape)==4 and fixed_input_shape[1] <= 0:
-            raise ValueError("the number of input channels must be a positive integer.")
+            raise ValueError(
+                "Input width and height must be positive integers.")
+        if len(fixed_input_shape) == 4 and fixed_input_shape[1] <= 0:
+            raise ValueError(
+                "The number of input channels must be a positive integer.")
 
     # Set environment variables
     os.environ['PADDLEX_EXPORT_STAGE'] = 'True'
@@ -56,4 +74,4 @@ if __name__ == '__main__':
 
     # Do dynamic-to-static cast
     # XXX: Invoke a protected (single underscore) method outside of subclasses.
-    model._export_inference_model(args.save_dir, fixed_input_shape)
+    model._export_inference_model(args.save_dir, fixed_input_shape)

+ 3 - 3
paddlers/deploy/predictor.py

@@ -175,9 +175,9 @@ class Predictor(object):
             if self._model._postprocess is None:
                 self._model.build_postprocess_from_labels(topk)
             # XXX: Convert ndarray to tensor as self._model._postprocess requires
-            net_outputs = paddle.to_tensor(net_outputs)
-            assert net_outputs.shape[1] == 1
-            outputs = self._model._postprocess(net_outputs.squeeze(1))
+            assert len(net_outputs) == 1
+            net_outputs = paddle.to_tensor(net_outputs[0])
+            outputs = self._model._postprocess(net_outputs)
             class_ids = map(itemgetter('class_ids'), outputs)
             scores = map(itemgetter('scores'), outputs)
             label_names = map(itemgetter('label_names'), outputs)

+ 2 - 0
paddlers/tasks/change_detector.py

@@ -650,6 +650,8 @@ class BaseChangeDetector(BaseModel):
             if isinstance(sample['image_t1'], str) or \
                 isinstance(sample['image_t2'], str):
                 sample = DecodeImg(to_rgb=False)(sample)
+                sample['image'] = sample['image'].astype('float32')
+                sample['image2'] = sample['image2'].astype('float32')
                 ori_shape = sample['image'].shape[:2]
             else:
                 ori_shape = im1.shape[:2]

+ 1 - 0
paddlers/tasks/classifier.py

@@ -468,6 +468,7 @@ class BaseClassifier(BaseModel):
             sample = {'image': im}
             if isinstance(sample['image'], str):
                 sample = DecodeImg(to_rgb=False)(sample)
+                sample['image'] = sample['image'].astype('float32')
             ori_shape = sample['image'].shape[:2]
             im = transforms(sample)
             batch_im.append(im)

+ 6 - 2
paddlers/tasks/object_detector.py

@@ -27,7 +27,7 @@ import paddlers.models.ppdet as ppdet
 from paddlers.models.ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner
 import paddlers
 import paddlers.utils.logging as logging
-from paddlers.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Pad
+from paddlers.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Pad, DecodeImg
 from paddlers.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, \
     _BatchPad, _Gt2YoloTarget
 from paddlers.transforms import arrange_transforms
@@ -550,7 +550,11 @@ class BaseDetector(BaseModel):
         batch_samples = list()
         for im in images:
             sample = {'image': im}
-            batch_samples.append(transforms(sample))
+            if isinstance(sample['image'], str):
+                sample = DecodeImg(to_rgb=False)(sample)
+                sample['image'] = sample['image'].astype('float32')
+            sample = transforms(sample)
+            batch_samples.append(sample)
         batch_transforms = self._compose_batch_transform(transforms, 'test')
         batch_samples = batch_transforms(batch_samples)
         if to_tensor:

+ 1 - 0
paddlers/tasks/segmenter.py

@@ -614,6 +614,7 @@ class BaseSegmenter(BaseModel):
             sample = {'image': im}
             if isinstance(sample['image'], str):
                 sample = DecodeImg(to_rgb=False)(sample)
+                sample['image'] = sample['image'].astype('float32')
             ori_shape = sample['image'].shape[:2]
             im = transforms(sample)[0]
             batch_im.append(im)

+ 0 - 13
tests/deploy/test_export.py

@@ -1,13 +0,0 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.

+ 0 - 13
tests/deploy/test_predict.py

@@ -1,13 +0,0 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.

+ 351 - 0
tests/deploy/test_predictor.py

@@ -0,0 +1,351 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tempfile
+import unittest.mock as mock
+
+import cv2
+import paddle
+
+import paddlers as pdrs
+from testing_utils import CommonTest, run_script
+
+
+class TestPredictor(CommonTest):
+    MODULE = pdrs.tasks
+    TRAINER_NAME_TO_EXPORT_OPTS = {}
+
+    @staticmethod
+    def add_tests(cls):
+        def _test_predictor(trainer_name):
+            def _test_predictor_impl(self):
+                trainer_class = getattr(self.MODULE, trainer_name)
+                # Construct trainer with default parameters
+                trainer = trainer_class()
+                with tempfile.TemporaryDirectory() as td:
+                    dynamic_model_dir = f"{td}/dynamic"
+                    static_model_dir = f"{td}/static"
+                    # HACK: BaseModel.save_model() requires BaseModel().optimizer to be set
+                    optimizer = mock.Mock()
+                    optimizer.state_dict.return_value = {'foo': 'bar'}
+                    trainer.optimizer = optimizer
+                    trainer.save_model(dynamic_model_dir)
+                    export_cmd = f"python export_model.py --model_dir {dynamic_model_dir} --save_dir {static_model_dir} "
+                    if trainer_name in self.TRAINER_NAME_TO_EXPORT_OPTS:
+                        export_cmd += self.TRAINER_NAME_TO_EXPORT_OPTS[
+                            trainer_name]
+                    elif '_default' in self.TRAINER_NAME_TO_EXPORT_OPTS:
+                        export_cmd += self.TRAINER_NAME_TO_EXPORT_OPTS[
+                            '_default']
+                    run_script(export_cmd, wd="../deploy/export")
+                    # Construct predictor
+                    # TODO: Test trt and mkl
+                    predictor = pdrs.deploy.Predictor(
+                        static_model_dir,
+                        use_gpu=paddle.device.get_device().startswith('gpu'))
+                    self.check_predictor(predictor, trainer)
+
+            return _test_predictor_impl
+
+        for trainer_name in cls.MODULE.__all__:
+            setattr(cls, 'test_' + trainer_name, _test_predictor(trainer_name))
+
+        return cls
+
+    def check_predictor(self, predictor, trainer):
+        raise NotImplementedError
+
+    def check_dict_equal(self, dict_, expected_dict):
+        if isinstance(dict_, list):
+            self.assertIsInstance(expected_dict, list)
+            self.assertEqual(len(dict_), len(expected_dict))
+            for d1, d2 in zip(dict_, expected_dict):
+                self.check_dict_equal(d1, d2)
+        else:
+            assert isinstance(dict_, dict)
+            assert isinstance(expected_dict, dict)
+            self.assertEqual(dict_.keys(), expected_dict.keys())
+            for key in dict_.keys():
+                self.check_output_equal(dict_[key], expected_dict[key])
+
+
+@TestPredictor.add_tests
+class TestCDPredictor(TestPredictor):
+    MODULE = pdrs.tasks.change_detector
+    TRAINER_NAME_TO_EXPORT_OPTS = {
+        'BIT': "--fixed_input_shape [1,3,256,256]",
+        '_default': "--fixed_input_shape [-1,3,256,256]"
+    }
+
+    def check_predictor(self, predictor, trainer):
+        t1_path = "data/ssmt/optical_t1.bmp"
+        t2_path = "data/ssmt/optical_t2.bmp"
+        single_input = (t1_path, t2_path)
+        num_inputs = 2
+        transforms = pdrs.transforms.Compose([pdrs.transforms.Normalize()])
+
+        # Expected failure
+        with self.assertRaises(ValueError):
+            predictor.predict(t1_path, transforms=transforms)
+
+        # Single input (file paths)
+        input_ = single_input
+        out_single_file_p = predictor.predict(input_, transforms=transforms)
+        out_single_file_t = trainer.predict(input_, transforms=transforms)
+        self.check_dict_equal(out_single_file_p, out_single_file_t)
+        out_single_file_list_p = predictor.predict(
+            [input_], transforms=transforms)
+        self.assertEqual(len(out_single_file_list_p), 1)
+        self.check_dict_equal(out_single_file_list_p[0], out_single_file_p)
+        out_single_file_list_t = trainer.predict(
+            [input_], transforms=transforms)
+        self.check_dict_equal(out_single_file_list_p[0],
+                              out_single_file_list_t[0])
+
+        # Single input (ndarrays)
+        input_ = (cv2.imread(t1_path).astype('float32'),
+                  cv2.imread(t2_path).astype('float32')
+                  )  # Reuse the name `input_`
+        out_single_array_p = predictor.predict(input_, transforms=transforms)
+        self.check_dict_equal(out_single_array_p, out_single_file_p)
+        out_single_array_t = trainer.predict(input_, transforms=transforms)
+        self.check_dict_equal(out_single_array_p, out_single_array_t)
+        out_single_array_list_p = predictor.predict(
+            [input_], transforms=transforms)
+        self.assertEqual(len(out_single_array_list_p), 1)
+        self.check_dict_equal(out_single_array_list_p[0], out_single_array_p)
+        out_single_array_list_t = trainer.predict(
+            [input_], transforms=transforms)
+        self.check_dict_equal(out_single_array_list_p[0],
+                              out_single_array_list_t[0])
+
+        if isinstance(trainer, pdrs.tasks.change_detector.BIT):
+            return
+
+        # Multiple inputs (file paths)
+        input_ = [single_input] * num_inputs  # Reuse the name `input_`
+        out_multi_file_p = predictor.predict(input_, transforms=transforms)
+        self.assertEqual(len(out_multi_file_p), num_inputs)
+        out_multi_file_t = trainer.predict(input_, transforms=transforms)
+        self.check_dict_equal(out_multi_file_p, out_multi_file_t)
+
+        # Multiple inputs (ndarrays)
+        input_ = [(cv2.imread(t1_path).astype('float32'), cv2.imread(t2_path)
+                   .astype('float32'))] * num_inputs  # Reuse the name `input_`
+        out_multi_array_p = predictor.predict(input_, transforms=transforms)
+        self.assertEqual(len(out_multi_array_p), num_inputs)
+        out_multi_array_t = trainer.predict(input_, transforms=transforms)
+        self.check_dict_equal(out_multi_array_p, out_multi_array_t)
+
+
+@TestPredictor.add_tests
+class TestClasPredictor(TestPredictor):
+    MODULE = pdrs.tasks.classifier
+    TRAINER_NAME_TO_EXPORT_OPTS = {
+        '_default': "--fixed_input_shape [-1,3,256,256]"
+    }
+
+    def check_predictor(self, predictor, trainer):
+        single_input = "data/ssmt/optical_t1.bmp"
+        num_inputs = 2
+        transforms = pdrs.transforms.Compose([pdrs.transforms.Normalize()])
+        labels = list(range(2))
+        trainer.labels = labels
+        predictor._model.labels = labels
+
+        # Single input (file paths)
+        input_ = single_input
+        out_single_file_p = predictor.predict(input_, transforms=transforms)
+        out_single_file_t = trainer.predict(input_, transforms=transforms)
+        self.check_dict_equal(out_single_file_p, out_single_file_t)
+        out_single_file_list_p = predictor.predict(
+            [input_], transforms=transforms)
+        self.assertEqual(len(out_single_file_list_p), 1)
+        self.check_dict_equal(out_single_file_list_p[0], out_single_file_p)
+        out_single_file_list_t = trainer.predict(
+            [input_], transforms=transforms)
+        self.check_dict_equal(out_single_file_list_p[0],
+                              out_single_file_list_t[0])
+
+        # Single input (ndarrays)
+        input_ = cv2.imread(single_input).astype(
+            'float32')  # Reuse the name `input_`
+        out_single_array_p = predictor.predict(input_, transforms=transforms)
+        self.check_dict_equal(out_single_array_p, out_single_file_p)
+        out_single_array_t = trainer.predict(input_, transforms=transforms)
+        self.check_dict_equal(out_single_array_p, out_single_array_t)
+        out_single_array_list_p = predictor.predict(
+            [input_], transforms=transforms)
+        self.assertEqual(len(out_single_array_list_p), 1)
+        self.check_dict_equal(out_single_array_list_p[0], out_single_array_p)
+        out_single_array_list_t = trainer.predict(
+            [input_], transforms=transforms)
+        self.check_dict_equal(out_single_array_list_p[0],
+                              out_single_array_list_t[0])
+
+        # Multiple inputs (file paths)
+        input_ = [single_input] * num_inputs  # Reuse the name `input_`
+        out_multi_file_p = predictor.predict(input_, transforms=transforms)
+        self.assertEqual(len(out_multi_file_p), num_inputs)
+        out_multi_file_t = trainer.predict(input_, transforms=transforms)
+        self.assertEqual(len(out_multi_file_p), len(out_multi_file_t))
+        self.check_dict_equal(out_multi_file_p, out_multi_file_t)
+
+        # Multiple inputs (ndarrays)
+        input_ = [cv2.imread(single_input).astype('float32')
+                  ] * num_inputs  # Reuse the name `input_`
+        out_multi_array_p = predictor.predict(input_, transforms=transforms)
+        self.assertEqual(len(out_multi_array_p), num_inputs)
+        out_multi_array_t = trainer.predict(input_, transforms=transforms)
+        self.assertEqual(len(out_multi_array_p), len(out_multi_array_t))
+        self.check_dict_equal(out_multi_array_p, out_multi_array_t)
+
+
+@TestPredictor.add_tests
+class TestDetPredictor(TestPredictor):
+    MODULE = pdrs.tasks.object_detector
+    TRAINER_NAME_TO_EXPORT_OPTS = {
+        '_default': "--fixed_input_shape [-1,3,256,256]"
+    }
+
+    def check_predictor(self, predictor, trainer):
+        single_input = "data/ssmt/optical_t1.bmp"
+        num_inputs = 2
+        transforms = pdrs.transforms.Compose([pdrs.transforms.Normalize()])
+        labels = list(range(80))
+        trainer.labels = labels
+        predictor._model.labels = labels
+
+        # Single input (file paths)
+        input_ = single_input
+        out_single_file_p = predictor.predict(input_, transforms=transforms)
+        out_single_file_t = trainer.predict(input_, transforms=transforms)
+        self.check_dict_equal(out_single_file_p, out_single_file_t)
+        out_single_file_list_p = predictor.predict(
+            [input_], transforms=transforms)
+        self.assertEqual(len(out_single_file_list_p), 1)
+        self.check_dict_equal(out_single_file_list_p[0], out_single_file_p)
+        out_single_file_list_t = trainer.predict(
+            [input_], transforms=transforms)
+        self.check_dict_equal(out_single_file_list_p[0],
+                              out_single_file_list_t[0])
+
+        # Single input (ndarrays)
+        input_ = cv2.imread(single_input).astype(
+            'float32')  # Reuse the name `input_`
+        out_single_array_p = predictor.predict(input_, transforms=transforms)
+        self.check_dict_equal(out_single_array_p, out_single_file_p)
+        out_single_array_t = trainer.predict(input_, transforms=transforms)
+        self.check_dict_equal(out_single_array_p, out_single_array_t)
+        out_single_array_list_p = predictor.predict(
+            [input_], transforms=transforms)
+        self.assertEqual(len(out_single_array_list_p), 1)
+        self.check_dict_equal(out_single_array_list_p[0], out_single_array_p)
+        out_single_array_list_t = trainer.predict(
+            [input_], transforms=transforms)
+        self.check_dict_equal(out_single_array_list_p[0],
+                              out_single_array_list_t[0])
+
+        # Single input (ndarrays)
+        input_ = cv2.imread(single_input).astype(
+            'float32')  # Reuse the name `input_`
+        out_single_array_p = predictor.predict(input_, transforms=transforms)
+        self.check_dict_equal(out_single_array_p, out_single_file_p)
+        out_single_array_t = trainer.predict(input_, transforms=transforms)
+        self.check_dict_equal(out_single_array_p, out_single_array_t)
+        out_single_array_list_p = predictor.predict(
+            [input_], transforms=transforms)
+        self.assertEqual(len(out_single_array_list_p), 1)
+        self.check_dict_equal(out_single_array_list_p[0], out_single_array_p)
+        out_single_array_list_t = trainer.predict(
+            [input_], transforms=transforms)
+        self.check_dict_equal(out_single_array_list_p[0],
+                              out_single_array_list_t[0])
+
+        # Multiple inputs (file paths)
+        input_ = [single_input] * num_inputs  # Reuse the name `input_`
+        out_multi_file_p = predictor.predict(input_, transforms=transforms)
+        self.assertEqual(len(out_multi_file_p), num_inputs)
+        out_multi_file_t = trainer.predict(input_, transforms=transforms)
+        self.assertEqual(len(out_multi_file_p), len(out_multi_file_t))
+        self.check_dict_equal(out_multi_file_p, out_multi_file_t)
+
+        # Multiple inputs (ndarrays)
+        input_ = [cv2.imread(single_input).astype('float32')
+                  ] * num_inputs  # Reuse the name `input_`
+        out_multi_array_p = predictor.predict(input_, transforms=transforms)
+        self.assertEqual(len(out_multi_array_p), num_inputs)
+        out_multi_array_t = trainer.predict(input_, transforms=transforms)
+        self.assertEqual(len(out_multi_array_p), len(out_multi_array_t))
+        self.check_dict_equal(out_multi_array_p, out_multi_array_t)
+
+
+@TestPredictor.add_tests
+class TestSegPredictor(TestPredictor):
+    MODULE = pdrs.tasks.segmenter
+    TRAINER_NAME_TO_EXPORT_OPTS = {
+        '_default': "--fixed_input_shape [-1,3,256,256]"
+    }
+
+    def check_predictor(self, predictor, trainer):
+        single_input = "data/ssmt/optical_t1.bmp"
+        num_inputs = 2
+        transforms = pdrs.transforms.Compose([pdrs.transforms.Normalize()])
+
+        # Single input (file paths)
+        input_ = single_input
+        out_single_file_p = predictor.predict(input_, transforms=transforms)
+        out_single_file_t = trainer.predict(input_, transforms=transforms)
+        self.check_dict_equal(out_single_file_p, out_single_file_t)
+        out_single_file_list_p = predictor.predict(
+            [input_], transforms=transforms)
+        self.assertEqual(len(out_single_file_list_p), 1)
+        self.check_dict_equal(out_single_file_list_p[0], out_single_file_p)
+        out_single_file_list_t = trainer.predict(
+            [input_], transforms=transforms)
+        self.check_dict_equal(out_single_file_list_p[0],
+                              out_single_file_list_t[0])
+
+        # Single input (ndarrays)
+        input_ = cv2.imread(single_input).astype(
+            'float32')  # Reuse the name `input_`
+        out_single_array_p = predictor.predict(input_, transforms=transforms)
+        self.check_dict_equal(out_single_array_p, out_single_file_p)
+        out_single_array_t = trainer.predict(input_, transforms=transforms)
+        self.check_dict_equal(out_single_array_p, out_single_array_t)
+        out_single_array_list_p = predictor.predict(
+            [input_], transforms=transforms)
+        self.assertEqual(len(out_single_array_list_p), 1)
+        self.check_dict_equal(out_single_array_list_p[0], out_single_array_p)
+        out_single_array_list_t = trainer.predict(
+            [input_], transforms=transforms)
+        self.check_dict_equal(out_single_array_list_p[0],
+                              out_single_array_list_t[0])
+
+        # Multiple inputs (file paths)
+        input_ = [single_input] * num_inputs  # Reuse the name `input_`
+        out_multi_file_p = predictor.predict(input_, transforms=transforms)
+        self.assertEqual(len(out_multi_file_p), num_inputs)
+        out_multi_file_t = trainer.predict(input_, transforms=transforms)
+        self.assertEqual(len(out_multi_file_p), len(out_multi_file_t))
+        self.check_dict_equal(out_multi_file_p, out_multi_file_t)
+
+        # Multiple inputs (ndarrays)
+        input_ = [cv2.imread(single_input).astype('float32')
+                  ] * num_inputs  # Reuse the name `input_`
+        out_multi_array_p = predictor.predict(input_, transforms=transforms)
+        self.assertEqual(len(out_multi_array_p), num_inputs)
+        out_multi_array_t = trainer.predict(input_, transforms=transforms)
+        self.assertEqual(len(out_multi_array_p), len(out_multi_array_t))
+        self.check_dict_equal(out_multi_array_p, out_multi_array_t)

+ 1 - 1
tests/tools/test_match.py

@@ -22,4 +22,4 @@ class TestMatch(CpuCommonTest):
         with tempfile.TemporaryDirectory() as td:
             run_script(
                 f"python match.py --im1_path ../tests/data/ssmt/multispectral_t1.tif --im2_path ../tests/data/ssmt/multispectral_t1.tif --save_path {td}/out.tiff",
-                wd='../tools')
+                wd="../tools")

+ 1 - 1
tests/tools/test_oif.py

@@ -21,4 +21,4 @@ class TestOIF(CpuCommonTest):
     def test_script(self):
         run_script(
             f"python oif.py --im_path ../tests/data/ssst/multispectral.tif",
-            wd='../tools')
+            wd="../tools")

+ 1 - 1
tests/tools/test_pca.py

@@ -22,4 +22,4 @@ class TestPCA(CpuCommonTest):
         with tempfile.TemporaryDirectory() as td:
             run_script(
                 f"python pca.py --im_path ../tests/data/ssst/multispectral.tif --save_dir {td} --dim 5",
-                wd='../tools')
+                wd="../tools")

+ 1 - 1
tests/tools/test_split.py

@@ -22,4 +22,4 @@ class TestSplit(CpuCommonTest):
         with tempfile.TemporaryDirectory() as td:
             run_script(
                 f"python split.py --image_path ../tests/data/ssst/multispectral.tif --mask_path ../tests/data/ssst/multiclass_gt2.png --block_size 128 --save_dir {td}",
-                wd='../tools')
+                wd="../tools")

+ 2 - 2
tests/transforms/test_functions.py

@@ -23,7 +23,7 @@ class TestMatchHistograms(CpuCommonTest):
     def setUp(self):
         self.inputs = [
             build_input_from_file(
-                'data/ssmt/test_mixed_binary.txt', prefix='./data/ssmt')
+                "data/ssmt/test_mixed_binary.txt", prefix="./data/ssmt")
         ]
 
     def test_output_shape(self):
@@ -43,7 +43,7 @@ class TestMatchByRegression(CpuCommonTest):
     def setUp(self):
         self.inputs = [
             build_input_from_file(
-                'data/ssmt/test_mixed_binary.txt', prefix='./data/ssmt')
+                "data/ssmt/test_mixed_binary.txt", prefix="./data/ssmt")
         ]
 
     def test_output_shape(self):

+ 29 - 29
tests/transforms/test_operators.py

@@ -136,47 +136,47 @@ class TestTransform(CpuCommonTest):
     def setUp(self):
         self.inputs = [
             build_input_from_file(
-                'data/ssst/test_optical_clas.txt',
-                prefix='./data/ssst'), 
+                "data/ssst/test_optical_clas.txt",
+                prefix="./data/ssst"), 
             build_input_from_file(
-                'data/ssst/test_sar_clas.txt',
-                prefix='./data/ssst'), 
+                "data/ssst/test_sar_clas.txt",
+                prefix="./data/ssst"), 
             build_input_from_file(
-                'data/ssst/test_multispectral_clas.txt',
-                prefix='./data/ssst'), 
+                "data/ssst/test_multispectral_clas.txt",
+                prefix="./data/ssst"), 
             build_input_from_file(
-                'data/ssst/test_optical_seg.txt',
-                prefix='./data/ssst'), 
+                "data/ssst/test_optical_seg.txt",
+                prefix="./data/ssst"), 
             build_input_from_file(
-                'data/ssst/test_sar_seg.txt',
-                prefix='./data/ssst'), 
+                "data/ssst/test_sar_seg.txt",
+                prefix="./data/ssst"), 
             build_input_from_file(
-                'data/ssst/test_multispectral_seg.txt',
-                prefix='./data/ssst'),
+                "data/ssst/test_multispectral_seg.txt",
+                prefix="./data/ssst"),
             build_input_from_file(
-                'data/ssst/test_optical_det.txt',
-                prefix='./data/ssst',
-                label_list='data/ssst/labels_det.txt'), 
+                "data/ssst/test_optical_det.txt",
+                prefix="./data/ssst",
+                label_list="data/ssst/labels_det.txt"), 
             build_input_from_file(
-                'data/ssst/test_sar_det.txt',
-                prefix='./data/ssst',
-                label_list='data/ssst/labels_det.txt'),
+                "data/ssst/test_sar_det.txt",
+                prefix="./data/ssst",
+                label_list="data/ssst/labels_det.txt"),
             build_input_from_file(
-                'data/ssst/test_multispectral_det.txt',
-                prefix='./data/ssst',
-                label_list='data/ssst/labels_det.txt'), 
+                "data/ssst/test_multispectral_det.txt",
+                prefix="./data/ssst",
+                label_list="data/ssst/labels_det.txt"), 
             build_input_from_file(
-                'data/ssst/test_det_coco.txt',
-                prefix='./data/ssst'), 
+                "data/ssst/test_det_coco.txt",
+                prefix="./data/ssst"), 
             build_input_from_file(
-                'data/ssmt/test_mixed_binary.txt',
-                prefix='./data/ssmt'), 
+                "data/ssmt/test_mixed_binary.txt",
+                prefix="./data/ssmt"), 
             build_input_from_file(
-                'data/ssmt/test_mixed_multiclass.txt',
-                prefix='./data/ssmt'), 
+                "data/ssmt/test_mixed_multiclass.txt",
+                prefix="./data/ssmt"), 
             build_input_from_file(
-                'data/ssmt/test_mixed_multitask.txt',
-                prefix='./data/ssmt')
+                "data/ssmt/test_mixed_multitask.txt",
+                prefix="./data/ssmt")
         ]
 
     def test_DecodeImg(self):