object_detector.py 89 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import collections
  15. import copy
  16. import os
  17. import os.path as osp
  18. import numpy as np
  19. import paddle
  20. from paddle.static import InputSpec
  21. import paddlers
  22. import paddlers.models.ppdet as ppdet
  23. from paddlers.models.ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner
  24. from paddlers.transforms import decode_image
  25. from paddlers.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Pad
  26. from paddlers.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, \
  27. _BatchPad, _Gt2YoloTarget
  28. from paddlers.models.ppdet.optimizer import ModelEMA
  29. import paddlers.utils.logging as logging
  30. from paddlers.utils.checkpoint import det_pretrain_weights_dict
  31. from .base import BaseModel
  32. from .utils.det_metrics import VOCMetric, COCOMetric
  33. __all__ = [
  34. "YOLOv3", "FasterRCNN", "PPYOLO", "PPYOLOTiny", "PPYOLOv2", "MaskRCNN"
  35. ]
  36. class BaseDetector(BaseModel):
  37. def __init__(self, model_name, num_classes=80, **params):
  38. self.init_params.update(locals())
  39. if 'with_net' in self.init_params:
  40. del self.init_params['with_net']
  41. super(BaseDetector, self).__init__('detector')
  42. if not hasattr(ppdet.modeling, model_name):
  43. raise ValueError("ERROR: There is no model named {}.".format(
  44. model_name))
  45. self.model_name = model_name
  46. self.num_classes = num_classes
  47. self.labels = None
  48. if params.get('with_net', True):
  49. params.pop('with_net', None)
  50. self.net = self.build_net(**params)
  51. def build_net(self, **params):
  52. with paddle.utils.unique_name.guard():
  53. net = ppdet.modeling.__dict__[self.model_name](**params)
  54. return net
  55. def _fix_transforms_shape(self, image_shape):
  56. raise NotImplementedError("_fix_transforms_shape: not implemented!")
  57. def _define_input_spec(self, image_shape):
  58. input_spec = [{
  59. "image": InputSpec(
  60. shape=image_shape, name='image', dtype='float32'),
  61. "im_shape": InputSpec(
  62. shape=[image_shape[0], 2], name='im_shape', dtype='float32'),
  63. "scale_factor": InputSpec(
  64. shape=[image_shape[0], 2], name='scale_factor', dtype='float32')
  65. }]
  66. return input_spec
  67. def _check_image_shape(self, image_shape):
  68. if len(image_shape) == 2:
  69. image_shape = [1, 3] + image_shape
  70. if image_shape[-2] % 32 > 0 or image_shape[-1] % 32 > 0:
  71. raise ValueError(
  72. "Height and width in fixed_input_shape must be a multiple of 32, but received {}.".
  73. format(image_shape[-2:]))
  74. return image_shape
  75. def _get_test_inputs(self, image_shape):
  76. if image_shape is not None:
  77. image_shape = self._check_image_shape(image_shape)
  78. self._fix_transforms_shape(image_shape[-2:])
  79. else:
  80. image_shape = [None, 3, -1, -1]
  81. self.fixed_input_shape = image_shape
  82. return self._define_input_spec(image_shape)
  83. def _get_backbone(self, backbone_name, **params):
  84. backbone = getattr(ppdet.modeling, backbone_name)(**params)
  85. return backbone
  86. def run(self, net, inputs, mode):
  87. net_out = net(inputs)
  88. if mode in ['train', 'eval']:
  89. outputs = net_out
  90. else:
  91. outputs = dict()
  92. for key in net_out:
  93. outputs[key] = net_out[key].numpy()
  94. return outputs
  95. def default_optimizer(self,
  96. parameters,
  97. learning_rate,
  98. warmup_steps,
  99. warmup_start_lr,
  100. lr_decay_epochs,
  101. lr_decay_gamma,
  102. num_steps_each_epoch,
  103. reg_coeff=1e-04,
  104. scheduler='Piecewise',
  105. num_epochs=None):
  106. if scheduler.lower() == 'piecewise':
  107. if warmup_steps > 0 and warmup_steps > lr_decay_epochs[
  108. 0] * num_steps_each_epoch:
  109. logging.error(
  110. "In function train(), parameters must satisfy: "
  111. "warmup_steps <= lr_decay_epochs[0] * num_samples_in_train_dataset. "
  112. "See this doc for more information: "
  113. "https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/parameters.md",
  114. exit=False)
  115. logging.error(
  116. "Either `warmup_steps` be less than {} or lr_decay_epochs[0] be greater than {} "
  117. "must be satisfied, please modify 'warmup_steps' or 'lr_decay_epochs' in train function".
  118. format(lr_decay_epochs[0] * num_steps_each_epoch,
  119. warmup_steps // num_steps_each_epoch),
  120. exit=True)
  121. boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
  122. values = [(lr_decay_gamma**i) * learning_rate
  123. for i in range(len(lr_decay_epochs) + 1)]
  124. scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries, values)
  125. elif scheduler.lower() == 'cosine':
  126. if num_epochs is None:
  127. logging.error(
  128. "`num_epochs` must be set while using cosine annealing decay scheduler, but received {}".
  129. format(num_epochs),
  130. exit=False)
  131. if warmup_steps > 0 and warmup_steps > num_epochs * num_steps_each_epoch:
  132. logging.error(
  133. "In function train(), parameters must satisfy: "
  134. "warmup_steps <= num_epochs * num_samples_in_train_dataset. "
  135. "See this doc for more information: "
  136. "https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/parameters.md",
  137. exit=False)
  138. logging.error(
  139. "`warmup_steps` must be less than the total number of steps({}), "
  140. "please modify 'num_epochs' or 'warmup_steps' in train function".
  141. format(num_epochs * num_steps_each_epoch),
  142. exit=True)
  143. T_max = num_epochs * num_steps_each_epoch - warmup_steps
  144. scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
  145. learning_rate=learning_rate,
  146. T_max=T_max,
  147. eta_min=0.0,
  148. last_epoch=-1)
  149. else:
  150. logging.error(
  151. "Invalid learning rate scheduler: {}!".format(scheduler),
  152. exit=True)
  153. if warmup_steps > 0:
  154. scheduler = paddle.optimizer.lr.LinearWarmup(
  155. learning_rate=scheduler,
  156. warmup_steps=warmup_steps,
  157. start_lr=warmup_start_lr,
  158. end_lr=learning_rate)
  159. optimizer = paddle.optimizer.Momentum(
  160. scheduler,
  161. momentum=.9,
  162. weight_decay=paddle.regularizer.L2Decay(coeff=reg_coeff),
  163. parameters=parameters)
  164. return optimizer
  165. def train(self,
  166. num_epochs,
  167. train_dataset,
  168. train_batch_size=64,
  169. eval_dataset=None,
  170. optimizer=None,
  171. save_interval_epochs=1,
  172. log_interval_steps=10,
  173. save_dir='output',
  174. pretrain_weights='IMAGENET',
  175. learning_rate=.001,
  176. warmup_steps=0,
  177. warmup_start_lr=0.0,
  178. lr_decay_epochs=(216, 243),
  179. lr_decay_gamma=0.1,
  180. metric=None,
  181. use_ema=False,
  182. early_stop=False,
  183. early_stop_patience=5,
  184. use_vdl=True,
  185. resume_checkpoint=None):
  186. """
  187. Train the model.
  188. Args:
  189. num_epochs (int): Number of epochs.
  190. train_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset):
  191. Training dataset.
  192. train_batch_size (int, optional): Total batch size among all cards used in
  193. training. Defaults to 64.
  194. eval_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset|None, optional):
  195. Evaluation dataset. If None, the model will not be evaluated during training
  196. process. Defaults to None.
  197. optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used for
  198. training. If None, a default optimizer will be used. Defaults to None.
  199. save_interval_epochs (int, optional): Epoch interval for saving the model.
  200. Defaults to 1.
  201. log_interval_steps (int, optional): Step interval for printing training
  202. information. Defaults to 10.
  203. save_dir (str, optional): Directory to save the model. Defaults to 'output'.
  204. pretrain_weights (str|None, optional): None or name/path of pretrained
  205. weights. If None, no pretrained weights will be loaded.
  206. Defaults to 'IMAGENET'.
  207. learning_rate (float, optional): Learning rate for training. Defaults to .001.
  208. warmup_steps (int, optional): Number of steps of warm-up training.
  209. Defaults to 0.
  210. warmup_start_lr (float, optional): Start learning rate of warm-up training.
  211. Defaults to 0..
  212. lr_decay_epochs (list|tuple, optional): Epoch milestones for learning
  213. rate decay. Defaults to (216, 243).
  214. lr_decay_gamma (float, optional): Gamma coefficient of learning rate decay.
  215. Defaults to .1.
  216. metric (str|None, optional): Evaluation metric. Choices are {'VOC', 'COCO', None}.
  217. If None, determine the metric according to the dataset format.
  218. Defaults to None.
  219. use_ema (bool, optional): Whether to use exponential moving average
  220. strategy. Defaults to False.
  221. early_stop (bool, optional): Whether to adopt early stop strategy.
  222. Defaults to False.
  223. early_stop_patience (int, optional): Early stop patience. Defaults to 5.
  224. use_vdl(bool, optional): Whether to use VisualDL to monitor the training
  225. process. Defaults to True.
  226. resume_checkpoint (str|None, optional): Path of the checkpoint to resume
  227. training from. If None, no training checkpoint will be resumed. At most
  228. Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
  229. Defaults to None.
  230. """
  231. args = self._pre_train(locals())
  232. return self._real_train(**args)
  233. def _pre_train(self, in_args):
  234. return in_args
  235. def _real_train(self,
  236. num_epochs,
  237. train_dataset,
  238. train_batch_size=64,
  239. eval_dataset=None,
  240. optimizer=None,
  241. save_interval_epochs=1,
  242. log_interval_steps=10,
  243. save_dir='output',
  244. pretrain_weights='IMAGENET',
  245. learning_rate=.001,
  246. warmup_steps=0,
  247. warmup_start_lr=0.0,
  248. lr_decay_epochs=(216, 243),
  249. lr_decay_gamma=0.1,
  250. metric=None,
  251. use_ema=False,
  252. early_stop=False,
  253. early_stop_patience=5,
  254. use_vdl=True,
  255. resume_checkpoint=None):
  256. if self.status == 'Infer':
  257. logging.error(
  258. "Exported inference model does not support training.",
  259. exit=True)
  260. if pretrain_weights is not None and resume_checkpoint is not None:
  261. logging.error(
  262. "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
  263. exit=True)
  264. if train_dataset.__class__.__name__ == 'VOCDetDataset':
  265. train_dataset.data_fields = {
  266. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  267. 'difficult'
  268. }
  269. elif train_dataset.__class__.__name__ == 'CocoDetection':
  270. if self.__class__.__name__ == 'MaskRCNN':
  271. train_dataset.data_fields = {
  272. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  273. 'gt_poly', 'is_crowd'
  274. }
  275. else:
  276. train_dataset.data_fields = {
  277. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  278. 'is_crowd'
  279. }
  280. if metric is None:
  281. if eval_dataset.__class__.__name__ == 'VOCDetDataset':
  282. self.metric = 'voc'
  283. elif eval_dataset.__class__.__name__ == 'COCODetDataset':
  284. self.metric = 'coco'
  285. else:
  286. assert metric.lower() in ['coco', 'voc'], \
  287. "Evaluation metric {} is not supported. Please choose from 'COCO' and 'VOC'."
  288. self.metric = metric.lower()
  289. self.labels = train_dataset.labels
  290. self.num_max_boxes = train_dataset.num_max_boxes
  291. train_dataset.batch_transforms = self._compose_batch_transform(
  292. train_dataset.transforms, mode='train')
  293. # Build optimizer if not defined
  294. if optimizer is None:
  295. num_steps_each_epoch = len(train_dataset) // train_batch_size
  296. self.optimizer = self.default_optimizer(
  297. parameters=self.net.parameters(),
  298. learning_rate=learning_rate,
  299. warmup_steps=warmup_steps,
  300. warmup_start_lr=warmup_start_lr,
  301. lr_decay_epochs=lr_decay_epochs,
  302. lr_decay_gamma=lr_decay_gamma,
  303. num_steps_each_epoch=num_steps_each_epoch)
  304. else:
  305. self.optimizer = optimizer
  306. # Initiate weights
  307. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  308. if pretrain_weights not in det_pretrain_weights_dict['_'.join(
  309. [self.model_name, self.backbone_name])]:
  310. logging.warning(
  311. "Path of pretrain_weights('{}') does not exist!".format(
  312. pretrain_weights))
  313. pretrain_weights = det_pretrain_weights_dict['_'.join(
  314. [self.model_name, self.backbone_name])][0]
  315. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  316. "If you don't want to use pretrain weights, "
  317. "set pretrain_weights to be None.".format(
  318. pretrain_weights))
  319. elif pretrain_weights is not None and osp.exists(pretrain_weights):
  320. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  321. logging.error(
  322. "Invalid pretrain weights. Please specify a '.pdparams' file.",
  323. exit=True)
  324. pretrained_dir = osp.join(save_dir, 'pretrain')
  325. self.net_initialize(
  326. pretrain_weights=pretrain_weights,
  327. save_dir=pretrained_dir,
  328. resume_checkpoint=resume_checkpoint,
  329. is_backbone_weights=(pretrain_weights == 'IMAGENET' and
  330. 'ESNet_' in self.backbone_name))
  331. if use_ema:
  332. ema = ModelEMA(model=self.net, decay=.9998, use_thres_step=True)
  333. else:
  334. ema = None
  335. # Start train loop
  336. self.train_loop(
  337. num_epochs=num_epochs,
  338. train_dataset=train_dataset,
  339. train_batch_size=train_batch_size,
  340. eval_dataset=eval_dataset,
  341. save_interval_epochs=save_interval_epochs,
  342. log_interval_steps=log_interval_steps,
  343. save_dir=save_dir,
  344. ema=ema,
  345. early_stop=early_stop,
  346. early_stop_patience=early_stop_patience,
  347. use_vdl=use_vdl)
  348. def quant_aware_train(self,
  349. num_epochs,
  350. train_dataset,
  351. train_batch_size=64,
  352. eval_dataset=None,
  353. optimizer=None,
  354. save_interval_epochs=1,
  355. log_interval_steps=10,
  356. save_dir='output',
  357. learning_rate=.00001,
  358. warmup_steps=0,
  359. warmup_start_lr=0.0,
  360. lr_decay_epochs=(216, 243),
  361. lr_decay_gamma=0.1,
  362. metric=None,
  363. use_ema=False,
  364. early_stop=False,
  365. early_stop_patience=5,
  366. use_vdl=True,
  367. resume_checkpoint=None,
  368. quant_config=None):
  369. """
  370. Quantization-aware training.
  371. Args:
  372. num_epochs (int): Number of epochs.
  373. train_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset):
  374. Training dataset.
  375. train_batch_size (int, optional): Total batch size among all cards used in
  376. training. Defaults to 64.
  377. eval_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset|None, optional):
  378. Evaluation dataset. If None, the model will not be evaluated during training
  379. process. Defaults to None.
  380. optimizer (paddle.optimizer.Optimizer or None, optional): Optimizer used for
  381. training. If None, a default optimizer will be used. Defaults to None.
  382. save_interval_epochs (int, optional): Epoch interval for saving the model.
  383. Defaults to 1.
  384. log_interval_steps (int, optional): Step interval for printing training
  385. information. Defaults to 10.
  386. save_dir (str, optional): Directory to save the model. Defaults to 'output'.
  387. learning_rate (float, optional): Learning rate for training.
  388. Defaults to .00001.
  389. warmup_steps (int, optional): Number of steps of warm-up training.
  390. Defaults to 0.
  391. warmup_start_lr (float, optional): Start learning rate of warm-up training.
  392. Defaults to 0..
  393. lr_decay_epochs (list or tuple, optional): Epoch milestones for learning rate
  394. decay. Defaults to (216, 243).
  395. lr_decay_gamma (float, optional): Gamma coefficient of learning rate decay.
  396. Defaults to .1.
  397. metric (str|None, optional): Evaluation metric. Choices are {'VOC', 'COCO', None}.
  398. If None, determine the metric according to the dataset format.
  399. Defaults to None.
  400. use_ema (bool, optional): Whether to use exponential moving average strategy.
  401. Defaults to False.
  402. early_stop (bool, optional): Whether to adopt early stop strategy.
  403. Defaults to False.
  404. early_stop_patience (int, optional): Early stop patience. Defaults to 5.
  405. use_vdl (bool, optional): Whether to use VisualDL to monitor the training
  406. process. Defaults to True.
  407. quant_config (dict or None, optional): Quantization configuration. If None,
  408. a default rule of thumb configuration will be used. Defaults to None.
  409. resume_checkpoint (str|None, optional): Path of the checkpoint to resume
  410. quantization-aware training from. If None, no training checkpoint will
  411. be resumed. Defaults to None.
  412. """
  413. self._prepare_qat(quant_config)
  414. self.train(
  415. num_epochs=num_epochs,
  416. train_dataset=train_dataset,
  417. train_batch_size=train_batch_size,
  418. eval_dataset=eval_dataset,
  419. optimizer=optimizer,
  420. save_interval_epochs=save_interval_epochs,
  421. log_interval_steps=log_interval_steps,
  422. save_dir=save_dir,
  423. pretrain_weights=None,
  424. learning_rate=learning_rate,
  425. warmup_steps=warmup_steps,
  426. warmup_start_lr=warmup_start_lr,
  427. lr_decay_epochs=lr_decay_epochs,
  428. lr_decay_gamma=lr_decay_gamma,
  429. metric=metric,
  430. use_ema=use_ema,
  431. early_stop=early_stop,
  432. early_stop_patience=early_stop_patience,
  433. use_vdl=use_vdl,
  434. resume_checkpoint=resume_checkpoint)
  435. def evaluate(self,
  436. eval_dataset,
  437. batch_size=1,
  438. metric=None,
  439. return_details=False):
  440. """
  441. Evaluate the model.
  442. Args:
  443. eval_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset):
  444. Evaluation dataset.
  445. batch_size (int, optional): Total batch size among all cards used for
  446. evaluation. Defaults to 1.
  447. metric (str|None, optional): Evaluation metric. Choices are {'VOC', 'COCO', None}.
  448. If None, determine the metric according to the dataset format.
  449. Defaults to None.
  450. return_details (bool, optional): Whether to return evaluation details.
  451. Defaults to False.
  452. Returns:
  453. collections.OrderedDict with key-value pairs:
  454. {"bbox_mmap":`mean average precision (0.50, 11point)`}.
  455. """
  456. if metric is None:
  457. if not hasattr(self, 'metric'):
  458. if eval_dataset.__class__.__name__ == 'VOCDetDataset':
  459. self.metric = 'voc'
  460. elif eval_dataset.__class__.__name__ == 'COCODetDataset':
  461. self.metric = 'coco'
  462. else:
  463. assert metric.lower() in ['coco', 'voc'], \
  464. "Evaluation metric {} is not supported. Please choose from 'COCO' and 'VOC'."
  465. self.metric = metric.lower()
  466. if self.metric == 'voc':
  467. eval_dataset.data_fields = {
  468. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  469. 'difficult'
  470. }
  471. elif self.metric == 'coco':
  472. if self.__class__.__name__ == 'MaskRCNN':
  473. eval_dataset.data_fields = {
  474. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  475. 'gt_poly', 'is_crowd'
  476. }
  477. else:
  478. eval_dataset.data_fields = {
  479. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  480. 'is_crowd'
  481. }
  482. eval_dataset.batch_transforms = self._compose_batch_transform(
  483. eval_dataset.transforms, mode='eval')
  484. self._check_transforms(eval_dataset.transforms, 'eval')
  485. self.net.eval()
  486. nranks = paddle.distributed.get_world_size()
  487. local_rank = paddle.distributed.get_rank()
  488. if nranks > 1:
  489. # Initialize parallel environment if not done.
  490. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  491. ):
  492. paddle.distributed.init_parallel_env()
  493. if batch_size > 1:
  494. logging.warning(
  495. "Detector only supports single card evaluation with batch_size=1 "
  496. "during evaluation, so batch_size is forcibly set to 1.")
  497. batch_size = 1
  498. if nranks < 2 or local_rank == 0:
  499. self.eval_data_loader = self.build_data_loader(
  500. eval_dataset, batch_size=batch_size, mode='eval')
  501. is_bbox_normalized = False
  502. if eval_dataset.batch_transforms is not None:
  503. is_bbox_normalized = any(
  504. isinstance(t, _NormalizeBox)
  505. for t in eval_dataset.batch_transforms.batch_transforms)
  506. if self.metric == 'voc':
  507. eval_metric = VOCMetric(
  508. labels=eval_dataset.labels,
  509. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  510. is_bbox_normalized=is_bbox_normalized,
  511. classwise=False)
  512. else:
  513. eval_metric = COCOMetric(
  514. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  515. classwise=False)
  516. scores = collections.OrderedDict()
  517. logging.info(
  518. "Start to evaluate(total_samples={}, total_steps={})...".format(
  519. eval_dataset.num_samples, eval_dataset.num_samples))
  520. with paddle.no_grad():
  521. for step, data in enumerate(self.eval_data_loader):
  522. outputs = self.run(self.net, data, 'eval')
  523. eval_metric.update(data, outputs)
  524. eval_metric.accumulate()
  525. self.eval_details = eval_metric.details
  526. scores.update(eval_metric.get())
  527. eval_metric.reset()
  528. if return_details:
  529. return scores, self.eval_details
  530. return scores
  531. def predict(self, img_file, transforms=None):
  532. """
  533. Do inference.
  534. Args:
  535. img_file (list[np.ndarray|str] | str | np.ndarray): Image path or decoded
  536. image data, which also could constitute a list, meaning all images to be
  537. predicted as a mini-batch.
  538. transforms (paddlers.transforms.Compose|None, optional): Transforms for
  539. inputs. If None, the transforms for evaluation process will be used.
  540. Defaults to None.
  541. Returns:
  542. If `img_file` is a string or np.array, the result is a list of dict with
  543. key-value pairs:
  544. {"category_id": `category_id`,
  545. "category": `category`,
  546. "bbox": `[x, y, w, h]`,
  547. "score": `score`,
  548. "mask": `mask`}.
  549. If `img_file` is a list, the result is a list composed of list of dicts
  550. with the corresponding fields:
  551. category_id(int): the predicted category ID. 0 represents the first
  552. category in the dataset, and so on.
  553. category(str): category name
  554. bbox(list): bounding box in [x, y, w, h] format
  555. score(str): confidence
  556. mask(dict): Only for instance segmentation task. Mask of the object in
  557. RLE format
  558. """
  559. if transforms is None and not hasattr(self, 'test_transforms'):
  560. raise ValueError("transforms need to be defined, now is None.")
  561. if transforms is None:
  562. transforms = self.test_transforms
  563. if isinstance(img_file, (str, np.ndarray)):
  564. images = [img_file]
  565. else:
  566. images = img_file
  567. batch_samples = self.preprocess(images, transforms)
  568. self.net.eval()
  569. outputs = self.run(self.net, batch_samples, 'test')
  570. prediction = self.postprocess(outputs)
  571. if isinstance(img_file, (str, np.ndarray)):
  572. prediction = prediction[0]
  573. return prediction
  574. def preprocess(self, images, transforms, to_tensor=True):
  575. self._check_transforms(transforms, 'test')
  576. batch_samples = list()
  577. for im in images:
  578. if isinstance(im, str):
  579. im = decode_image(im, to_rgb=False)
  580. sample = {'image': im}
  581. sample = transforms(sample)
  582. batch_samples.append(sample)
  583. batch_transforms = self._compose_batch_transform(transforms, 'test')
  584. batch_samples = batch_transforms(batch_samples)
  585. if to_tensor:
  586. for k in batch_samples:
  587. batch_samples[k] = paddle.to_tensor(batch_samples[k])
  588. return batch_samples
  589. def postprocess(self, batch_pred):
  590. infer_result = {}
  591. if 'bbox' in batch_pred:
  592. bboxes = batch_pred['bbox']
  593. bbox_nums = batch_pred['bbox_num']
  594. det_res = []
  595. k = 0
  596. for i in range(len(bbox_nums)):
  597. det_nums = bbox_nums[i]
  598. for j in range(det_nums):
  599. dt = bboxes[k]
  600. k = k + 1
  601. num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
  602. if int(num_id) < 0:
  603. continue
  604. category = self.labels[int(num_id)]
  605. w = xmax - xmin
  606. h = ymax - ymin
  607. bbox = [xmin, ymin, w, h]
  608. dt_res = {
  609. 'category_id': int(num_id),
  610. 'category': category,
  611. 'bbox': bbox,
  612. 'score': score
  613. }
  614. det_res.append(dt_res)
  615. infer_result['bbox'] = det_res
  616. if 'mask' in batch_pred:
  617. masks = batch_pred['mask']
  618. bboxes = batch_pred['bbox']
  619. mask_nums = batch_pred['bbox_num']
  620. seg_res = []
  621. k = 0
  622. for i in range(len(mask_nums)):
  623. det_nums = mask_nums[i]
  624. for j in range(det_nums):
  625. mask = masks[k].astype(np.uint8)
  626. score = float(bboxes[k][1])
  627. label = int(bboxes[k][0])
  628. k = k + 1
  629. if label == -1:
  630. continue
  631. category = self.labels[int(label)]
  632. sg_res = {
  633. 'category_id': int(label),
  634. 'category': category,
  635. 'mask': mask.astype('uint8'),
  636. 'score': score
  637. }
  638. seg_res.append(sg_res)
  639. infer_result['mask'] = seg_res
  640. bbox_num = batch_pred['bbox_num']
  641. results = []
  642. start = 0
  643. for num in bbox_num:
  644. end = start + num
  645. curr_res = infer_result['bbox'][start:end]
  646. if 'mask' in infer_result:
  647. mask_res = infer_result['mask'][start:end]
  648. for box, mask in zip(curr_res, mask_res):
  649. box.update(mask)
  650. results.append(curr_res)
  651. start = end
  652. return results
  653. def _check_transforms(self, transforms, mode):
  654. super()._check_transforms(transforms, mode)
  655. if not isinstance(transforms.arrange,
  656. paddlers.transforms.ArrangeDetector):
  657. raise TypeError(
  658. "`transforms.arrange` must be an ArrangeDetector object.")
  659. class PicoDet(BaseDetector):
  660. def __init__(self,
  661. num_classes=80,
  662. backbone='ESNet_m',
  663. nms_score_threshold=.025,
  664. nms_topk=1000,
  665. nms_keep_topk=100,
  666. nms_iou_threshold=.6,
  667. **params):
  668. self.init_params = locals()
  669. if backbone not in {
  670. 'ESNet_s', 'ESNet_m', 'ESNet_l', 'LCNet', 'MobileNetV3',
  671. 'ResNet18_vd'
  672. }:
  673. raise ValueError(
  674. "backbone: {} is not supported. Please choose one of "
  675. "{'ESNet_s', 'ESNet_m', 'ESNet_l', 'LCNet', 'MobileNetV3', 'ResNet18_vd'}.".
  676. format(backbone))
  677. self.backbone_name = backbone
  678. if params.get('with_net', True):
  679. if backbone == 'ESNet_s':
  680. backbone = self._get_backbone(
  681. 'ESNet',
  682. scale=.75,
  683. feature_maps=[4, 11, 14],
  684. act="hard_swish",
  685. channel_ratio=[
  686. 0.875, 0.5, 0.5, 0.5, 0.625, 0.5, 0.625, 0.5, 0.5, 0.5,
  687. 0.5, 0.5, 0.5
  688. ])
  689. neck_out_channels = 96
  690. head_num_convs = 2
  691. elif backbone == 'ESNet_m':
  692. backbone = self._get_backbone(
  693. 'ESNet',
  694. scale=1.0,
  695. feature_maps=[4, 11, 14],
  696. act="hard_swish",
  697. channel_ratio=[
  698. 0.875, 0.5, 1.0, 0.625, 0.5, 0.75, 0.625, 0.625, 0.5,
  699. 0.625, 1.0, 0.625, 0.75
  700. ])
  701. neck_out_channels = 128
  702. head_num_convs = 4
  703. elif backbone == 'ESNet_l':
  704. backbone = self._get_backbone(
  705. 'ESNet',
  706. scale=1.25,
  707. feature_maps=[4, 11, 14],
  708. act="hard_swish",
  709. channel_ratio=[
  710. 0.875, 0.5, 1.0, 0.625, 0.5, 0.75, 0.625, 0.625, 0.5,
  711. 0.625, 1.0, 0.625, 0.75
  712. ])
  713. neck_out_channels = 160
  714. head_num_convs = 4
  715. elif backbone == 'LCNet':
  716. backbone = self._get_backbone(
  717. 'LCNet', scale=1.5, feature_maps=[3, 4, 5])
  718. neck_out_channels = 128
  719. head_num_convs = 4
  720. elif backbone == 'MobileNetV3':
  721. backbone = self._get_backbone(
  722. 'MobileNetV3',
  723. scale=1.0,
  724. with_extra_blocks=False,
  725. extra_block_filters=[],
  726. feature_maps=[7, 13, 16])
  727. neck_out_channels = 128
  728. head_num_convs = 4
  729. else:
  730. backbone = self._get_backbone(
  731. 'ResNet',
  732. depth=18,
  733. variant='d',
  734. return_idx=[1, 2, 3],
  735. freeze_at=-1,
  736. freeze_norm=False,
  737. norm_decay=0.)
  738. neck_out_channels = 128
  739. head_num_convs = 4
  740. neck = ppdet.modeling.CSPPAN(
  741. in_channels=[i.channels for i in backbone.out_shape],
  742. out_channels=neck_out_channels,
  743. num_features=4,
  744. num_csp_blocks=1,
  745. use_depthwise=True)
  746. head_conv_feat = ppdet.modeling.PicoFeat(
  747. feat_in=neck_out_channels,
  748. feat_out=neck_out_channels,
  749. num_fpn_stride=4,
  750. num_convs=head_num_convs,
  751. norm_type='bn',
  752. share_cls_reg=True, )
  753. loss_class = ppdet.modeling.VarifocalLoss(
  754. use_sigmoid=True, iou_weighted=True, loss_weight=1.0)
  755. loss_dfl = ppdet.modeling.DistributionFocalLoss(loss_weight=.25)
  756. loss_bbox = ppdet.modeling.GIoULoss(loss_weight=2.0)
  757. assigner = ppdet.modeling.SimOTAAssigner(
  758. candidate_topk=10, iou_weight=6, num_classes=num_classes)
  759. nms = ppdet.modeling.MultiClassNMS(
  760. nms_top_k=nms_topk,
  761. keep_top_k=nms_keep_topk,
  762. score_threshold=nms_score_threshold,
  763. nms_threshold=nms_iou_threshold)
  764. head = ppdet.modeling.PicoHead(
  765. conv_feat=head_conv_feat,
  766. num_classes=num_classes,
  767. fpn_stride=[8, 16, 32, 64],
  768. prior_prob=0.01,
  769. reg_max=7,
  770. cell_offset=.5,
  771. loss_class=loss_class,
  772. loss_dfl=loss_dfl,
  773. loss_bbox=loss_bbox,
  774. assigner=assigner,
  775. feat_in_chan=neck_out_channels,
  776. nms=nms)
  777. params.update({
  778. 'backbone': backbone,
  779. 'neck': neck,
  780. 'head': head,
  781. })
  782. super(PicoDet, self).__init__(
  783. model_name='PicoDet', num_classes=num_classes, **params)
  784. def _compose_batch_transform(self, transforms, mode='train'):
  785. default_batch_transforms = [_BatchPad(pad_to_stride=32)]
  786. if mode == 'eval':
  787. collate_batch = True
  788. else:
  789. collate_batch = False
  790. custom_batch_transforms = []
  791. for i, op in enumerate(transforms.transforms):
  792. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  793. if mode != 'train':
  794. raise ValueError(
  795. "{} cannot be present in the {} transforms.".format(
  796. op.__class__.__name__, mode) +
  797. "Please check the {} transforms.".format(mode))
  798. custom_batch_transforms.insert(0, copy.deepcopy(op))
  799. batch_transforms = BatchCompose(
  800. custom_batch_transforms + default_batch_transforms,
  801. collate_batch=collate_batch)
  802. return batch_transforms
  803. def _fix_transforms_shape(self, image_shape):
  804. if getattr(self, 'test_transforms', None):
  805. has_resize_op = False
  806. resize_op_idx = -1
  807. normalize_op_idx = len(self.test_transforms.transforms)
  808. for idx, op in enumerate(self.test_transforms.transforms):
  809. name = op.__class__.__name__
  810. if name == 'Resize':
  811. has_resize_op = True
  812. resize_op_idx = idx
  813. if name == 'Normalize':
  814. normalize_op_idx = idx
  815. if not has_resize_op:
  816. self.test_transforms.transforms.insert(
  817. normalize_op_idx,
  818. Resize(
  819. target_size=image_shape, interp='CUBIC'))
  820. else:
  821. self.test_transforms.transforms[
  822. resize_op_idx].target_size = image_shape
  823. def _get_test_inputs(self, image_shape):
  824. if image_shape is not None:
  825. image_shape = self._check_image_shape(image_shape)
  826. self._fix_transforms_shape(image_shape[-2:])
  827. else:
  828. image_shape = [None, 3, 320, 320]
  829. if getattr(self, 'test_transforms', None):
  830. for idx, op in enumerate(self.test_transforms.transforms):
  831. name = op.__class__.__name__
  832. if name == 'Resize':
  833. image_shape = [None, 3] + list(
  834. self.test_transforms.transforms[idx].target_size)
  835. logging.warning(
  836. '[Important!!!] When exporting inference model for {}, '
  837. 'if fixed_input_shape is not set, it will be forcibly set to {}. '
  838. 'Please ensure image shape after transforms is {}, if not, '
  839. 'fixed_input_shape should be specified manually.'
  840. .format(self.__class__.__name__, image_shape, image_shape[1:]))
  841. self.fixed_input_shape = image_shape
  842. return self._define_input_spec(image_shape)
  843. def _pre_train(self, in_args):
  844. optimizer = in_args['optimizer']
  845. if optimizer is None:
  846. num_steps_each_epoch = len(in_args['train_dataset']) // in_args[
  847. 'train_batch_size']
  848. optimizer = self.default_optimizer(
  849. parameters=self.net.parameters(),
  850. learning_rate=in_args['learning_rate'],
  851. warmup_steps=in_args['warmup_steps'],
  852. warmup_start_lr=in_args['warmup_start_lr'],
  853. lr_decay_epochs=in_args['lr_decay_epochs'],
  854. lr_decay_gamma=in_args['lr_decay_gamma'],
  855. num_steps_each_epoch=in_args['num_steps_each_epoch'],
  856. reg_coeff=4e-05,
  857. scheduler='Cosine',
  858. num_epochs=in_args['num_epochs'])
  859. in_args['optimizer'] = optimizer
  860. return in_args
  861. class YOLOv3(BaseDetector):
  862. def __init__(self,
  863. num_classes=80,
  864. backbone='MobileNetV1',
  865. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  866. [59, 119], [116, 90], [156, 198], [373, 326]],
  867. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  868. ignore_threshold=0.7,
  869. nms_score_threshold=0.01,
  870. nms_topk=1000,
  871. nms_keep_topk=100,
  872. nms_iou_threshold=0.45,
  873. label_smooth=False,
  874. **params):
  875. self.init_params = locals()
  876. if backbone not in {
  877. 'MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3',
  878. 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34'
  879. }:
  880. raise ValueError(
  881. "backbone: {} is not supported. Please choose one of "
  882. "{'MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3', 'MobileNetV3_ssld', 'DarkNet53', "
  883. "'ResNet50_vd_dcn', 'ResNet34'}.".format(backbone))
  884. self.backbone_name = backbone
  885. if params.get('with_net', True):
  886. if paddlers.env_info['place'] == 'gpu' and paddlers.env_info[
  887. 'num'] > 1 and not os.environ.get('PADDLERS_EXPORT_STAGE'):
  888. norm_type = 'sync_bn'
  889. else:
  890. norm_type = 'bn'
  891. if 'MobileNetV1' in backbone:
  892. norm_type = 'bn'
  893. backbone = self._get_backbone('MobileNet', norm_type=norm_type)
  894. elif 'MobileNetV3' in backbone:
  895. backbone = self._get_backbone(
  896. 'MobileNetV3',
  897. norm_type=norm_type,
  898. feature_maps=[7, 13, 16])
  899. elif backbone == 'ResNet50_vd_dcn':
  900. backbone = self._get_backbone(
  901. 'ResNet',
  902. norm_type=norm_type,
  903. variant='d',
  904. return_idx=[1, 2, 3],
  905. dcn_v2_stages=[3],
  906. freeze_at=-1,
  907. freeze_norm=False)
  908. elif backbone == 'ResNet34':
  909. backbone = self._get_backbone(
  910. 'ResNet',
  911. depth=34,
  912. norm_type=norm_type,
  913. return_idx=[1, 2, 3],
  914. freeze_at=-1,
  915. freeze_norm=False,
  916. norm_decay=0.)
  917. else:
  918. backbone = self._get_backbone('DarkNet', norm_type=norm_type)
  919. neck = ppdet.modeling.YOLOv3FPN(
  920. norm_type=norm_type,
  921. in_channels=[i.channels for i in backbone.out_shape])
  922. loss = ppdet.modeling.YOLOv3Loss(
  923. num_classes=num_classes,
  924. ignore_thresh=ignore_threshold,
  925. label_smooth=label_smooth)
  926. yolo_head = ppdet.modeling.YOLOv3Head(
  927. in_channels=[i.channels for i in neck.out_shape],
  928. anchors=anchors,
  929. anchor_masks=anchor_masks,
  930. num_classes=num_classes,
  931. loss=loss)
  932. post_process = ppdet.modeling.BBoxPostProcess(
  933. decode=ppdet.modeling.YOLOBox(num_classes=num_classes),
  934. nms=ppdet.modeling.MultiClassNMS(
  935. score_threshold=nms_score_threshold,
  936. nms_top_k=nms_topk,
  937. keep_top_k=nms_keep_topk,
  938. nms_threshold=nms_iou_threshold))
  939. params.update({
  940. 'backbone': backbone,
  941. 'neck': neck,
  942. 'yolo_head': yolo_head,
  943. 'post_process': post_process
  944. })
  945. super(YOLOv3, self).__init__(
  946. model_name='YOLOv3', num_classes=num_classes, **params)
  947. self.anchors = anchors
  948. self.anchor_masks = anchor_masks
  949. def _compose_batch_transform(self, transforms, mode='train'):
  950. if mode == 'train':
  951. default_batch_transforms = [
  952. _BatchPad(pad_to_stride=-1), _NormalizeBox(),
  953. _PadBox(getattr(self, 'num_max_boxes', 50)), _BboxXYXY2XYWH(),
  954. _Gt2YoloTarget(
  955. anchor_masks=self.anchor_masks,
  956. anchors=self.anchors,
  957. downsample_ratios=getattr(self, 'downsample_ratios',
  958. [32, 16, 8]),
  959. num_classes=self.num_classes)
  960. ]
  961. else:
  962. default_batch_transforms = [_BatchPad(pad_to_stride=-1)]
  963. if mode == 'eval' and self.metric == 'voc':
  964. collate_batch = False
  965. else:
  966. collate_batch = True
  967. custom_batch_transforms = []
  968. for i, op in enumerate(transforms.transforms):
  969. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  970. if mode != 'train':
  971. raise ValueError(
  972. "{} cannot be present in the {} transforms. ".format(
  973. op.__class__.__name__, mode) +
  974. "Please check the {} transforms.".format(mode))
  975. custom_batch_transforms.insert(0, copy.deepcopy(op))
  976. batch_transforms = BatchCompose(
  977. custom_batch_transforms + default_batch_transforms,
  978. collate_batch=collate_batch)
  979. return batch_transforms
  980. def _fix_transforms_shape(self, image_shape):
  981. if getattr(self, 'test_transforms', None):
  982. has_resize_op = False
  983. resize_op_idx = -1
  984. normalize_op_idx = len(self.test_transforms.transforms)
  985. for idx, op in enumerate(self.test_transforms.transforms):
  986. name = op.__class__.__name__
  987. if name == 'Resize':
  988. has_resize_op = True
  989. resize_op_idx = idx
  990. if name == 'Normalize':
  991. normalize_op_idx = idx
  992. if not has_resize_op:
  993. self.test_transforms.transforms.insert(
  994. normalize_op_idx,
  995. Resize(
  996. target_size=image_shape, interp='CUBIC'))
  997. else:
  998. self.test_transforms.transforms[
  999. resize_op_idx].target_size = image_shape
  1000. class FasterRCNN(BaseDetector):
  1001. def __init__(self,
  1002. num_classes=80,
  1003. backbone='ResNet50',
  1004. with_fpn=True,
  1005. with_dcn=False,
  1006. aspect_ratios=[0.5, 1.0, 2.0],
  1007. anchor_sizes=[[32], [64], [128], [256], [512]],
  1008. keep_top_k=100,
  1009. nms_threshold=0.5,
  1010. score_threshold=0.05,
  1011. fpn_num_channels=256,
  1012. rpn_batch_size_per_im=256,
  1013. rpn_fg_fraction=0.5,
  1014. test_pre_nms_top_n=None,
  1015. test_post_nms_top_n=1000,
  1016. **params):
  1017. self.init_params = locals()
  1018. if backbone not in {
  1019. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34',
  1020. 'ResNet34_vd', 'ResNet101', 'ResNet101_vd', 'HRNet_W18'
  1021. }:
  1022. raise ValueError(
  1023. "backbone: {} is not supported. Please choose one of "
  1024. "{'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34', 'ResNet34_vd', "
  1025. "'ResNet101', 'ResNet101_vd', 'HRNet_W18'}.".format(backbone))
  1026. self.backbone_name = backbone
  1027. if params.get('with_net', True):
  1028. dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
  1029. if backbone == 'HRNet_W18':
  1030. if not with_fpn:
  1031. logging.warning(
  1032. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1033. format(backbone))
  1034. with_fpn = True
  1035. if with_dcn:
  1036. logging.warning(
  1037. "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  1038. format(backbone))
  1039. backbone = self._get_backbone(
  1040. 'HRNet', width=18, freeze_at=0, return_idx=[0, 1, 2, 3])
  1041. elif backbone == 'ResNet50_vd_ssld':
  1042. if not with_fpn:
  1043. logging.warning(
  1044. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1045. format(backbone))
  1046. with_fpn = True
  1047. backbone = self._get_backbone(
  1048. 'ResNet',
  1049. variant='d',
  1050. norm_type='bn',
  1051. freeze_at=0,
  1052. return_idx=[0, 1, 2, 3],
  1053. num_stages=4,
  1054. lr_mult_list=[0.05, 0.05, 0.1, 0.15],
  1055. dcn_v2_stages=dcn_v2_stages)
  1056. elif 'ResNet50' in backbone:
  1057. if with_fpn:
  1058. backbone = self._get_backbone(
  1059. 'ResNet',
  1060. variant='d' if '_vd' in backbone else 'b',
  1061. norm_type='bn',
  1062. freeze_at=0,
  1063. return_idx=[0, 1, 2, 3],
  1064. num_stages=4,
  1065. dcn_v2_stages=dcn_v2_stages)
  1066. else:
  1067. if with_dcn:
  1068. logging.warning(
  1069. "Backbone {} without fpn should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  1070. format(backbone))
  1071. backbone = self._get_backbone(
  1072. 'ResNet',
  1073. variant='d' if '_vd' in backbone else 'b',
  1074. norm_type='bn',
  1075. freeze_at=0,
  1076. return_idx=[2],
  1077. num_stages=3)
  1078. elif 'ResNet34' in backbone:
  1079. if not with_fpn:
  1080. logging.warning(
  1081. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1082. format(backbone))
  1083. with_fpn = True
  1084. backbone = self._get_backbone(
  1085. 'ResNet',
  1086. depth=34,
  1087. variant='d' if 'vd' in backbone else 'b',
  1088. norm_type='bn',
  1089. freeze_at=0,
  1090. return_idx=[0, 1, 2, 3],
  1091. num_stages=4,
  1092. dcn_v2_stages=dcn_v2_stages)
  1093. else:
  1094. if not with_fpn:
  1095. logging.warning(
  1096. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1097. format(backbone))
  1098. with_fpn = True
  1099. backbone = self._get_backbone(
  1100. 'ResNet',
  1101. depth=101,
  1102. variant='d' if 'vd' in backbone else 'b',
  1103. norm_type='bn',
  1104. freeze_at=0,
  1105. return_idx=[0, 1, 2, 3],
  1106. num_stages=4,
  1107. dcn_v2_stages=dcn_v2_stages)
  1108. rpn_in_channel = backbone.out_shape[0].channels
  1109. if with_fpn:
  1110. self.backbone_name = self.backbone_name + '_fpn'
  1111. if 'HRNet' in self.backbone_name:
  1112. neck = ppdet.modeling.HRFPN(
  1113. in_channels=[i.channels for i in backbone.out_shape],
  1114. out_channel=fpn_num_channels,
  1115. spatial_scales=[
  1116. 1.0 / i.stride for i in backbone.out_shape
  1117. ],
  1118. share_conv=False)
  1119. else:
  1120. neck = ppdet.modeling.FPN(
  1121. in_channels=[i.channels for i in backbone.out_shape],
  1122. out_channel=fpn_num_channels,
  1123. spatial_scales=[
  1124. 1.0 / i.stride for i in backbone.out_shape
  1125. ])
  1126. rpn_in_channel = neck.out_shape[0].channels
  1127. anchor_generator_cfg = {
  1128. 'aspect_ratios': aspect_ratios,
  1129. 'anchor_sizes': anchor_sizes,
  1130. 'strides': [4, 8, 16, 32, 64]
  1131. }
  1132. train_proposal_cfg = {
  1133. 'min_size': 0.0,
  1134. 'nms_thresh': .7,
  1135. 'pre_nms_top_n': 2000,
  1136. 'post_nms_top_n': 1000,
  1137. 'topk_after_collect': True
  1138. }
  1139. test_proposal_cfg = {
  1140. 'min_size': 0.0,
  1141. 'nms_thresh': .7,
  1142. 'pre_nms_top_n': 1000
  1143. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1144. 'post_nms_top_n': test_post_nms_top_n
  1145. }
  1146. head = ppdet.modeling.TwoFCHead(
  1147. in_channel=neck.out_shape[0].channels, out_channel=1024)
  1148. roi_extractor_cfg = {
  1149. 'resolution': 7,
  1150. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1151. 'sampling_ratio': 0,
  1152. 'aligned': True
  1153. }
  1154. with_pool = False
  1155. else:
  1156. neck = None
  1157. anchor_generator_cfg = {
  1158. 'aspect_ratios': aspect_ratios,
  1159. 'anchor_sizes': anchor_sizes,
  1160. 'strides': [16]
  1161. }
  1162. train_proposal_cfg = {
  1163. 'min_size': 0.0,
  1164. 'nms_thresh': .7,
  1165. 'pre_nms_top_n': 12000,
  1166. 'post_nms_top_n': 2000,
  1167. 'topk_after_collect': False
  1168. }
  1169. test_proposal_cfg = {
  1170. 'min_size': 0.0,
  1171. 'nms_thresh': .7,
  1172. 'pre_nms_top_n': 6000
  1173. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1174. 'post_nms_top_n': test_post_nms_top_n
  1175. }
  1176. head = ppdet.modeling.Res5Head()
  1177. roi_extractor_cfg = {
  1178. 'resolution': 14,
  1179. 'spatial_scale':
  1180. [1. / i.stride for i in backbone.out_shape],
  1181. 'sampling_ratio': 0,
  1182. 'aligned': True
  1183. }
  1184. with_pool = True
  1185. rpn_target_assign_cfg = {
  1186. 'batch_size_per_im': rpn_batch_size_per_im,
  1187. 'fg_fraction': rpn_fg_fraction,
  1188. 'negative_overlap': .3,
  1189. 'positive_overlap': .7,
  1190. 'use_random': True
  1191. }
  1192. rpn_head = ppdet.modeling.RPNHead(
  1193. anchor_generator=anchor_generator_cfg,
  1194. rpn_target_assign=rpn_target_assign_cfg,
  1195. train_proposal=train_proposal_cfg,
  1196. test_proposal=test_proposal_cfg,
  1197. in_channel=rpn_in_channel)
  1198. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  1199. bbox_head = ppdet.modeling.BBoxHead(
  1200. head=head,
  1201. in_channel=head.out_shape[0].channels,
  1202. roi_extractor=roi_extractor_cfg,
  1203. with_pool=with_pool,
  1204. bbox_assigner=bbox_assigner,
  1205. num_classes=num_classes)
  1206. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  1207. num_classes=num_classes,
  1208. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  1209. nms=ppdet.modeling.MultiClassNMS(
  1210. score_threshold=score_threshold,
  1211. keep_top_k=keep_top_k,
  1212. nms_threshold=nms_threshold))
  1213. params.update({
  1214. 'backbone': backbone,
  1215. 'neck': neck,
  1216. 'rpn_head': rpn_head,
  1217. 'bbox_head': bbox_head,
  1218. 'bbox_post_process': bbox_post_process
  1219. })
  1220. else:
  1221. if backbone not in {'ResNet50', 'ResNet50_vd'}:
  1222. with_fpn = True
  1223. self.with_fpn = with_fpn
  1224. super(FasterRCNN, self).__init__(
  1225. model_name='FasterRCNN', num_classes=num_classes, **params)
  1226. def _pre_train(self, in_args):
  1227. train_dataset = in_args['train_dataset']
  1228. if train_dataset.pos_num < len(train_dataset.file_list):
  1229. # In-place modification
  1230. train_dataset.num_workers = 0
  1231. return in_args
  1232. def _compose_batch_transform(self, transforms, mode='train'):
  1233. if mode == 'train':
  1234. default_batch_transforms = [
  1235. _BatchPad(pad_to_stride=32 if self.with_fpn else -1)
  1236. ]
  1237. else:
  1238. default_batch_transforms = [
  1239. _BatchPad(pad_to_stride=32 if self.with_fpn else -1)
  1240. ]
  1241. custom_batch_transforms = []
  1242. for i, op in enumerate(transforms.transforms):
  1243. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  1244. if mode != 'train':
  1245. raise ValueError(
  1246. "{} cannot be present in the {} transforms. ".format(
  1247. op.__class__.__name__, mode) +
  1248. "Please check the {} transforms.".format(mode))
  1249. custom_batch_transforms.insert(0, copy.deepcopy(op))
  1250. batch_transforms = BatchCompose(
  1251. custom_batch_transforms + default_batch_transforms,
  1252. collate_batch=False)
  1253. return batch_transforms
  1254. def _fix_transforms_shape(self, image_shape):
  1255. if getattr(self, 'test_transforms', None):
  1256. has_resize_op = False
  1257. resize_op_idx = -1
  1258. normalize_op_idx = len(self.test_transforms.transforms)
  1259. for idx, op in enumerate(self.test_transforms.transforms):
  1260. name = op.__class__.__name__
  1261. if name == 'ResizeByShort':
  1262. has_resize_op = True
  1263. resize_op_idx = idx
  1264. if name == 'Normalize':
  1265. normalize_op_idx = idx
  1266. if not has_resize_op:
  1267. self.test_transforms.transforms.insert(
  1268. normalize_op_idx,
  1269. Resize(
  1270. target_size=image_shape,
  1271. keep_ratio=True,
  1272. interp='CUBIC'))
  1273. else:
  1274. self.test_transforms.transforms[resize_op_idx] = Resize(
  1275. target_size=image_shape, keep_ratio=True, interp='CUBIC')
  1276. self.test_transforms.transforms.append(
  1277. Pad(im_padding_value=[0., 0., 0.]))
  1278. def _get_test_inputs(self, image_shape):
  1279. if image_shape is not None:
  1280. image_shape = self._check_image_shape(image_shape)
  1281. self._fix_transforms_shape(image_shape[-2:])
  1282. else:
  1283. image_shape = [None, 3, -1, -1]
  1284. if self.with_fpn:
  1285. self.test_transforms.transforms.append(
  1286. Pad(im_padding_value=[0., 0., 0.]))
  1287. self.fixed_input_shape = image_shape
  1288. return self._define_input_spec(image_shape)
  1289. class PPYOLO(YOLOv3):
  1290. def __init__(self,
  1291. num_classes=80,
  1292. backbone='ResNet50_vd_dcn',
  1293. anchors=None,
  1294. anchor_masks=None,
  1295. use_coord_conv=True,
  1296. use_iou_aware=True,
  1297. use_spp=True,
  1298. use_drop_block=True,
  1299. scale_x_y=1.05,
  1300. ignore_threshold=0.7,
  1301. label_smooth=False,
  1302. use_iou_loss=True,
  1303. use_matrix_nms=True,
  1304. nms_score_threshold=0.01,
  1305. nms_topk=-1,
  1306. nms_keep_topk=100,
  1307. nms_iou_threshold=0.45,
  1308. **params):
  1309. self.init_params = locals()
  1310. if backbone not in {
  1311. 'ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large',
  1312. 'MobileNetV3_small'
  1313. }:
  1314. raise ValueError(
  1315. "backbone: {} is not supported. Please choose one of "
  1316. "{'ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large', 'MobileNetV3_small'}.".
  1317. format(backbone))
  1318. self.backbone_name = backbone
  1319. self.downsample_ratios = [
  1320. 32, 16, 8
  1321. ] if backbone == 'ResNet50_vd_dcn' else [32, 16]
  1322. if params.get('with_net', True):
  1323. if paddlers.env_info['place'] == 'gpu' and paddlers.env_info[
  1324. 'num'] > 1 and not os.environ.get('PADDLERS_EXPORT_STAGE'):
  1325. norm_type = 'sync_bn'
  1326. else:
  1327. norm_type = 'bn'
  1328. if anchors is None and anchor_masks is None:
  1329. if 'MobileNetV3' in backbone:
  1330. anchors = [[11, 18], [34, 47], [51, 126], [115, 71],
  1331. [120, 195], [254, 235]]
  1332. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  1333. elif backbone == 'ResNet50_vd_dcn':
  1334. anchors = [[10, 13], [16, 30], [33, 23], [30, 61],
  1335. [62, 45], [59, 119], [116, 90], [156, 198],
  1336. [373, 326]]
  1337. anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  1338. else:
  1339. anchors = [[10, 14], [23, 27], [37, 58], [81, 82],
  1340. [135, 169], [344, 319]]
  1341. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  1342. elif anchors is None or anchor_masks is None:
  1343. raise ValueError("Please define both anchors and anchor_masks.")
  1344. if backbone == 'ResNet50_vd_dcn':
  1345. backbone = self._get_backbone(
  1346. 'ResNet',
  1347. variant='d',
  1348. norm_type=norm_type,
  1349. return_idx=[1, 2, 3],
  1350. dcn_v2_stages=[3],
  1351. freeze_at=-1,
  1352. freeze_norm=False,
  1353. norm_decay=0.)
  1354. elif backbone == 'ResNet18_vd':
  1355. backbone = self._get_backbone(
  1356. 'ResNet',
  1357. depth=18,
  1358. variant='d',
  1359. norm_type=norm_type,
  1360. return_idx=[2, 3],
  1361. freeze_at=-1,
  1362. freeze_norm=False,
  1363. norm_decay=0.)
  1364. elif backbone == 'MobileNetV3_large':
  1365. backbone = self._get_backbone(
  1366. 'MobileNetV3',
  1367. model_name='large',
  1368. norm_type=norm_type,
  1369. scale=1,
  1370. with_extra_blocks=False,
  1371. extra_block_filters=[],
  1372. feature_maps=[13, 16])
  1373. elif backbone == 'MobileNetV3_small':
  1374. backbone = self._get_backbone(
  1375. 'MobileNetV3',
  1376. model_name='small',
  1377. norm_type=norm_type,
  1378. scale=1,
  1379. with_extra_blocks=False,
  1380. extra_block_filters=[],
  1381. feature_maps=[9, 12])
  1382. neck = ppdet.modeling.PPYOLOFPN(
  1383. norm_type=norm_type,
  1384. in_channels=[i.channels for i in backbone.out_shape],
  1385. coord_conv=use_coord_conv,
  1386. drop_block=use_drop_block,
  1387. spp=use_spp,
  1388. conv_block_num=0
  1389. if ('MobileNetV3' in self.backbone_name or
  1390. self.backbone_name == 'ResNet18_vd') else 2)
  1391. loss = ppdet.modeling.YOLOv3Loss(
  1392. num_classes=num_classes,
  1393. ignore_thresh=ignore_threshold,
  1394. downsample=self.downsample_ratios,
  1395. label_smooth=label_smooth,
  1396. scale_x_y=scale_x_y,
  1397. iou_loss=ppdet.modeling.IouLoss(
  1398. loss_weight=2.5, loss_square=True)
  1399. if use_iou_loss else None,
  1400. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1401. if use_iou_aware else None)
  1402. yolo_head = ppdet.modeling.YOLOv3Head(
  1403. in_channels=[i.channels for i in neck.out_shape],
  1404. anchors=anchors,
  1405. anchor_masks=anchor_masks,
  1406. num_classes=num_classes,
  1407. loss=loss,
  1408. iou_aware=use_iou_aware)
  1409. if use_matrix_nms:
  1410. nms = ppdet.modeling.MatrixNMS(
  1411. keep_top_k=nms_keep_topk,
  1412. score_threshold=nms_score_threshold,
  1413. post_threshold=.05
  1414. if 'MobileNetV3' in self.backbone_name else .01,
  1415. nms_top_k=nms_topk,
  1416. background_label=-1)
  1417. else:
  1418. nms = ppdet.modeling.MultiClassNMS(
  1419. score_threshold=nms_score_threshold,
  1420. nms_top_k=nms_topk,
  1421. keep_top_k=nms_keep_topk,
  1422. nms_threshold=nms_iou_threshold)
  1423. post_process = ppdet.modeling.BBoxPostProcess(
  1424. decode=ppdet.modeling.YOLOBox(
  1425. num_classes=num_classes,
  1426. conf_thresh=.005
  1427. if 'MobileNetV3' in self.backbone_name else .01,
  1428. scale_x_y=scale_x_y),
  1429. nms=nms)
  1430. params.update({
  1431. 'backbone': backbone,
  1432. 'neck': neck,
  1433. 'yolo_head': yolo_head,
  1434. 'post_process': post_process
  1435. })
  1436. super(YOLOv3, self).__init__(
  1437. model_name='YOLOv3', num_classes=num_classes, **params)
  1438. self.anchors = anchors
  1439. self.anchor_masks = anchor_masks
  1440. self.model_name = 'PPYOLO'
  1441. def _get_test_inputs(self, image_shape):
  1442. if image_shape is not None:
  1443. image_shape = self._check_image_shape(image_shape)
  1444. self._fix_transforms_shape(image_shape[-2:])
  1445. else:
  1446. image_shape = [None, 3, 608, 608]
  1447. if getattr(self, 'test_transforms', None):
  1448. for idx, op in enumerate(self.test_transforms.transforms):
  1449. name = op.__class__.__name__
  1450. if name == 'Resize':
  1451. image_shape = [None, 3] + list(
  1452. self.test_transforms.transforms[idx].target_size)
  1453. logging.warning(
  1454. '[Important!!!] When exporting inference model for {}, '
  1455. 'if fixed_input_shape is not set, it will be forcibly set to {}. '
  1456. 'Please ensure image shape after transforms is {}, if not, '
  1457. 'fixed_input_shape should be specified manually.'
  1458. .format(self.__class__.__name__, image_shape, image_shape[1:]))
  1459. self.fixed_input_shape = image_shape
  1460. return self._define_input_spec(image_shape)
  1461. class PPYOLOTiny(YOLOv3):
  1462. def __init__(self,
  1463. num_classes=80,
  1464. backbone='MobileNetV3',
  1465. anchors=[[10, 15], [24, 36], [72, 42], [35, 87], [102, 96],
  1466. [60, 170], [220, 125], [128, 222], [264, 266]],
  1467. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1468. use_iou_aware=False,
  1469. use_spp=True,
  1470. use_drop_block=True,
  1471. scale_x_y=1.05,
  1472. ignore_threshold=0.5,
  1473. label_smooth=False,
  1474. use_iou_loss=True,
  1475. use_matrix_nms=False,
  1476. nms_score_threshold=0.005,
  1477. nms_topk=1000,
  1478. nms_keep_topk=100,
  1479. nms_iou_threshold=0.45,
  1480. **params):
  1481. self.init_params = locals()
  1482. if backbone != 'MobileNetV3':
  1483. logging.warning("PPYOLOTiny only supports MobileNetV3 as backbone. "
  1484. "Backbone is forcibly set to MobileNetV3.")
  1485. self.backbone_name = 'MobileNetV3'
  1486. self.downsample_ratios = [32, 16, 8]
  1487. if params.get('with_net', True):
  1488. if paddlers.env_info['place'] == 'gpu' and paddlers.env_info[
  1489. 'num'] > 1 and not os.environ.get('PADDLERS_EXPORT_STAGE'):
  1490. norm_type = 'sync_bn'
  1491. else:
  1492. norm_type = 'bn'
  1493. backbone = self._get_backbone(
  1494. 'MobileNetV3',
  1495. model_name='large',
  1496. norm_type=norm_type,
  1497. scale=.5,
  1498. with_extra_blocks=False,
  1499. extra_block_filters=[],
  1500. feature_maps=[7, 13, 16])
  1501. neck = ppdet.modeling.PPYOLOTinyFPN(
  1502. detection_block_channels=[160, 128, 96],
  1503. in_channels=[i.channels for i in backbone.out_shape],
  1504. spp=use_spp,
  1505. drop_block=use_drop_block)
  1506. loss = ppdet.modeling.YOLOv3Loss(
  1507. num_classes=num_classes,
  1508. ignore_thresh=ignore_threshold,
  1509. downsample=self.downsample_ratios,
  1510. label_smooth=label_smooth,
  1511. scale_x_y=scale_x_y,
  1512. iou_loss=ppdet.modeling.IouLoss(
  1513. loss_weight=2.5, loss_square=True)
  1514. if use_iou_loss else None,
  1515. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1516. if use_iou_aware else None)
  1517. yolo_head = ppdet.modeling.YOLOv3Head(
  1518. in_channels=[i.channels for i in neck.out_shape],
  1519. anchors=anchors,
  1520. anchor_masks=anchor_masks,
  1521. num_classes=num_classes,
  1522. loss=loss,
  1523. iou_aware=use_iou_aware)
  1524. if use_matrix_nms:
  1525. nms = ppdet.modeling.MatrixNMS(
  1526. keep_top_k=nms_keep_topk,
  1527. score_threshold=nms_score_threshold,
  1528. post_threshold=.05,
  1529. nms_top_k=nms_topk,
  1530. background_label=-1)
  1531. else:
  1532. nms = ppdet.modeling.MultiClassNMS(
  1533. score_threshold=nms_score_threshold,
  1534. nms_top_k=nms_topk,
  1535. keep_top_k=nms_keep_topk,
  1536. nms_threshold=nms_iou_threshold)
  1537. post_process = ppdet.modeling.BBoxPostProcess(
  1538. decode=ppdet.modeling.YOLOBox(
  1539. num_classes=num_classes,
  1540. conf_thresh=.005,
  1541. downsample_ratio=32,
  1542. clip_bbox=True,
  1543. scale_x_y=scale_x_y),
  1544. nms=nms)
  1545. params.update({
  1546. 'backbone': backbone,
  1547. 'neck': neck,
  1548. 'yolo_head': yolo_head,
  1549. 'post_process': post_process
  1550. })
  1551. super(YOLOv3, self).__init__(
  1552. model_name='YOLOv3', num_classes=num_classes, **params)
  1553. self.anchors = anchors
  1554. self.anchor_masks = anchor_masks
  1555. self.model_name = 'PPYOLOTiny'
  1556. def _get_test_inputs(self, image_shape):
  1557. if image_shape is not None:
  1558. image_shape = self._check_image_shape(image_shape)
  1559. self._fix_transforms_shape(image_shape[-2:])
  1560. else:
  1561. image_shape = [None, 3, 320, 320]
  1562. if getattr(self, 'test_transforms', None):
  1563. for idx, op in enumerate(self.test_transforms.transforms):
  1564. name = op.__class__.__name__
  1565. if name == 'Resize':
  1566. image_shape = [None, 3] + list(
  1567. self.test_transforms.transforms[idx].target_size)
  1568. logging.warning(
  1569. '[Important!!!] When exporting inference model for {},'.format(
  1570. self.__class__.__name__) +
  1571. ' if fixed_input_shape is not set, it will be forcibly set to {}. '.
  1572. format(image_shape) +
  1573. 'Please check image shape after transforms is {}, if not, fixed_input_shape '.
  1574. format(image_shape[1:]) + 'should be specified manually.')
  1575. self.fixed_input_shape = image_shape
  1576. return self._define_input_spec(image_shape)
  1577. class PPYOLOv2(YOLOv3):
  1578. def __init__(self,
  1579. num_classes=80,
  1580. backbone='ResNet50_vd_dcn',
  1581. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  1582. [59, 119], [116, 90], [156, 198], [373, 326]],
  1583. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1584. use_iou_aware=True,
  1585. use_spp=True,
  1586. use_drop_block=True,
  1587. scale_x_y=1.05,
  1588. ignore_threshold=0.7,
  1589. label_smooth=False,
  1590. use_iou_loss=True,
  1591. use_matrix_nms=True,
  1592. nms_score_threshold=0.01,
  1593. nms_topk=-1,
  1594. nms_keep_topk=100,
  1595. nms_iou_threshold=0.45,
  1596. **params):
  1597. self.init_params = locals()
  1598. if backbone not in {'ResNet50_vd_dcn', 'ResNet101_vd_dcn'}:
  1599. raise ValueError(
  1600. "backbone: {} is not supported. Please choose one of "
  1601. "{'ResNet50_vd_dcn', 'ResNet101_vd_dcn'}.".format(backbone))
  1602. self.backbone_name = backbone
  1603. self.downsample_ratios = [32, 16, 8]
  1604. if params.get('with_net', True):
  1605. if paddlers.env_info['place'] == 'gpu' and paddlers.env_info[
  1606. 'num'] > 1 and not os.environ.get('PADDLERS_EXPORT_STAGE'):
  1607. norm_type = 'sync_bn'
  1608. else:
  1609. norm_type = 'bn'
  1610. if backbone == 'ResNet50_vd_dcn':
  1611. backbone = self._get_backbone(
  1612. 'ResNet',
  1613. variant='d',
  1614. norm_type=norm_type,
  1615. return_idx=[1, 2, 3],
  1616. dcn_v2_stages=[3],
  1617. freeze_at=-1,
  1618. freeze_norm=False,
  1619. norm_decay=0.)
  1620. elif backbone == 'ResNet101_vd_dcn':
  1621. backbone = self._get_backbone(
  1622. 'ResNet',
  1623. depth=101,
  1624. variant='d',
  1625. norm_type=norm_type,
  1626. return_idx=[1, 2, 3],
  1627. dcn_v2_stages=[3],
  1628. freeze_at=-1,
  1629. freeze_norm=False,
  1630. norm_decay=0.)
  1631. neck = ppdet.modeling.PPYOLOPAN(
  1632. norm_type=norm_type,
  1633. in_channels=[i.channels for i in backbone.out_shape],
  1634. drop_block=use_drop_block,
  1635. block_size=3,
  1636. keep_prob=.9,
  1637. spp=use_spp)
  1638. loss = ppdet.modeling.YOLOv3Loss(
  1639. num_classes=num_classes,
  1640. ignore_thresh=ignore_threshold,
  1641. downsample=self.downsample_ratios,
  1642. label_smooth=label_smooth,
  1643. scale_x_y=scale_x_y,
  1644. iou_loss=ppdet.modeling.IouLoss(
  1645. loss_weight=2.5, loss_square=True)
  1646. if use_iou_loss else None,
  1647. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1648. if use_iou_aware else None)
  1649. yolo_head = ppdet.modeling.YOLOv3Head(
  1650. in_channels=[i.channels for i in neck.out_shape],
  1651. anchors=anchors,
  1652. anchor_masks=anchor_masks,
  1653. num_classes=num_classes,
  1654. loss=loss,
  1655. iou_aware=use_iou_aware,
  1656. iou_aware_factor=.5)
  1657. if use_matrix_nms:
  1658. nms = ppdet.modeling.MatrixNMS(
  1659. keep_top_k=nms_keep_topk,
  1660. score_threshold=nms_score_threshold,
  1661. post_threshold=.01,
  1662. nms_top_k=nms_topk,
  1663. background_label=-1)
  1664. else:
  1665. nms = ppdet.modeling.MultiClassNMS(
  1666. score_threshold=nms_score_threshold,
  1667. nms_top_k=nms_topk,
  1668. keep_top_k=nms_keep_topk,
  1669. nms_threshold=nms_iou_threshold)
  1670. post_process = ppdet.modeling.BBoxPostProcess(
  1671. decode=ppdet.modeling.YOLOBox(
  1672. num_classes=num_classes,
  1673. conf_thresh=.01,
  1674. downsample_ratio=32,
  1675. clip_bbox=True,
  1676. scale_x_y=scale_x_y),
  1677. nms=nms)
  1678. params.update({
  1679. 'backbone': backbone,
  1680. 'neck': neck,
  1681. 'yolo_head': yolo_head,
  1682. 'post_process': post_process
  1683. })
  1684. super(YOLOv3, self).__init__(
  1685. model_name='YOLOv3', num_classes=num_classes, **params)
  1686. self.anchors = anchors
  1687. self.anchor_masks = anchor_masks
  1688. self.model_name = 'PPYOLOv2'
  1689. def _get_test_inputs(self, image_shape):
  1690. if image_shape is not None:
  1691. image_shape = self._check_image_shape(image_shape)
  1692. self._fix_transforms_shape(image_shape[-2:])
  1693. else:
  1694. image_shape = [None, 3, 640, 640]
  1695. if getattr(self, 'test_transforms', None):
  1696. for idx, op in enumerate(self.test_transforms.transforms):
  1697. name = op.__class__.__name__
  1698. if name == 'Resize':
  1699. image_shape = [None, 3] + list(
  1700. self.test_transforms.transforms[idx].target_size)
  1701. logging.warning(
  1702. '[Important!!!] When exporting inference model for {},'.format(
  1703. self.__class__.__name__) +
  1704. ' if fixed_input_shape is not set, it will be forcibly set to {}. '.
  1705. format(image_shape) +
  1706. 'Please check image shape after transforms is {}, if not, fixed_input_shape '.
  1707. format(image_shape[1:]) + 'should be specified manually.')
  1708. self.fixed_input_shape = image_shape
  1709. return self._define_input_spec(image_shape)
  1710. class MaskRCNN(BaseDetector):
  1711. def __init__(self,
  1712. num_classes=80,
  1713. backbone='ResNet50_vd',
  1714. with_fpn=True,
  1715. with_dcn=False,
  1716. aspect_ratios=[0.5, 1.0, 2.0],
  1717. anchor_sizes=[[32], [64], [128], [256], [512]],
  1718. keep_top_k=100,
  1719. nms_threshold=0.5,
  1720. score_threshold=0.05,
  1721. fpn_num_channels=256,
  1722. rpn_batch_size_per_im=256,
  1723. rpn_fg_fraction=0.5,
  1724. test_pre_nms_top_n=None,
  1725. test_post_nms_top_n=1000,
  1726. **params):
  1727. self.init_params = locals()
  1728. if backbone not in {
  1729. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101',
  1730. 'ResNet101_vd'
  1731. }:
  1732. raise ValueError(
  1733. "backbone: {} is not supported. Please choose one of "
  1734. "{'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101', 'ResNet101_vd'}.".
  1735. format(backbone))
  1736. self.backbone_name = backbone + '_fpn' if with_fpn else backbone
  1737. dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
  1738. if params.get('with_net', True):
  1739. if backbone == 'ResNet50':
  1740. if with_fpn:
  1741. backbone = self._get_backbone(
  1742. 'ResNet',
  1743. norm_type='bn',
  1744. freeze_at=0,
  1745. return_idx=[0, 1, 2, 3],
  1746. num_stages=4,
  1747. dcn_v2_stages=dcn_v2_stages)
  1748. else:
  1749. if with_dcn:
  1750. logging.warning(
  1751. "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  1752. format(backbone))
  1753. backbone = self._get_backbone(
  1754. 'ResNet',
  1755. norm_type='bn',
  1756. freeze_at=0,
  1757. return_idx=[2],
  1758. num_stages=3)
  1759. elif 'ResNet50_vd' in backbone:
  1760. if not with_fpn:
  1761. logging.warning(
  1762. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1763. format(backbone))
  1764. with_fpn = True
  1765. backbone = self._get_backbone(
  1766. 'ResNet',
  1767. variant='d',
  1768. norm_type='bn',
  1769. freeze_at=0,
  1770. return_idx=[0, 1, 2, 3],
  1771. num_stages=4,
  1772. lr_mult_list=[0.05, 0.05, 0.1, 0.15]
  1773. if '_ssld' in backbone else [1.0, 1.0, 1.0, 1.0],
  1774. dcn_v2_stages=dcn_v2_stages)
  1775. else:
  1776. if not with_fpn:
  1777. logging.warning(
  1778. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1779. format(backbone))
  1780. with_fpn = True
  1781. backbone = self._get_backbone(
  1782. 'ResNet',
  1783. variant='d' if '_vd' in backbone else 'b',
  1784. depth=101,
  1785. norm_type='bn',
  1786. freeze_at=0,
  1787. return_idx=[0, 1, 2, 3],
  1788. num_stages=4,
  1789. dcn_v2_stages=dcn_v2_stages)
  1790. rpn_in_channel = backbone.out_shape[0].channels
  1791. if with_fpn:
  1792. neck = ppdet.modeling.FPN(
  1793. in_channels=[i.channels for i in backbone.out_shape],
  1794. out_channel=fpn_num_channels,
  1795. spatial_scales=[
  1796. 1.0 / i.stride for i in backbone.out_shape
  1797. ])
  1798. rpn_in_channel = neck.out_shape[0].channels
  1799. anchor_generator_cfg = {
  1800. 'aspect_ratios': aspect_ratios,
  1801. 'anchor_sizes': anchor_sizes,
  1802. 'strides': [4, 8, 16, 32, 64]
  1803. }
  1804. train_proposal_cfg = {
  1805. 'min_size': 0.0,
  1806. 'nms_thresh': .7,
  1807. 'pre_nms_top_n': 2000,
  1808. 'post_nms_top_n': 1000,
  1809. 'topk_after_collect': True
  1810. }
  1811. test_proposal_cfg = {
  1812. 'min_size': 0.0,
  1813. 'nms_thresh': .7,
  1814. 'pre_nms_top_n': 1000
  1815. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1816. 'post_nms_top_n': test_post_nms_top_n
  1817. }
  1818. bb_head = ppdet.modeling.TwoFCHead(
  1819. in_channel=neck.out_shape[0].channels, out_channel=1024)
  1820. bb_roi_extractor_cfg = {
  1821. 'resolution': 7,
  1822. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1823. 'sampling_ratio': 0,
  1824. 'aligned': True
  1825. }
  1826. with_pool = False
  1827. m_head = ppdet.modeling.MaskFeat(
  1828. in_channel=neck.out_shape[0].channels,
  1829. out_channel=256,
  1830. num_convs=4)
  1831. m_roi_extractor_cfg = {
  1832. 'resolution': 14,
  1833. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1834. 'sampling_ratio': 0,
  1835. 'aligned': True
  1836. }
  1837. mask_assigner = MaskAssigner(
  1838. num_classes=num_classes, mask_resolution=28)
  1839. share_bbox_feat = False
  1840. else:
  1841. neck = None
  1842. anchor_generator_cfg = {
  1843. 'aspect_ratios': aspect_ratios,
  1844. 'anchor_sizes': anchor_sizes,
  1845. 'strides': [16]
  1846. }
  1847. train_proposal_cfg = {
  1848. 'min_size': 0.0,
  1849. 'nms_thresh': .7,
  1850. 'pre_nms_top_n': 12000,
  1851. 'post_nms_top_n': 2000,
  1852. 'topk_after_collect': False
  1853. }
  1854. test_proposal_cfg = {
  1855. 'min_size': 0.0,
  1856. 'nms_thresh': .7,
  1857. 'pre_nms_top_n': 6000
  1858. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1859. 'post_nms_top_n': test_post_nms_top_n
  1860. }
  1861. bb_head = ppdet.modeling.Res5Head()
  1862. bb_roi_extractor_cfg = {
  1863. 'resolution': 14,
  1864. 'spatial_scale':
  1865. [1. / i.stride for i in backbone.out_shape],
  1866. 'sampling_ratio': 0,
  1867. 'aligned': True
  1868. }
  1869. with_pool = True
  1870. m_head = ppdet.modeling.MaskFeat(
  1871. in_channel=bb_head.out_shape[0].channels,
  1872. out_channel=256,
  1873. num_convs=0)
  1874. m_roi_extractor_cfg = {
  1875. 'resolution': 14,
  1876. 'spatial_scale':
  1877. [1. / i.stride for i in backbone.out_shape],
  1878. 'sampling_ratio': 0,
  1879. 'aligned': True
  1880. }
  1881. mask_assigner = MaskAssigner(
  1882. num_classes=num_classes, mask_resolution=14)
  1883. share_bbox_feat = True
  1884. rpn_target_assign_cfg = {
  1885. 'batch_size_per_im': rpn_batch_size_per_im,
  1886. 'fg_fraction': rpn_fg_fraction,
  1887. 'negative_overlap': .3,
  1888. 'positive_overlap': .7,
  1889. 'use_random': True
  1890. }
  1891. rpn_head = ppdet.modeling.RPNHead(
  1892. anchor_generator=anchor_generator_cfg,
  1893. rpn_target_assign=rpn_target_assign_cfg,
  1894. train_proposal=train_proposal_cfg,
  1895. test_proposal=test_proposal_cfg,
  1896. in_channel=rpn_in_channel)
  1897. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  1898. bbox_head = ppdet.modeling.BBoxHead(
  1899. head=bb_head,
  1900. in_channel=bb_head.out_shape[0].channels,
  1901. roi_extractor=bb_roi_extractor_cfg,
  1902. with_pool=with_pool,
  1903. bbox_assigner=bbox_assigner,
  1904. num_classes=num_classes)
  1905. mask_head = ppdet.modeling.MaskHead(
  1906. head=m_head,
  1907. roi_extractor=m_roi_extractor_cfg,
  1908. mask_assigner=mask_assigner,
  1909. share_bbox_feat=share_bbox_feat,
  1910. num_classes=num_classes)
  1911. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  1912. num_classes=num_classes,
  1913. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  1914. nms=ppdet.modeling.MultiClassNMS(
  1915. score_threshold=score_threshold,
  1916. keep_top_k=keep_top_k,
  1917. nms_threshold=nms_threshold))
  1918. mask_post_process = ppdet.modeling.MaskPostProcess(binary_thresh=.5)
  1919. params.update({
  1920. 'backbone': backbone,
  1921. 'neck': neck,
  1922. 'rpn_head': rpn_head,
  1923. 'bbox_head': bbox_head,
  1924. 'mask_head': mask_head,
  1925. 'bbox_post_process': bbox_post_process,
  1926. 'mask_post_process': mask_post_process
  1927. })
  1928. self.with_fpn = with_fpn
  1929. super(MaskRCNN, self).__init__(
  1930. model_name='MaskRCNN', num_classes=num_classes, **params)
  1931. def _pre_train(self, in_args):
  1932. train_dataset = in_args['train_dataset']
  1933. if train_dataset.pos_num < len(train_dataset.file_list):
  1934. # In-place modification
  1935. train_dataset.num_workers = 0
  1936. return in_args
  1937. def _compose_batch_transform(self, transforms, mode='train'):
  1938. if mode == 'train':
  1939. default_batch_transforms = [
  1940. _BatchPad(pad_to_stride=32 if self.with_fpn else -1)
  1941. ]
  1942. else:
  1943. default_batch_transforms = [
  1944. _BatchPad(pad_to_stride=32 if self.with_fpn else -1)
  1945. ]
  1946. custom_batch_transforms = []
  1947. for i, op in enumerate(transforms.transforms):
  1948. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  1949. if mode != 'train':
  1950. raise ValueError(
  1951. "{} cannot be present in the {} transforms. ".format(
  1952. op.__class__.__name__, mode) +
  1953. "Please check the {} transforms.".format(mode))
  1954. custom_batch_transforms.insert(0, copy.deepcopy(op))
  1955. batch_transforms = BatchCompose(
  1956. custom_batch_transforms + default_batch_transforms,
  1957. collate_batch=False)
  1958. return batch_transforms
  1959. def _fix_transforms_shape(self, image_shape):
  1960. if getattr(self, 'test_transforms', None):
  1961. has_resize_op = False
  1962. resize_op_idx = -1
  1963. normalize_op_idx = len(self.test_transforms.transforms)
  1964. for idx, op in enumerate(self.test_transforms.transforms):
  1965. name = op.__class__.__name__
  1966. if name == 'ResizeByShort':
  1967. has_resize_op = True
  1968. resize_op_idx = idx
  1969. if name == 'Normalize':
  1970. normalize_op_idx = idx
  1971. if not has_resize_op:
  1972. self.test_transforms.transforms.insert(
  1973. normalize_op_idx,
  1974. Resize(
  1975. target_size=image_shape,
  1976. keep_ratio=True,
  1977. interp='CUBIC'))
  1978. else:
  1979. self.test_transforms.transforms[resize_op_idx] = Resize(
  1980. target_size=image_shape, keep_ratio=True, interp='CUBIC')
  1981. self.test_transforms.transforms.append(
  1982. Pad(im_padding_value=[0., 0., 0.]))
  1983. def _get_test_inputs(self, image_shape):
  1984. if image_shape is not None:
  1985. image_shape = self._check_image_shape(image_shape)
  1986. self._fix_transforms_shape(image_shape[-2:])
  1987. else:
  1988. image_shape = [None, 3, -1, -1]
  1989. if self.with_fpn:
  1990. self.test_transforms.transforms.append(
  1991. Pad(im_padding_value=[0., 0., 0.]))
  1992. self.fixed_input_shape = image_shape
  1993. return self._define_input_spec(image_shape)