Browse Source

[Fix] Fix error when clas predict (#75)

Yizhou Chen 2 years ago
parent
commit
b583337932
1 changed files with 8 additions and 1 deletions
  1. 8 1
      paddlers/tasks/classifier.py

+ 8 - 1
paddlers/tasks/classifier.py

@@ -423,6 +423,13 @@ class BaseClassifier(BaseModel):
                                                         self.model_type)
         self.net.eval()
         data = (batch_im, batch_origin_shape, transforms.transforms)
+        # add class_id_map from model.yml
+        if self._postprocess is None:
+            label_dict = dict()
+            for i, label in enumerate(self.labels):
+                label_dict[i] = label
+            self._postprocess = self.default_postprocess(None)
+            self._postprocess.class_id_map = label_dict
         outputs = self.run(self.net, data, 'test')
         label_list = outputs['class_ids']
         score_list = outputs['scores']
@@ -451,7 +458,7 @@ class BaseClassifier(BaseModel):
             if isinstance(sample['image'], str):
                 sample = ImgDecoder(to_rgb=False)(sample)
             ori_shape = sample['image'].shape[:2]
-            im = transforms(sample)[0]
+            im = transforms(sample)
             batch_im.append(im)
             batch_ori_shape.append(ori_shape)
         if to_tensor: