|
@@ -13,6 +13,7 @@
|
|
|
# limitations under the License.
|
|
|
|
|
|
import os.path as osp
|
|
|
+import sys
|
|
|
import tempfile
|
|
|
import unittest.mock as mock
|
|
|
|
|
@@ -54,7 +55,7 @@ class TestPredictor(CommonTest):
|
|
|
optimizer.state_dict.return_value = {'foo': 'bar'}
|
|
|
trainer.optimizer = optimizer
|
|
|
trainer.save_model(dynamic_model_dir)
|
|
|
- export_cmd = f"python export_model.py --model_dir {dynamic_model_dir} --save_dir {static_model_dir} "
|
|
|
+ export_cmd = f"{sys.executable} export_model.py --model_dir {dynamic_model_dir} --save_dir {static_model_dir} "
|
|
|
if trainer_name in self.TRAINER_NAME_TO_EXPORT_OPTS:
|
|
|
export_cmd += self.TRAINER_NAME_TO_EXPORT_OPTS[
|
|
|
trainer_name]
|
|
@@ -126,7 +127,7 @@ class TestCDPredictor(TestPredictor):
|
|
|
t2_path = "data/ssmt/optical_t2.bmp"
|
|
|
single_input = (t1_path, t2_path)
|
|
|
num_inputs = 2
|
|
|
- transforms = [pdrs.transforms.Normalize()]
|
|
|
+ transforms = pdrs.transforms.Compose([pdrs.transforms.Normalize()])
|
|
|
|
|
|
# Expected failure
|
|
|
with self.assertRaises(ValueError):
|
|
@@ -191,7 +192,7 @@ class TestClasPredictor(TestPredictor):
|
|
|
def check_predictor(self, predictor, trainer):
|
|
|
single_input = "data/ssst/optical.bmp"
|
|
|
num_inputs = 2
|
|
|
- transforms = [pdrs.transforms.Normalize()]
|
|
|
+ transforms = pdrs.transforms.Compose([pdrs.transforms.Normalize()])
|
|
|
labels = list(range(2))
|
|
|
trainer.labels = labels
|
|
|
predictor._model.labels = labels
|
|
@@ -257,7 +258,7 @@ class TestDetPredictor(TestPredictor):
|
|
|
# given that the network is (partially?) randomly initialized.
|
|
|
single_input = "data/ssst/optical.bmp"
|
|
|
num_inputs = 2
|
|
|
- transforms = [pdrs.transforms.Normalize()]
|
|
|
+ transforms = pdrs.transforms.Compose([pdrs.transforms.Normalize()])
|
|
|
labels = list(range(80))
|
|
|
trainer.labels = labels
|
|
|
predictor._model.labels = labels
|
|
@@ -319,7 +320,7 @@ class TestResPredictor(TestPredictor):
|
|
|
# because the output is of uint8 type.
|
|
|
single_input = "data/ssst/optical.bmp"
|
|
|
num_inputs = 2
|
|
|
- transforms = [pdrs.transforms.Normalize()]
|
|
|
+ transforms = pdrs.transforms.Compose([pdrs.transforms.Normalize()])
|
|
|
|
|
|
# Single input (file path)
|
|
|
input_ = single_input
|
|
@@ -371,7 +372,7 @@ class TestSegPredictor(TestPredictor):
|
|
|
def check_predictor(self, predictor, trainer):
|
|
|
single_input = "data/ssst/optical.bmp"
|
|
|
num_inputs = 2
|
|
|
- transforms = [pdrs.transforms.Normalize()]
|
|
|
+ transforms = pdrs.transforms.Compose([pdrs.transforms.Normalize()])
|
|
|
|
|
|
# Single input (file path)
|
|
|
input_ = single_input
|