|
@@ -423,6 +423,13 @@ class BaseClassifier(BaseModel):
|
|
|
self.model_type)
|
|
|
self.net.eval()
|
|
|
data = (batch_im, batch_origin_shape, transforms.transforms)
|
|
|
+ # add class_id_map from model.yml
|
|
|
+ if self._postprocess is None:
|
|
|
+ label_dict = dict()
|
|
|
+ for i, label in enumerate(self.labels):
|
|
|
+ label_dict[i] = label
|
|
|
+ self._postprocess = self.default_postprocess(None)
|
|
|
+ self._postprocess.class_id_map = label_dict
|
|
|
outputs = self.run(self.net, data, 'test')
|
|
|
label_list = outputs['class_ids']
|
|
|
score_list = outputs['scores']
|
|
@@ -451,7 +458,7 @@ class BaseClassifier(BaseModel):
|
|
|
if isinstance(sample['image'], str):
|
|
|
sample = ImgDecoder(to_rgb=False)(sample)
|
|
|
ori_shape = sample['image'].shape[:2]
|
|
|
- im = transforms(sample)[0]
|
|
|
+ im = transforms(sample)
|
|
|
batch_im.append(im)
|
|
|
batch_ori_shape.append(ori_shape)
|
|
|
if to_tensor:
|