|
@@ -115,7 +115,7 @@ class BaseRestorer(BaseModel):
|
|
|
tar_shape = inputs[1]
|
|
|
if self.status == 'Infer':
|
|
|
net_out = net(inputs[0])
|
|
|
- res_map_list = self._postprocess(
|
|
|
+ res_map_list = self.postprocess(
|
|
|
net_out, tar_shape, transforms=inputs[2])
|
|
|
else:
|
|
|
if isinstance(net, GANAdapter):
|
|
@@ -124,7 +124,7 @@ class BaseRestorer(BaseModel):
|
|
|
net_out = net(inputs[0])
|
|
|
if self.TEST_OUT_KEY is not None:
|
|
|
net_out = net_out[self.TEST_OUT_KEY]
|
|
|
- pred = self._postprocess(
|
|
|
+ pred = self.postprocess(
|
|
|
net_out, tar_shape, transforms=inputs[2])
|
|
|
res_map_list = []
|
|
|
for res_map in pred:
|
|
@@ -141,7 +141,7 @@ class BaseRestorer(BaseModel):
|
|
|
net_out = net_out[self.TEST_OUT_KEY]
|
|
|
tar = inputs[1]
|
|
|
tar_shape = [tar.shape[-2:]]
|
|
|
- pred = self._postprocess(
|
|
|
+ pred = self.postprocess(
|
|
|
net_out, tar_shape, transforms=inputs[2])[0] # NCHW
|
|
|
pred = self._tensor_to_images(pred)
|
|
|
outputs['pred'] = pred
|
|
@@ -446,8 +446,8 @@ class BaseRestorer(BaseModel):
|
|
|
images = [img_file]
|
|
|
else:
|
|
|
images = img_file
|
|
|
- batch_im, batch_tar_shape = self._preprocess(images, transforms,
|
|
|
- self.model_type)
|
|
|
+ batch_im, batch_tar_shape = self.preprocess(images, transforms,
|
|
|
+ self.model_type)
|
|
|
self.net.eval()
|
|
|
data = (batch_im, batch_tar_shape, transforms.transforms)
|
|
|
outputs = self.run(self.net, data, 'test')
|
|
@@ -458,7 +458,7 @@ class BaseRestorer(BaseModel):
|
|
|
prediction = {'res_map': res_map_list[0]}
|
|
|
return prediction
|
|
|
|
|
|
- def _preprocess(self, images, transforms, to_tensor=True):
|
|
|
+ def preprocess(self, images, transforms, to_tensor=True):
|
|
|
self._check_transforms(transforms, 'test')
|
|
|
batch_im = list()
|
|
|
batch_tar_shape = list()
|
|
@@ -531,7 +531,7 @@ class BaseRestorer(BaseModel):
|
|
|
batch_restore_list.append(restore_list)
|
|
|
return batch_restore_list
|
|
|
|
|
|
- def _postprocess(self, batch_pred, batch_tar_shape, transforms):
|
|
|
+ def postprocess(self, batch_pred, batch_tar_shape, transforms):
|
|
|
batch_restore_list = BaseRestorer.get_transforms_shape_info(
|
|
|
batch_tar_shape, transforms)
|
|
|
if self.status == 'Infer':
|