Browse Source

Optimize inheritance

Bobholamovic 2 years ago
parent
commit
de4ca3a3fd
1 changed files with 10 additions and 9 deletions
  1. 10 9
      paddlers/tasks/object_detector.py

+ 10 - 9
paddlers/tasks/object_detector.py

@@ -1017,7 +1017,7 @@ class PicoDet(BaseDetector):
                 dataset, batch_size, mode, collate_fn)
 
 
-class YOLOv3(BaseDetector):
+class _YOLOv3(BaseDetector):
     def __init__(self,
                  rotate=False,
                  num_classes=80,
@@ -1142,7 +1142,7 @@ class YOLOv3(BaseDetector):
                 'yolo_head': yolo_head,
                 'post_process': post_process
             })
-        super(YOLOv3, self).__init__(
+        super(_YOLOv3, self).__init__(
             model_name='YOLOv3', num_classes=num_classes, **params)
         self.anchors = anchors
         self.anchor_masks = anchor_masks
@@ -1455,7 +1455,7 @@ class FasterRCNN(BaseDetector):
         return self._define_input_spec(image_shape)
 
 
-class PPYOLO(YOLOv3):
+class PPYOLO(_YOLOv3):
     def __init__(self,
                  num_classes=80,
                  backbone='ResNet50_vd_dcn',
@@ -1614,7 +1614,7 @@ class PPYOLO(YOLOv3):
                 'post_process': post_process
             })
 
-        super(YOLOv3, self).__init__(
+        super(PPYOLO, self).__init__(
             model_name='YOLOv3', num_classes=num_classes, **params)
         self.anchors = anchors
         self.anchor_masks = anchor_masks
@@ -1643,7 +1643,7 @@ class PPYOLO(YOLOv3):
         return self._define_input_spec(image_shape)
 
 
-class PPYOLOTiny(YOLOv3):
+class PPYOLOTiny(_YOLOv3):
     def __init__(self,
                  num_classes=80,
                  backbone='MobileNetV3',
@@ -1741,7 +1741,7 @@ class PPYOLOTiny(YOLOv3):
                 'post_process': post_process
             })
 
-        super(YOLOv3, self).__init__(
+        super(PPYOLOTiny, self).__init__(
             model_name='YOLOv3', num_classes=num_classes, **params)
         self.anchors = anchors
         self.anchor_masks = anchor_masks
@@ -1771,7 +1771,7 @@ class PPYOLOTiny(YOLOv3):
         return self._define_input_spec(image_shape)
 
 
-class PPYOLOv2(YOLOv3):
+class PPYOLOv2(_YOLOv3):
     def __init__(self,
                  num_classes=80,
                  backbone='ResNet50_vd_dcn',
@@ -1888,7 +1888,7 @@ class PPYOLOv2(YOLOv3):
                 'post_process': post_process
             })
 
-        super(YOLOv3, self).__init__(
+        super(PPYOLOv2, self).__init__(
             model_name='YOLOv3', num_classes=num_classes, **params)
         self.anchors = anchors
         self.anchor_masks = anchor_masks
@@ -2187,4 +2187,5 @@ class MaskRCNN(BaseDetector):
         return self._define_input_spec(image_shape)
 
 
-FCOSR = functools.partial(YOLOv3, rotate=True)
+YOLOv3 = functools.partial(_YOLOv3, rotate=False)
+FCOSR = functools.partial(_YOLOv3, rotate=True)