Bläddra i källkod

Fix rotated object detection predict bug

Bobholamovic 2 år sedan
förälder
incheckning
719ca2d321
3 ändrade filer med 23 tillägg och 6 borttagningar
  1. 1 1
      docs/apis/infer_cn.md
  2. 1 1
      docs/apis/infer_en.md
  3. 21 4
      paddlers/tasks/object_detector.py

+ 1 - 1
docs/apis/infer_cn.md

@@ -82,7 +82,7 @@ def predict(self, img_file, transforms=None):
 ```
 {"category_id": 类别ID,
  "category": 类别名称,
- "bbox": 目标框位置信息,依次包含目标框左上角的横、纵坐标以及目标框的宽度和长度,  
+ "bbox": 目标框位置信息,对于水平目标框依次包含目标框左上角的横、纵坐标以及目标框的宽度和高度,对于旋转框依次包含目标框的四个角点的横、纵坐标,  
  "score": 类别置信度,
  "mask": [RLE格式](https://baike.baidu.com/item/rle/366352)的掩模图(mask),仅实例分割模型预测结果包含此键值对}
 ```

+ 1 - 1
docs/apis/infer_en.md

@@ -82,7 +82,7 @@ If `img_file` is a string or NumPy array, returns a list with a predicted target
 ```
 {"category_id": Category ID,
  "category": Category name,
- "bbox": Bounding box position information, including the horizontal and vertical coordinates of the upper left corner of the box and the width and length of the box,  
+ "bbox": Bounding box position information, including the horizontal and vertical coordinates of the upper left corner of the box and the width and height of the box (for horizontal bounding boxes), or the horizontal and vertical coordinates of the four corners of the box (for rotated bounding boxes),  
  "score": Category confidence score,
  "mask": [RLE Format](https://baike.baidu.com/item/rle/366352) mask, only instance segmentation model prediction results contain this key-value pair}
 ```

+ 21 - 4
paddlers/tasks/object_detector.py

@@ -770,13 +770,22 @@ class BaseDetector(BaseModel):
                 for j in range(det_nums):
                     dt = bboxes[k]
                     k = k + 1
-                    num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
+                    dt = dt.tolist()
+                    if len(dt) == 8:
+                        # Generic object detection
+                        num_id, score, xmin, ymin, xmax, ymax = dt
+                        w = xmax - xmin
+                        h = ymax - ymin
+                        bbox = [xmin, ymin, w, h]
+                    elif len(dt) == 10:
+                        # Rotated object detection
+                        num_id, score, *pts = dt
+                        bbox = list(pts)
+                    else:
+                        raise AssertionError
                     if int(num_id) < 0:
                         continue
                     category = self.labels[int(num_id)]
-                    w = xmax - xmin
-                    h = ymax - ymin
-                    bbox = [xmin, ymin, w, h]
                     dt_res = {
                         'category_id': int(num_id),
                         'category': category,
@@ -2287,6 +2296,10 @@ class FCOSR(YOLOv3):
 
         return batch_transforms
 
+    def export_inference_model(self, save_dir, image_shape=None):
+        raise RuntimeError("Currently, {} model cannot be exported.".format(
+            self.__class__.__name__))
+
 
 class PPYOLOE_R(YOLOv3):
     supported_backbones = ('CSPResNet_m', 'CSPResNet_l', 'CSPResNet_s',
@@ -2399,3 +2412,7 @@ class PPYOLOE_R(YOLOv3):
             batch_transforms, collate_batch=collate_batch)
 
         return batch_transforms
+
+    def export_inference_model(self, save_dir, image_shape=None):
+        raise RuntimeError("Currently, {} model cannot be exported.".format(
+            self.__class__.__name__))