|
@@ -28,7 +28,7 @@ import paddlers.models.ppdet as ppdet
|
|
|
from paddlers.models.ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner
|
|
|
from paddlers.transforms import decode_image, construct_sample
|
|
|
from paddlers.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Pad
|
|
|
-from paddlers.transforms.batch_operators import BatchCompose, _BatchPad, _Gt2YoloTarget
|
|
|
+from paddlers.transforms.batch_operators import BatchCompose, _BatchPad, _Gt2YoloTarget, BatchPadRGT, BatchNormalizeImage
|
|
|
from paddlers.models.ppdet.optimizer import ModelEMA
|
|
|
import paddlers.utils.logging as logging
|
|
|
from paddlers.utils.checkpoint import det_pretrain_weights_dict
|
|
@@ -43,6 +43,7 @@ __all__ = [
|
|
|
"PPYOLOv2",
|
|
|
"MaskRCNN",
|
|
|
"FCOSR",
|
|
|
+ "PPYOLOE_R",
|
|
|
]
|
|
|
|
|
|
# TODO: Prune and decoupling
|
|
@@ -57,6 +58,7 @@ class BaseDetector(BaseModel):
|
|
|
'rbox':
|
|
|
{'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class', 'gt_poly'},
|
|
|
}
|
|
|
+ supported_backbones = None
|
|
|
|
|
|
def __init__(self, model_name, num_classes=80, **params):
|
|
|
self.init_params.update(locals())
|
|
@@ -83,6 +85,13 @@ class BaseDetector(BaseModel):
|
|
|
def set_data_fields(cls, data_name, data_fields):
|
|
|
cls.data_fields[data_name] = data_fields
|
|
|
|
|
|
+ def _is_backbone_weight(self):
|
|
|
+ target_backbone = ['ESNET_', 'CSPResNet_']
|
|
|
+ for b in target_backbone:
|
|
|
+ if b in self.backbone_name:
|
|
|
+ return True
|
|
|
+ return False
|
|
|
+
|
|
|
def _build_inference_net(self):
|
|
|
infer_net = self.net
|
|
|
infer_net.eval()
|
|
@@ -102,6 +111,12 @@ class BaseDetector(BaseModel):
|
|
|
}]
|
|
|
return input_spec
|
|
|
|
|
|
+ def _check_backbone(self, backbone):
|
|
|
+ if backbone not in self.supported_backbones:
|
|
|
+ raise ValueError(
|
|
|
+ "backbone: {} is not supported. Please choose one of "
|
|
|
+ "{}.".format(backbone, self.supported_backbones))
|
|
|
+
|
|
|
def _check_image_shape(self, image_shape):
|
|
|
if len(image_shape) == 2:
|
|
|
image_shape = [1, 3] + image_shape
|
|
@@ -129,7 +144,7 @@ class BaseDetector(BaseModel):
|
|
|
depth = name[0]
|
|
|
fixed_kwargs['depth'] = int(depth[6:])
|
|
|
if len(name) > 1:
|
|
|
- fixed_kwargs['variant'] = name[1]
|
|
|
+ fixed_kwargs['variant'] = name[1][1]
|
|
|
backbone = getattr(ppdet.modeling, 'ResNet')
|
|
|
backbone = functools.partial(backbone, **fixed_kwargs)
|
|
|
else:
|
|
@@ -254,6 +269,7 @@ class BaseDetector(BaseModel):
|
|
|
early_stop_patience=5,
|
|
|
use_vdl=True,
|
|
|
clip_grad_by_norm=None,
|
|
|
+ reg_coeff=1e-4,
|
|
|
resume_checkpoint=None,
|
|
|
precision='fp32',
|
|
|
amp_level='O1',
|
|
@@ -286,6 +302,8 @@ class BaseDetector(BaseModel):
|
|
|
Defaults to 0.
|
|
|
warmup_start_lr (float, optional): Start learning rate of warm-up training.
|
|
|
Defaults to 0..
|
|
|
+ scheduler (str, optional): Learning rate scheduler used for training. If None,
|
|
|
+ a default scheduler will be used. Default to None.
|
|
|
lr_decay_epochs (list|tuple, optional): Epoch milestones for learning
|
|
|
rate decay. Defaults to (216, 243).
|
|
|
lr_decay_gamma (float, optional): Gamma coefficient of learning rate decay.
|
|
@@ -303,6 +321,10 @@ class BaseDetector(BaseModel):
|
|
|
early_stop_patience (int, optional): Early stop patience. Defaults to 5.
|
|
|
use_vdl(bool, optional): Whether to use VisualDL to monitor the training
|
|
|
process. Defaults to True.
|
|
|
+ clip_grad_by_norm (float, optional): Maximum global norm for gradient clipping.
|
|
|
+ Default to None.
|
|
|
+ reg_coeff (float, optional): Coefficient for L2 weight decay regularization.
|
|
|
+ Default to 1e-4.
|
|
|
resume_checkpoint (str|None, optional): Path of the checkpoint to resume
|
|
|
training from. If None, no training checkpoint will be resumed. At most
|
|
|
Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
|
|
@@ -330,14 +352,14 @@ class BaseDetector(BaseModel):
|
|
|
def _pre_train(self, in_args):
|
|
|
return in_args
|
|
|
|
|
|
- def _real_train(self, num_epochs, train_dataset, train_batch_size,
|
|
|
- eval_dataset, optimizer, save_interval_epochs,
|
|
|
- log_interval_steps, save_dir, pretrain_weights,
|
|
|
- learning_rate, warmup_steps, warmup_start_lr,
|
|
|
- lr_decay_epochs, lr_decay_gamma, metric, use_ema,
|
|
|
- early_stop, early_stop_patience, use_vdl, resume_checkpoint,
|
|
|
- scheduler, cosine_decay_num_epochs, clip_grad_by_norm,
|
|
|
- precision, amp_level, custom_white_list, custom_black_list):
|
|
|
+ def _real_train(
|
|
|
+ self, num_epochs, train_dataset, train_batch_size, eval_dataset,
|
|
|
+ optimizer, save_interval_epochs, log_interval_steps, save_dir,
|
|
|
+ pretrain_weights, learning_rate, warmup_steps, warmup_start_lr,
|
|
|
+ lr_decay_epochs, lr_decay_gamma, metric, use_ema, early_stop,
|
|
|
+ early_stop_patience, use_vdl, resume_checkpoint, scheduler,
|
|
|
+ cosine_decay_num_epochs, clip_grad_by_norm, reg_coeff, precision,
|
|
|
+ amp_level, custom_white_list, custom_black_list):
|
|
|
self.precision = precision
|
|
|
self.amp_level = amp_level
|
|
|
self.custom_white_list = custom_white_list
|
|
@@ -366,14 +388,12 @@ class BaseDetector(BaseModel):
|
|
|
|
|
|
self.labels = train_dataset.labels
|
|
|
self.num_max_boxes = train_dataset.num_max_boxes
|
|
|
- train_batch_transforms = self._default_batch_transforms(
|
|
|
- 'train') if train_dataset.batch_transforms is None else None
|
|
|
- eval_batch_transforms = self._default_batch_transforms(
|
|
|
- 'eval') if eval_dataset.batch_transforms is None else None
|
|
|
+
|
|
|
+ train_batch_transforms = self._compose_batch_transforms(
|
|
|
+ 'train', train_dataset.batch_transforms)
|
|
|
+
|
|
|
train_dataset.build_collate_fn(train_batch_transforms,
|
|
|
self._default_collate_fn)
|
|
|
- eval_dataset.build_collate_fn(eval_batch_transforms,
|
|
|
- self._default_collate_fn)
|
|
|
|
|
|
# Build optimizer if not defined
|
|
|
if optimizer is None:
|
|
@@ -389,7 +409,8 @@ class BaseDetector(BaseModel):
|
|
|
num_steps_each_epoch=num_steps_each_epoch,
|
|
|
num_epochs=num_epochs,
|
|
|
clip_grad_by_norm=clip_grad_by_norm,
|
|
|
- cosine_decay_num_epochs=cosine_decay_num_epochs)
|
|
|
+ cosine_decay_num_epochs=cosine_decay_num_epochs,
|
|
|
+ reg_coeff=reg_coeff, )
|
|
|
else:
|
|
|
self.optimizer = optimizer
|
|
|
|
|
@@ -422,8 +443,8 @@ class BaseDetector(BaseModel):
|
|
|
pretrain_weights=pretrain_weights,
|
|
|
save_dir=pretrained_dir,
|
|
|
resume_checkpoint=resume_checkpoint,
|
|
|
- is_backbone_weights=(pretrain_weights == 'IMAGENET' and
|
|
|
- 'ESNet_' in self.backbone_name))
|
|
|
+ is_backbone_weights=pretrain_weights == 'IMAGENET' and
|
|
|
+ self._is_backbone_weight())
|
|
|
|
|
|
if use_ema:
|
|
|
ema = ModelEMA(model=self.net, decay=.9998, use_thres_step=True)
|
|
@@ -454,6 +475,27 @@ class BaseDetector(BaseModel):
|
|
|
def _default_batch_transforms(self, mode):
|
|
|
raise NotImplementedError
|
|
|
|
|
|
+ def _filter_batch_transforms(self, defaults, targets):
|
|
|
+ # TODO: Warning message
|
|
|
+ if targets is None:
|
|
|
+ return defaults
|
|
|
+ target_types = [type(i) for i in targets]
|
|
|
+ filtered = [i for i in defaults if type(i) not in target_types]
|
|
|
+ return filtered
|
|
|
+
|
|
|
+ def _compose_batch_transforms(self, mode, batch_transforms):
|
|
|
+ defaults = self._default_batch_transforms(mode)
|
|
|
+ out = []
|
|
|
+ if isinstance(batch_transforms, BatchCompose):
|
|
|
+ batch_transforms = batch_transforms.batch_transforms
|
|
|
+ if batch_transforms is not None:
|
|
|
+ out.extend(batch_transforms)
|
|
|
+ filtered = self._filter_batch_transforms(defaults.batch_transforms,
|
|
|
+ batch_transforms)
|
|
|
+ out.extend(filtered)
|
|
|
+
|
|
|
+ return BatchCompose(out, collate_batch=defaults.collate_batch)
|
|
|
+
|
|
|
def quant_aware_train(self,
|
|
|
num_epochs,
|
|
|
train_dataset,
|
|
@@ -580,9 +622,12 @@ class BaseDetector(BaseModel):
|
|
|
"Evaluation metric {} is not supported. Please choose from 'COCO' and 'VOC'."
|
|
|
|
|
|
eval_dataset.data_fields = self.data_fields[self.metric]
|
|
|
- eval_batch_transforms = self._default_batch_transforms(
|
|
|
- 'eval') if eval_dataset.batch_transforms is None else None
|
|
|
- eval_dataset._build_collate_fn(eval_batch_transforms)
|
|
|
+
|
|
|
+ eval_batch_transforms = self._compose_batch_transforms(
|
|
|
+ 'eval', eval_dataset.batch_transforms)
|
|
|
+ eval_dataset.build_collate_fn(eval_batch_transforms,
|
|
|
+ self._default_collate_fn)
|
|
|
+
|
|
|
self._check_transforms(eval_dataset.transforms)
|
|
|
|
|
|
self.net.eval()
|
|
@@ -791,6 +836,9 @@ class BaseDetector(BaseModel):
|
|
|
|
|
|
|
|
|
class PicoDet(BaseDetector):
|
|
|
+ supported_backbones = ('ESNet_s', 'ESNet_m', 'ESNet_l', 'LCNet',
|
|
|
+ 'MobileNetV3', 'ResNet18_vd')
|
|
|
+
|
|
|
def __init__(self,
|
|
|
num_classes=80,
|
|
|
backbone='ESNet_m',
|
|
@@ -800,14 +848,8 @@ class PicoDet(BaseDetector):
|
|
|
nms_iou_threshold=.6,
|
|
|
**params):
|
|
|
self.init_params = locals()
|
|
|
- if backbone not in {
|
|
|
- 'ESNet_s', 'ESNet_m', 'ESNet_l', 'LCNet', 'MobileNetV3',
|
|
|
- 'ResNet18_vd'
|
|
|
- }:
|
|
|
- raise ValueError(
|
|
|
- "backbone: {} is not supported. Please choose one of "
|
|
|
- "{'ESNet_s', 'ESNet_m', 'ESNet_l', 'LCNet', 'MobileNetV3', 'ResNet18_vd'}.".
|
|
|
- format(backbone))
|
|
|
+ self._check_backbone(backbone)
|
|
|
+
|
|
|
self.backbone_name = backbone
|
|
|
if params.get('with_net', True):
|
|
|
kwargs = {}
|
|
@@ -1017,9 +1059,12 @@ class PicoDet(BaseDetector):
|
|
|
dataset, batch_size, mode, collate_fn)
|
|
|
|
|
|
|
|
|
-class _YOLOv3(BaseDetector):
|
|
|
+class YOLOv3(BaseDetector):
|
|
|
+ supported_backbones = ('MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3',
|
|
|
+ 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn',
|
|
|
+ 'ResNet34')
|
|
|
+
|
|
|
def __init__(self,
|
|
|
- rotate=False,
|
|
|
num_classes=80,
|
|
|
backbone='MobileNetV1',
|
|
|
post_process=None,
|
|
@@ -1035,16 +1080,7 @@ class _YOLOv3(BaseDetector):
|
|
|
label_smooth=False,
|
|
|
**params):
|
|
|
self.init_params = locals()
|
|
|
- if backbone not in {
|
|
|
- 'MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3',
|
|
|
- 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34',
|
|
|
- 'ResNeXt50_32x4d'
|
|
|
- }:
|
|
|
- raise ValueError(
|
|
|
- "backbone: {} is not supported. Please choose one of "
|
|
|
- "{'MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3', 'MobileNetV3_ssld', 'DarkNet53', "
|
|
|
- "'ResNet50_vd_dcn', 'ResNet34', 'ResNeXt50_32x4d'}.".format(
|
|
|
- backbone))
|
|
|
+ self._check_backbone(backbone)
|
|
|
|
|
|
self.backbone_name = backbone
|
|
|
if params.get('with_net', True):
|
|
@@ -1058,6 +1094,7 @@ class _YOLOv3(BaseDetector):
|
|
|
kwargs['norm_type'] = norm_type
|
|
|
|
|
|
if 'MobileNetV3' in backbone:
|
|
|
+ backbone = 'MobileNetV3'
|
|
|
kwargs['feature_maps'] = [7, 13, 16]
|
|
|
elif backbone == 'ResNet50_vd_dcn':
|
|
|
kwargs.update(
|
|
@@ -1071,14 +1108,10 @@ class _YOLOv3(BaseDetector):
|
|
|
kwargs.update(
|
|
|
dict(
|
|
|
return_idx=[1, 2, 3], freeze_at=-1, freeze_norm=False))
|
|
|
- elif backbone == 'ResNeXt50_32x4d':
|
|
|
- backbone = 'ResNet50'
|
|
|
- kwargs.update(
|
|
|
- dict(
|
|
|
- return_idx=[1, 2, 3],
|
|
|
- base_width=4,
|
|
|
- groups=32,
|
|
|
- freeze_norm=False))
|
|
|
+ elif backbone == 'DarkNet53':
|
|
|
+ backbone = 'DarkNet'
|
|
|
+ elif 'MobileNet' in backbone:
|
|
|
+ backbone = 'MobileNet'
|
|
|
|
|
|
backbone = self._get_backbone(backbone, **kwargs)
|
|
|
nms = ppdet.modeling.MultiClassNMS(
|
|
@@ -1087,62 +1120,36 @@ class _YOLOv3(BaseDetector):
|
|
|
keep_top_k=nms_keep_topk,
|
|
|
nms_threshold=nms_iou_threshold,
|
|
|
normalized=nms_normalized)
|
|
|
- if rotate:
|
|
|
- neck = ppdet.modeling.FPN(
|
|
|
- in_channels=[i.channels for i in backbone.out_shape],
|
|
|
- out_channel=256,
|
|
|
- has_extra_convs=True,
|
|
|
- use_c5=False,
|
|
|
- relu_before_extra_convs=True)
|
|
|
- assigner = ppdet.modeling.FCOSRAssigner(
|
|
|
- num_classes=num_classes,
|
|
|
- factor=12,
|
|
|
- threshold=0.23,
|
|
|
- boundary=[[-1, 64], [64, 128], [128, 256], [256, 512],
|
|
|
- [512, 100000000.0]])
|
|
|
- yolo_head = ppdet.modeling.FCOSRHead(
|
|
|
- num_classes=num_classes,
|
|
|
- in_channels=[i.channels for i in neck.out_shape],
|
|
|
- feat_channels=256,
|
|
|
- fpn_strides=[8, 16, 32, 64, 128],
|
|
|
- stacked_convs=4,
|
|
|
- loss_weight={'class': 1.,
|
|
|
- 'probiou': 1.},
|
|
|
- assigner=assigner,
|
|
|
- nms=nms)
|
|
|
- post_process = None
|
|
|
- else:
|
|
|
- neck = ppdet.modeling.YOLOv3FPN(
|
|
|
- norm_type=norm_type,
|
|
|
- in_channels=[i.channels for i in backbone.out_shape])
|
|
|
- loss = ppdet.modeling.YOLOv3Loss(
|
|
|
- num_classes=num_classes,
|
|
|
- ignore_thresh=ignore_threshold,
|
|
|
- label_smooth=label_smooth)
|
|
|
- yolo_head = ppdet.modeling.YOLOv3Head(
|
|
|
- in_channels=[i.channels for i in neck.out_shape],
|
|
|
- anchors=anchors,
|
|
|
- anchor_masks=anchor_masks,
|
|
|
- num_classes=num_classes,
|
|
|
- loss=loss)
|
|
|
- post_process = ppdet.modeling.BBoxPostProcess(
|
|
|
- decode=ppdet.modeling.YOLOBox(num_classes=num_classes),
|
|
|
- nms=nms)
|
|
|
- post_process = ppdet.modeling.BBoxPostProcess(
|
|
|
- decode=ppdet.modeling.YOLOBox(num_classes=num_classes),
|
|
|
- nms=ppdet.modeling.MultiClassNMS(
|
|
|
- score_threshold=nms_score_threshold,
|
|
|
- nms_top_k=nms_topk,
|
|
|
- keep_top_k=nms_keep_topk,
|
|
|
- nms_threshold=nms_iou_threshold,
|
|
|
- normalized=nms_normalized))
|
|
|
+ neck = ppdet.modeling.YOLOv3FPN(
|
|
|
+ norm_type=norm_type,
|
|
|
+ in_channels=[i.channels for i in backbone.out_shape])
|
|
|
+ loss = ppdet.modeling.YOLOv3Loss(
|
|
|
+ num_classes=num_classes,
|
|
|
+ ignore_thresh=ignore_threshold,
|
|
|
+ label_smooth=label_smooth)
|
|
|
+ yolo_head = ppdet.modeling.YOLOv3Head(
|
|
|
+ in_channels=[i.channels for i in neck.out_shape],
|
|
|
+ anchors=anchors,
|
|
|
+ anchor_masks=anchor_masks,
|
|
|
+ num_classes=num_classes,
|
|
|
+ loss=loss)
|
|
|
+ post_process = ppdet.modeling.BBoxPostProcess(
|
|
|
+ decode=ppdet.modeling.YOLOBox(num_classes=num_classes), nms=nms)
|
|
|
+ post_process = ppdet.modeling.BBoxPostProcess(
|
|
|
+ decode=ppdet.modeling.YOLOBox(num_classes=num_classes),
|
|
|
+ nms=ppdet.modeling.MultiClassNMS(
|
|
|
+ score_threshold=nms_score_threshold,
|
|
|
+ nms_top_k=nms_topk,
|
|
|
+ keep_top_k=nms_keep_topk,
|
|
|
+ nms_threshold=nms_iou_threshold,
|
|
|
+ normalized=nms_normalized))
|
|
|
params.update({
|
|
|
'backbone': backbone,
|
|
|
'neck': neck,
|
|
|
'yolo_head': yolo_head,
|
|
|
'post_process': post_process
|
|
|
})
|
|
|
- super(_YOLOv3, self).__init__(
|
|
|
+ super(YOLOv3, self).__init__(
|
|
|
model_name='YOLOv3', num_classes=num_classes, **params)
|
|
|
self.anchors = anchors
|
|
|
self.anchor_masks = anchor_masks
|
|
@@ -1196,6 +1203,10 @@ class _YOLOv3(BaseDetector):
|
|
|
|
|
|
|
|
|
class FasterRCNN(BaseDetector):
|
|
|
+ supported_backbones = ('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld',
|
|
|
+ 'ResNet34', 'ResNet34_vd', 'ResNet101',
|
|
|
+ 'ResNet101_vd', 'HRNet_W18')
|
|
|
+
|
|
|
def __init__(self,
|
|
|
num_classes=80,
|
|
|
backbone='ResNet50',
|
|
@@ -1213,26 +1224,23 @@ class FasterRCNN(BaseDetector):
|
|
|
test_post_nms_top_n=1000,
|
|
|
**params):
|
|
|
self.init_params = locals()
|
|
|
- if backbone not in {
|
|
|
- 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34',
|
|
|
- 'ResNet34_vd', 'ResNet101', 'ResNet101_vd', 'HRNet_W18'
|
|
|
- }:
|
|
|
- raise ValueError(
|
|
|
- "backbone: {} is not supported. Please choose one of "
|
|
|
- "{'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34', 'ResNet34_vd', "
|
|
|
- "'ResNet101', 'ResNet101_vd', 'HRNet_W18'}.".format(backbone))
|
|
|
+ self._check_backbone(backbone)
|
|
|
+
|
|
|
self.backbone_name = backbone
|
|
|
|
|
|
if params.get('with_net', True):
|
|
|
dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
|
|
|
kwargs = {}
|
|
|
- kwargs['dcn_v2_stages'] = dcn_v2_stages
|
|
|
if backbone == 'HRNet_W18':
|
|
|
if not with_fpn:
|
|
|
logging.warning(
|
|
|
"Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
|
|
|
format(backbone))
|
|
|
with_fpn = True
|
|
|
+ kwargs.update(
|
|
|
+ dict(
|
|
|
+ width=18, freeze_at=0, return_idx=[0, 1, 2, 3]))
|
|
|
+ backbone = 'HRNet'
|
|
|
if with_dcn:
|
|
|
logging.warning(
|
|
|
"Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
|
|
@@ -1244,13 +1252,13 @@ class FasterRCNN(BaseDetector):
|
|
|
format(backbone))
|
|
|
with_fpn = True
|
|
|
kwargs['lr_mult_list'] = [0.05, 0.05, 0.1, 0.15]
|
|
|
+ kwargs['dcn_v2_stages'] = dcn_v2_stages
|
|
|
elif 'ResNet50' in backbone:
|
|
|
if not with_fpn and with_dcn:
|
|
|
logging.warning(
|
|
|
"Backbone {} without fpn should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
|
|
|
format(backbone))
|
|
|
- kwargs.update(dict(return_idx=[2], num_stages=3))
|
|
|
- kwargs.pop('dcn_v2_stages')
|
|
|
+ kwargs.update(dict(return_idx=[2], num_stages=3))
|
|
|
elif 'ResNet34' in backbone:
|
|
|
if not with_fpn:
|
|
|
logging.warning(
|
|
@@ -1455,7 +1463,10 @@ class FasterRCNN(BaseDetector):
|
|
|
return self._define_input_spec(image_shape)
|
|
|
|
|
|
|
|
|
-class PPYOLO(_YOLOv3):
|
|
|
+class PPYOLO(YOLOv3):
|
|
|
+ supported_backbones = ('ResNet50_vd_dcn', 'ResNet18_vd',
|
|
|
+ 'MobileNetV3_large', 'MobileNetV3_small')
|
|
|
+
|
|
|
def __init__(self,
|
|
|
num_classes=80,
|
|
|
backbone='ResNet50_vd_dcn',
|
|
@@ -1476,14 +1487,8 @@ class PPYOLO(_YOLOv3):
|
|
|
nms_iou_threshold=0.45,
|
|
|
**params):
|
|
|
self.init_params = locals()
|
|
|
- if backbone not in {
|
|
|
- 'ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large',
|
|
|
- 'MobileNetV3_small'
|
|
|
- }:
|
|
|
- raise ValueError(
|
|
|
- "backbone: {} is not supported. Please choose one of "
|
|
|
- "{'ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large', 'MobileNetV3_small'}.".
|
|
|
- format(backbone))
|
|
|
+ self._check_backbone(backbone)
|
|
|
+
|
|
|
self.backbone_name = backbone
|
|
|
self.downsample_ratios = [
|
|
|
32, 16, 8
|
|
@@ -1514,7 +1519,7 @@ class PPYOLO(_YOLOv3):
|
|
|
|
|
|
if backbone == 'ResNet50_vd_dcn':
|
|
|
backbone = self._get_backbone(
|
|
|
- 'ResNet',
|
|
|
+ backbone,
|
|
|
variant='d',
|
|
|
norm_type=norm_type,
|
|
|
return_idx=[1, 2, 3],
|
|
@@ -1525,7 +1530,7 @@ class PPYOLO(_YOLOv3):
|
|
|
|
|
|
elif backbone == 'ResNet18_vd':
|
|
|
backbone = self._get_backbone(
|
|
|
- 'ResNet',
|
|
|
+ backbone,
|
|
|
depth=18,
|
|
|
variant='d',
|
|
|
norm_type=norm_type,
|
|
@@ -1614,7 +1619,8 @@ class PPYOLO(_YOLOv3):
|
|
|
'post_process': post_process
|
|
|
})
|
|
|
|
|
|
- super(PPYOLO, self).__init__(
|
|
|
+ # NOTE: call BaseDetector.__init__ instead of YOLOv3.__init__
|
|
|
+ super(YOLOv3, self).__init__(
|
|
|
model_name='YOLOv3', num_classes=num_classes, **params)
|
|
|
self.anchors = anchors
|
|
|
self.anchor_masks = anchor_masks
|
|
@@ -1643,7 +1649,9 @@ class PPYOLO(_YOLOv3):
|
|
|
return self._define_input_spec(image_shape)
|
|
|
|
|
|
|
|
|
-class PPYOLOTiny(_YOLOv3):
|
|
|
+class PPYOLOTiny(YOLOv3):
|
|
|
+ supported_backbones = ('MobileNetV3', )
|
|
|
+
|
|
|
def __init__(self,
|
|
|
num_classes=80,
|
|
|
backbone='MobileNetV3',
|
|
@@ -1668,6 +1676,7 @@ class PPYOLOTiny(_YOLOv3):
|
|
|
logging.warning("PPYOLOTiny only supports MobileNetV3 as backbone. "
|
|
|
"Backbone is forcibly set to MobileNetV3.")
|
|
|
self.backbone_name = 'MobileNetV3'
|
|
|
+
|
|
|
self.downsample_ratios = [32, 16, 8]
|
|
|
if params.get('with_net', True):
|
|
|
if paddlers.env_info['place'] == 'gpu' and paddlers.env_info[
|
|
@@ -1741,7 +1750,8 @@ class PPYOLOTiny(_YOLOv3):
|
|
|
'post_process': post_process
|
|
|
})
|
|
|
|
|
|
- super(PPYOLOTiny, self).__init__(
|
|
|
+ # NOTE: call BaseDetector.__init__ instead of YOLOv3.__init__
|
|
|
+ super(YOLOv3, self).__init__(
|
|
|
model_name='YOLOv3', num_classes=num_classes, **params)
|
|
|
self.anchors = anchors
|
|
|
self.anchor_masks = anchor_masks
|
|
@@ -1771,7 +1781,9 @@ class PPYOLOTiny(_YOLOv3):
|
|
|
return self._define_input_spec(image_shape)
|
|
|
|
|
|
|
|
|
-class PPYOLOv2(_YOLOv3):
|
|
|
+class PPYOLOv2(YOLOv3):
|
|
|
+ supported_backbones = ('ResNet50_vd_dcn', 'ResNet101_vd_dcn')
|
|
|
+
|
|
|
def __init__(self,
|
|
|
num_classes=80,
|
|
|
backbone='ResNet50_vd_dcn',
|
|
@@ -1792,10 +1804,7 @@ class PPYOLOv2(_YOLOv3):
|
|
|
nms_iou_threshold=0.45,
|
|
|
**params):
|
|
|
self.init_params = locals()
|
|
|
- if backbone not in {'ResNet50_vd_dcn', 'ResNet101_vd_dcn'}:
|
|
|
- raise ValueError(
|
|
|
- "backbone: {} is not supported. Please choose one of "
|
|
|
- "{'ResNet50_vd_dcn', 'ResNet101_vd_dcn'}.".format(backbone))
|
|
|
+ self._check_backbone(backbone)
|
|
|
self.backbone_name = backbone
|
|
|
self.downsample_ratios = [32, 16, 8]
|
|
|
|
|
@@ -1808,8 +1817,7 @@ class PPYOLOv2(_YOLOv3):
|
|
|
|
|
|
if backbone == 'ResNet50_vd_dcn':
|
|
|
backbone = self._get_backbone(
|
|
|
- 'ResNet',
|
|
|
- variant='d',
|
|
|
+ backbone,
|
|
|
norm_type=norm_type,
|
|
|
return_idx=[1, 2, 3],
|
|
|
dcn_v2_stages=[3],
|
|
@@ -1819,9 +1827,7 @@ class PPYOLOv2(_YOLOv3):
|
|
|
|
|
|
elif backbone == 'ResNet101_vd_dcn':
|
|
|
backbone = self._get_backbone(
|
|
|
- 'ResNet',
|
|
|
- depth=101,
|
|
|
- variant='d',
|
|
|
+ backbone,
|
|
|
norm_type=norm_type,
|
|
|
return_idx=[1, 2, 3],
|
|
|
dcn_v2_stages=[3],
|
|
@@ -1888,7 +1894,8 @@ class PPYOLOv2(_YOLOv3):
|
|
|
'post_process': post_process
|
|
|
})
|
|
|
|
|
|
- super(PPYOLOv2, self).__init__(
|
|
|
+ # NOTE: call BaseDetector.__init__ instead of YOLOv3.__init__
|
|
|
+ super(YOLOv3, self).__init__(
|
|
|
model_name='YOLOv3', num_classes=num_classes, **params)
|
|
|
self.anchors = anchors
|
|
|
self.anchor_masks = anchor_masks
|
|
@@ -1919,6 +1926,9 @@ class PPYOLOv2(_YOLOv3):
|
|
|
|
|
|
|
|
|
class MaskRCNN(BaseDetector):
|
|
|
+ supported_backbones = ('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld',
|
|
|
+ 'ResNet101', 'ResNet101_vd')
|
|
|
+
|
|
|
def __init__(self,
|
|
|
num_classes=80,
|
|
|
backbone='ResNet50_vd',
|
|
@@ -1936,14 +1946,7 @@ class MaskRCNN(BaseDetector):
|
|
|
test_post_nms_top_n=1000,
|
|
|
**params):
|
|
|
self.init_params = locals()
|
|
|
- if backbone not in {
|
|
|
- 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101',
|
|
|
- 'ResNet101_vd'
|
|
|
- }:
|
|
|
- raise ValueError(
|
|
|
- "backbone: {} is not supported. Please choose one of "
|
|
|
- "{'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101', 'ResNet101_vd'}.".
|
|
|
- format(backbone))
|
|
|
+ self._check_backbone(backbone)
|
|
|
|
|
|
self.backbone_name = backbone + '_fpn' if with_fpn else backbone
|
|
|
dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
|
|
@@ -2187,5 +2190,212 @@ class MaskRCNN(BaseDetector):
|
|
|
return self._define_input_spec(image_shape)
|
|
|
|
|
|
|
|
|
-YOLOv3 = functools.partial(_YOLOv3, rotate=False)
|
|
|
-FCOSR = functools.partial(_YOLOv3, rotate=True)
|
|
|
+class FCOSR(YOLOv3):
|
|
|
+ supported_backbones = {'ResNeXt50_32x4d'}
|
|
|
+
|
|
|
+ def __init__(self,
|
|
|
+ num_classes=80,
|
|
|
+ backbone='ResNeXt50_32x4d',
|
|
|
+ post_process=None,
|
|
|
+ anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
|
|
|
+ [59, 119], [116, 90], [156, 198], [373, 326]],
|
|
|
+ anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
|
|
|
+ nms_score_threshold=0.01,
|
|
|
+ nms_topk=1000,
|
|
|
+ nms_keep_topk=100,
|
|
|
+ nms_iou_threshold=0.45,
|
|
|
+ nms_normalized=True,
|
|
|
+ **params):
|
|
|
+ self.init_params = locals()
|
|
|
+ self._check_backbone(backbone)
|
|
|
+
|
|
|
+ self.backbone_name = backbone
|
|
|
+ if params.get('with_net', True):
|
|
|
+ if paddlers.env_info['place'] == 'gpu' and paddlers.env_info[
|
|
|
+ 'num'] > 1 and not os.environ.get('PADDLERS_EXPORT_STAGE'):
|
|
|
+ norm_type = 'sync_bn'
|
|
|
+ else:
|
|
|
+ norm_type = 'bn'
|
|
|
+
|
|
|
+ kwargs = {}
|
|
|
+ kwargs['norm_type'] = norm_type
|
|
|
+
|
|
|
+ backbone = 'ResNet50'
|
|
|
+ kwargs.update(
|
|
|
+ dict(
|
|
|
+ return_idx=[1, 2, 3],
|
|
|
+ base_width=4,
|
|
|
+ groups=32,
|
|
|
+ freeze_norm=False))
|
|
|
+
|
|
|
+ backbone = self._get_backbone(backbone, **kwargs)
|
|
|
+ nms = ppdet.modeling.MultiClassNMS(
|
|
|
+ score_threshold=nms_score_threshold,
|
|
|
+ nms_top_k=nms_topk,
|
|
|
+ keep_top_k=nms_keep_topk,
|
|
|
+ nms_threshold=nms_iou_threshold,
|
|
|
+ normalized=nms_normalized)
|
|
|
+ neck = ppdet.modeling.FPN(
|
|
|
+ in_channels=[i.channels for i in backbone.out_shape],
|
|
|
+ out_channel=256,
|
|
|
+ has_extra_convs=True,
|
|
|
+ use_c5=False,
|
|
|
+ relu_before_extra_convs=True)
|
|
|
+ assigner = ppdet.modeling.FCOSRAssigner(
|
|
|
+ num_classes=num_classes,
|
|
|
+ factor=12,
|
|
|
+ threshold=0.23,
|
|
|
+ boundary=[[-1, 64], [64, 128], [128, 256], [256, 512],
|
|
|
+ [512, 100000000.0]])
|
|
|
+ yolo_head = ppdet.modeling.FCOSRHead(
|
|
|
+ num_classes=num_classes,
|
|
|
+ in_channels=[i.channels for i in neck.out_shape],
|
|
|
+ feat_channels=256,
|
|
|
+ fpn_strides=[8, 16, 32, 64, 128],
|
|
|
+ stacked_convs=4,
|
|
|
+ loss_weight={'class': 1.,
|
|
|
+ 'probiou': 1.},
|
|
|
+ assigner=assigner,
|
|
|
+ nms=nms)
|
|
|
+ post_process = None
|
|
|
+ params.update({
|
|
|
+ 'backbone': backbone,
|
|
|
+ 'neck': neck,
|
|
|
+ 'yolo_head': yolo_head,
|
|
|
+ 'post_process': post_process
|
|
|
+ })
|
|
|
+ # NOTE: call BaseDetector.__init__ instead of YOLOv3.__init__
|
|
|
+ super(YOLOv3, self).__init__(
|
|
|
+ model_name='YOLOv3', num_classes=num_classes, **params)
|
|
|
+ self.model_name = 'FCOSR'
|
|
|
+ self.anchors = anchors
|
|
|
+ self.anchor_masks = anchor_masks
|
|
|
+
|
|
|
+ def _default_batch_transforms(self, mode='train'):
|
|
|
+ if mode == 'train':
|
|
|
+ batch_transforms = [BatchPadRGT(), _BatchPad(pad_to_stride=32)]
|
|
|
+ else:
|
|
|
+ batch_transforms = [_BatchPad(pad_to_stride=32)]
|
|
|
+
|
|
|
+ if mode == 'eval' and self.metric == 'voc':
|
|
|
+ collate_batch = False
|
|
|
+ else:
|
|
|
+ collate_batch = True
|
|
|
+
|
|
|
+ batch_transforms = BatchCompose(
|
|
|
+ batch_transforms, collate_batch=collate_batch)
|
|
|
+
|
|
|
+ return batch_transforms
|
|
|
+
|
|
|
+
|
|
|
+class PPYOLOE_R(YOLOv3):
|
|
|
+ supported_backbones = ('CSPResNet_m', 'CSPResNet_l', 'CSPResNet_s',
|
|
|
+ 'CSPResNet_x')
|
|
|
+
|
|
|
+ def __init__(self,
|
|
|
+ num_classes=80,
|
|
|
+ backbone='CSPResNet_l',
|
|
|
+ post_process=None,
|
|
|
+ anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
|
|
|
+ [59, 119], [116, 90], [156, 198], [373, 326]],
|
|
|
+ anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
|
|
|
+ nms_score_threshold=0.01,
|
|
|
+ nms_topk=1000,
|
|
|
+ nms_keep_topk=100,
|
|
|
+ nms_iou_threshold=0.45,
|
|
|
+ nms_normalized=True,
|
|
|
+ **params):
|
|
|
+ self.init_params = locals()
|
|
|
+ self._check_backbone(backbone)
|
|
|
+
|
|
|
+ self.backbone_name = backbone
|
|
|
+ if params.get('with_net', True):
|
|
|
+ if paddlers.env_info['place'] == 'gpu' and paddlers.env_info[
|
|
|
+ 'num'] > 1 and not os.environ.get('PADDLERS_EXPORT_STAGE'):
|
|
|
+ norm_type = 'sync_bn'
|
|
|
+ else:
|
|
|
+ norm_type = 'bn'
|
|
|
+
|
|
|
+ kwargs = {}
|
|
|
+ kwargs['norm_type'] = norm_type
|
|
|
+ kwargs.update(
|
|
|
+ dict(
|
|
|
+ layers=[3, 6, 6, 3],
|
|
|
+ channels=[64, 128, 256, 512, 1024],
|
|
|
+ return_idx=[1, 2, 3],
|
|
|
+ use_large_stem=True,
|
|
|
+ use_alpha=True))
|
|
|
+ if backbone == 'CSPResNet_l':
|
|
|
+ kwargs.update(dict(depth_mult=1.0, width_mult=1.0))
|
|
|
+ elif backbone == 'CSPResNet_m':
|
|
|
+ kwargs.update(dict(depth_mult=0.67, width_mult=0.75))
|
|
|
+ elif backbone == 'CSPResNet_s':
|
|
|
+ kwargs.update(dict(depth_mult=0.33, width_mult=0.5))
|
|
|
+ elif backbone == 'CSPResNet_x':
|
|
|
+ kwargs.update(dict(depth_mult=1.33, width_mult=1.25))
|
|
|
+ backbone = 'CSPResNet'
|
|
|
+
|
|
|
+ backbone = self._get_backbone(backbone, **kwargs)
|
|
|
+ nms = ppdet.modeling.MultiClassNMS(
|
|
|
+ score_threshold=nms_score_threshold,
|
|
|
+ nms_top_k=nms_topk,
|
|
|
+ keep_top_k=nms_keep_topk,
|
|
|
+ nms_threshold=nms_iou_threshold,
|
|
|
+ normalized=nms_normalized)
|
|
|
+ neck = ppdet.modeling.CustomCSPPAN(
|
|
|
+ in_channels=[i.channels for i in backbone.out_shape],
|
|
|
+ out_channels=[768, 384, 192],
|
|
|
+ stage_num=1,
|
|
|
+ block_num=3,
|
|
|
+ act='swish',
|
|
|
+ spp=True,
|
|
|
+ use_alpha=True)
|
|
|
+ static_assigner = ppdet.modeling.FCOSRAssigner(
|
|
|
+ num_classes=num_classes,
|
|
|
+ factor=12,
|
|
|
+ threshold=0.23,
|
|
|
+ boundary=[[512, 10000], [256, 512], [-1, 256]])
|
|
|
+ assigner = ppdet.modeling.RotatedTaskAlignedAssigner(
|
|
|
+ topk=13,
|
|
|
+ alpha=1.0,
|
|
|
+ beta=6.0, )
|
|
|
+ yolo_head = ppdet.modeling.PPYOLOERHead(
|
|
|
+ num_classes=num_classes,
|
|
|
+ in_channels=[i.channels for i in neck.out_shape],
|
|
|
+ fpn_strides=[32, 16, 8],
|
|
|
+ grid_cell_offset=0.5,
|
|
|
+ use_varifocal_loss=True,
|
|
|
+ loss_weight={'class': 1.,
|
|
|
+ 'iou': 2.5,
|
|
|
+ 'dfl': 0.05},
|
|
|
+ static_assigner=static_assigner,
|
|
|
+ assigner=assigner,
|
|
|
+ nms=nms)
|
|
|
+ params.update({
|
|
|
+ 'backbone': backbone,
|
|
|
+ 'neck': neck,
|
|
|
+ 'yolo_head': yolo_head,
|
|
|
+ 'post_process': post_process
|
|
|
+ })
|
|
|
+ # NOTE: call BaseDetector.__init__ instead of YOLOv3.__init__
|
|
|
+ super(YOLOv3, self).__init__(
|
|
|
+ model_name='YOLOv3', num_classes=num_classes, **params)
|
|
|
+ self.model_name = "PPYOLOE_R"
|
|
|
+ self.anchors = anchors
|
|
|
+ self.anchor_masks = anchor_masks
|
|
|
+
|
|
|
+ def _default_batch_transforms(self, mode='train'):
|
|
|
+ if mode == 'train':
|
|
|
+ batch_transforms = [BatchPadRGT(), _BatchPad(pad_to_stride=32)]
|
|
|
+ else:
|
|
|
+ batch_transforms = [_BatchPad(pad_to_stride=32)]
|
|
|
+
|
|
|
+ if mode == 'eval' and self.metric == 'voc':
|
|
|
+ collate_batch = False
|
|
|
+ else:
|
|
|
+ collate_batch = True
|
|
|
+
|
|
|
+ batch_transforms = BatchCompose(
|
|
|
+ batch_transforms, collate_batch=collate_batch)
|
|
|
+
|
|
|
+ return batch_transforms
|