浏览代码

Fix rotated object detection predict bug

Bobholamovic 2 年之前
父节点
当前提交
719ca2d321
共有 3 个文件被更改,包括 23 次插入6 次删除
  1. 1 1
      docs/apis/infer_cn.md
  2. 1 1
      docs/apis/infer_en.md
  3. 21 4
      paddlers/tasks/object_detector.py

+ 1 - 1
docs/apis/infer_cn.md

@@ -82,7 +82,7 @@ def predict(self, img_file, transforms=None):
 ```
 {"category_id": 类别ID,
  "category": 类别名称,
- "bbox": 目标框位置信息,依次包含目标框左上角的横、纵坐标以及目标框的宽度和长度,  
+ "bbox": 目标框位置信息,对于水平目标框依次包含目标框左上角的横、纵坐标以及目标框的宽度和高度,对于旋转框依次包含目标框的四个角点的横、纵坐标,  
  "score": 类别置信度,
  "mask": [RLE格式](https://baike.baidu.com/item/rle/366352)的掩模图(mask),仅实例分割模型预测结果包含此键值对}
 ```

+ 1 - 1
docs/apis/infer_en.md

@@ -82,7 +82,7 @@ If `img_file` is a string or NumPy array, returns a list with a predicted target
 ```
 {"category_id": Category ID,
  "category": Category name,
- "bbox": Bounding box position information, including the horizontal and vertical coordinates of the upper left corner of the box and the width and length of the box,  
+ "bbox": Bounding box position information, including the horizontal and vertical coordinates of the upper left corner of the box and the width and height of the box (for horizontal bounding boxes), or the horizontal and vertical coordinates of the four corners of the box (for rotated bounding boxes),  
  "score": Category confidence score,
  "mask": [RLE Format](https://baike.baidu.com/item/rle/366352) mask, only instance segmentation model prediction results contain this key-value pair}
 ```

+ 21 - 4
paddlers/tasks/object_detector.py

@@ -770,13 +770,22 @@ class BaseDetector(BaseModel):
                 for j in range(det_nums):
                     dt = bboxes[k]
                     k = k + 1
-                    num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
+                    dt = dt.tolist()
+                    if len(dt) == 8:
+                        # Generic object detection
+                        num_id, score, xmin, ymin, xmax, ymax = dt
+                        w = xmax - xmin
+                        h = ymax - ymin
+                        bbox = [xmin, ymin, w, h]
+                    elif len(dt) == 10:
+                        # Rotated object detection
+                        num_id, score, *pts = dt
+                        bbox = list(pts)
+                    else:
+                        raise AssertionError
                     if int(num_id) < 0:
                         continue
                     category = self.labels[int(num_id)]
-                    w = xmax - xmin
-                    h = ymax - ymin
-                    bbox = [xmin, ymin, w, h]
                     dt_res = {
                         'category_id': int(num_id),
                         'category': category,
@@ -2287,6 +2296,10 @@ class FCOSR(YOLOv3):
 
         return batch_transforms
 
+    def export_inference_model(self, save_dir, image_shape=None):
+        raise RuntimeError("Currently, {} model cannot be exported.".format(
+            self.__class__.__name__))
+
 
 class PPYOLOE_R(YOLOv3):
     supported_backbones = ('CSPResNet_m', 'CSPResNet_l', 'CSPResNet_s',
@@ -2399,3 +2412,7 @@ class PPYOLOE_R(YOLOv3):
             batch_transforms, collate_batch=collate_batch)
 
         return batch_transforms
+
+    def export_inference_model(self, save_dir, image_shape=None):
+        raise RuntimeError("Currently, {} model cannot be exported.".format(
+            self.__class__.__name__))