object_detector.py 100 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import collections
  16. import copy
  17. import os
  18. import os.path as osp
  19. import numpy as np
  20. import paddle
  21. from paddle.static import InputSpec
  22. import paddlers.models.ppdet as ppdet
  23. from paddlers.models.ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner
  24. import paddlers
  25. import paddlers.utils.logging as logging
  26. from paddlers.transforms import decode_image
  27. from paddlers.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Pad
  28. from paddlers.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, \
  29. _BatchPad, _Gt2YoloTarget
  30. from .base import BaseModel
  31. from .utils.det_metrics import VOCMetric, COCOMetric
  32. from paddlers.models.ppdet.optimizer import ModelEMA
  33. from paddlers.utils.checkpoint import det_pretrain_weights_dict
  34. __all__ = [
  35. "YOLOv3", "FasterRCNN", "PPYOLO", "PPYOLOTiny", "PPYOLOv2", "MaskRCNN"
  36. ]
  37. class BaseDetector(BaseModel):
  38. def __init__(self, model_name, num_classes=80, **params):
  39. self.init_params.update(locals())
  40. if 'with_net' in self.init_params:
  41. del self.init_params['with_net']
  42. super(BaseDetector, self).__init__('detector')
  43. if not hasattr(ppdet.modeling, model_name):
  44. raise ValueError("ERROR: There is no model named {}.".format(
  45. model_name))
  46. self.model_name = model_name
  47. self.num_classes = num_classes
  48. self.labels = None
  49. if params.get('with_net', True):
  50. params.pop('with_net', None)
  51. self.net = self.build_net(**params)
  52. def build_net(self, **params):
  53. with paddle.utils.unique_name.guard():
  54. net = ppdet.modeling.__dict__[self.model_name](**params)
  55. return net
  56. def _fix_transforms_shape(self, image_shape):
  57. raise NotImplementedError("_fix_transforms_shape: not implemented!")
  58. def _define_input_spec(self, image_shape):
  59. input_spec = [{
  60. "image": InputSpec(
  61. shape=image_shape, name='image', dtype='float32'),
  62. "im_shape": InputSpec(
  63. shape=[image_shape[0], 2], name='im_shape', dtype='float32'),
  64. "scale_factor": InputSpec(
  65. shape=[image_shape[0], 2], name='scale_factor', dtype='float32')
  66. }]
  67. return input_spec
  68. def _check_image_shape(self, image_shape):
  69. if len(image_shape) == 2:
  70. image_shape = [1, 3] + image_shape
  71. if image_shape[-2] % 32 > 0 or image_shape[-1] % 32 > 0:
  72. raise ValueError(
  73. "Height and width in fixed_input_shape must be a multiple of 32, but received {}.".
  74. format(image_shape[-2:]))
  75. return image_shape
  76. def _get_test_inputs(self, image_shape):
  77. if image_shape is not None:
  78. image_shape = self._check_image_shape(image_shape)
  79. self._fix_transforms_shape(image_shape[-2:])
  80. else:
  81. image_shape = [None, 3, -1, -1]
  82. self.fixed_input_shape = image_shape
  83. return self._define_input_spec(image_shape)
  84. def _get_backbone(self, backbone_name, **params):
  85. backbone = getattr(ppdet.modeling, backbone_name)(**params)
  86. return backbone
  87. def run(self, net, inputs, mode):
  88. net_out = net(inputs)
  89. if mode in ['train', 'eval']:
  90. outputs = net_out
  91. else:
  92. outputs = dict()
  93. for key in net_out:
  94. outputs[key] = net_out[key].numpy()
  95. return outputs
  96. def default_optimizer(self,
  97. parameters,
  98. learning_rate,
  99. warmup_steps,
  100. warmup_start_lr,
  101. lr_decay_epochs,
  102. lr_decay_gamma,
  103. num_steps_each_epoch,
  104. reg_coeff=1e-04,
  105. scheduler='Piecewise',
  106. num_epochs=None):
  107. if scheduler.lower() == 'piecewise':
  108. if warmup_steps > 0 and warmup_steps > lr_decay_epochs[
  109. 0] * num_steps_each_epoch:
  110. logging.error(
  111. "In function train(), parameters must satisfy: "
  112. "warmup_steps <= lr_decay_epochs[0] * num_samples_in_train_dataset. "
  113. "See this doc for more information: "
  114. "https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/parameters.md",
  115. exit=False)
  116. logging.error(
  117. "Either `warmup_steps` be less than {} or lr_decay_epochs[0] be greater than {} "
  118. "must be satisfied, please modify 'warmup_steps' or 'lr_decay_epochs' in train function".
  119. format(lr_decay_epochs[0] * num_steps_each_epoch,
  120. warmup_steps // num_steps_each_epoch),
  121. exit=True)
  122. boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
  123. values = [(lr_decay_gamma**i) * learning_rate
  124. for i in range(len(lr_decay_epochs) + 1)]
  125. scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries, values)
  126. elif scheduler.lower() == 'cosine':
  127. if num_epochs is None:
  128. logging.error(
  129. "`num_epochs` must be set while using cosine annealing decay scheduler, but received {}".
  130. format(num_epochs),
  131. exit=False)
  132. if warmup_steps > 0 and warmup_steps > num_epochs * num_steps_each_epoch:
  133. logging.error(
  134. "In function train(), parameters must satisfy: "
  135. "warmup_steps <= num_epochs * num_samples_in_train_dataset. "
  136. "See this doc for more information: "
  137. "https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/parameters.md",
  138. exit=False)
  139. logging.error(
  140. "`warmup_steps` must be less than the total number of steps({}), "
  141. "please modify 'num_epochs' or 'warmup_steps' in train function".
  142. format(num_epochs * num_steps_each_epoch),
  143. exit=True)
  144. T_max = num_epochs * num_steps_each_epoch - warmup_steps
  145. scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
  146. learning_rate=learning_rate,
  147. T_max=T_max,
  148. eta_min=0.0,
  149. last_epoch=-1)
  150. else:
  151. logging.error(
  152. "Invalid learning rate scheduler: {}!".format(scheduler),
  153. exit=True)
  154. if warmup_steps > 0:
  155. scheduler = paddle.optimizer.lr.LinearWarmup(
  156. learning_rate=scheduler,
  157. warmup_steps=warmup_steps,
  158. start_lr=warmup_start_lr,
  159. end_lr=learning_rate)
  160. optimizer = paddle.optimizer.Momentum(
  161. scheduler,
  162. momentum=.9,
  163. weight_decay=paddle.regularizer.L2Decay(coeff=reg_coeff),
  164. parameters=parameters)
  165. return optimizer
  166. def train(self,
  167. num_epochs,
  168. train_dataset,
  169. train_batch_size=64,
  170. eval_dataset=None,
  171. optimizer=None,
  172. save_interval_epochs=1,
  173. log_interval_steps=10,
  174. save_dir='output',
  175. pretrain_weights='IMAGENET',
  176. learning_rate=.001,
  177. warmup_steps=0,
  178. warmup_start_lr=0.0,
  179. lr_decay_epochs=(216, 243),
  180. lr_decay_gamma=0.1,
  181. metric=None,
  182. use_ema=False,
  183. early_stop=False,
  184. early_stop_patience=5,
  185. use_vdl=True,
  186. resume_checkpoint=None):
  187. """
  188. Train the model.
  189. Args:
  190. num_epochs (int): Number of epochs.
  191. train_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset):
  192. Training dataset.
  193. train_batch_size (int, optional): Total batch size among all cards used in
  194. training. Defaults to 64.
  195. eval_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset, optional):
  196. Evaluation dataset. If None, the model will not be evaluated during training
  197. process. Defaults to None.
  198. optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used for
  199. training. If None, a default optimizer will be used. Defaults to None.
  200. save_interval_epochs (int, optional): Epoch interval for saving the model.
  201. Defaults to 1.
  202. log_interval_steps (int, optional): Step interval for printing training
  203. information. Defaults to 10.
  204. save_dir (str, optional): Directory to save the model. Defaults to 'output'.
  205. pretrain_weights (str|None, optional): None or name/path of pretrained
  206. weights. If None, no pretrained weights will be loaded.
  207. Defaults to 'IMAGENET'.
  208. learning_rate (float, optional): Learning rate for training. Defaults to .001.
  209. warmup_steps (int, optional): Number of steps of warm-up training.
  210. Defaults to 0.
  211. warmup_start_lr (float, optional): Start learning rate of warm-up training.
  212. Defaults to 0..
  213. lr_decay_epochs (list|tuple, optional): Epoch milestones for learning
  214. rate decay. Defaults to (216, 243).
  215. lr_decay_gamma (float, optional): Gamma coefficient of learning rate decay.
  216. Defaults to .1.
  217. metric (str|None, optional): Evaluation metric. Choices are {'VOC', 'COCO', None}.
  218. If None, determine the metric according to the dataset format.
  219. Defaults to None.
  220. use_ema (bool, optional): Whether to use exponential moving average
  221. strategy. Defaults to False.
  222. early_stop (bool, optional): Whether to adopt early stop strategy.
  223. Defaults to False.
  224. early_stop_patience (int, optional): Early stop patience. Defaults to 5.
  225. use_vdl(bool, optional): Whether to use VisualDL to monitor the training
  226. process. Defaults to True.
  227. resume_checkpoint (str|None, optional): Path of the checkpoint to resume
  228. training from. If None, no training checkpoint will be resumed. At most
  229. Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
  230. Defaults to None.
  231. """
  232. if self.status == 'Infer':
  233. logging.error(
  234. "Exported inference model does not support training.",
  235. exit=True)
  236. if pretrain_weights is not None and resume_checkpoint is not None:
  237. logging.error(
  238. "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
  239. exit=True)
  240. if train_dataset.__class__.__name__ == 'VOCDetDataset':
  241. train_dataset.data_fields = {
  242. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  243. 'difficult'
  244. }
  245. elif train_dataset.__class__.__name__ == 'CocoDetection':
  246. if self.__class__.__name__ == 'MaskRCNN':
  247. train_dataset.data_fields = {
  248. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  249. 'gt_poly', 'is_crowd'
  250. }
  251. else:
  252. train_dataset.data_fields = {
  253. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  254. 'is_crowd'
  255. }
  256. if metric is None:
  257. if eval_dataset.__class__.__name__ == 'VOCDetDataset':
  258. self.metric = 'voc'
  259. elif eval_dataset.__class__.__name__ == 'COCODetDataset':
  260. self.metric = 'coco'
  261. else:
  262. assert metric.lower() in ['coco', 'voc'], \
  263. "Evaluation metric {} is not supported. Please choose from 'COCO' and 'VOC'."
  264. self.metric = metric.lower()
  265. self.labels = train_dataset.labels
  266. self.num_max_boxes = train_dataset.num_max_boxes
  267. train_dataset.batch_transforms = self._compose_batch_transform(
  268. train_dataset.transforms, mode='train')
  269. # build optimizer if not defined
  270. if optimizer is None:
  271. num_steps_each_epoch = len(train_dataset) // train_batch_size
  272. self.optimizer = self.default_optimizer(
  273. parameters=self.net.parameters(),
  274. learning_rate=learning_rate,
  275. warmup_steps=warmup_steps,
  276. warmup_start_lr=warmup_start_lr,
  277. lr_decay_epochs=lr_decay_epochs,
  278. lr_decay_gamma=lr_decay_gamma,
  279. num_steps_each_epoch=num_steps_each_epoch)
  280. else:
  281. self.optimizer = optimizer
  282. # initiate weights
  283. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  284. if pretrain_weights not in det_pretrain_weights_dict['_'.join(
  285. [self.model_name, self.backbone_name])]:
  286. logging.warning(
  287. "Path of pretrain_weights('{}') does not exist!".format(
  288. pretrain_weights))
  289. pretrain_weights = det_pretrain_weights_dict['_'.join(
  290. [self.model_name, self.backbone_name])][0]
  291. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  292. "If you don't want to use pretrain weights, "
  293. "set pretrain_weights to be None.".format(
  294. pretrain_weights))
  295. elif pretrain_weights is not None and osp.exists(pretrain_weights):
  296. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  297. logging.error(
  298. "Invalid pretrain weights. Please specify a '.pdparams' file.",
  299. exit=True)
  300. pretrained_dir = osp.join(save_dir, 'pretrain')
  301. self.net_initialize(
  302. pretrain_weights=pretrain_weights,
  303. save_dir=pretrained_dir,
  304. resume_checkpoint=resume_checkpoint,
  305. is_backbone_weights=(pretrain_weights == 'IMAGENET' and
  306. 'ESNet_' in self.backbone_name))
  307. if use_ema:
  308. ema = ModelEMA(model=self.net, decay=.9998, use_thres_step=True)
  309. else:
  310. ema = None
  311. # start train loop
  312. self.train_loop(
  313. num_epochs=num_epochs,
  314. train_dataset=train_dataset,
  315. train_batch_size=train_batch_size,
  316. eval_dataset=eval_dataset,
  317. save_interval_epochs=save_interval_epochs,
  318. log_interval_steps=log_interval_steps,
  319. save_dir=save_dir,
  320. ema=ema,
  321. early_stop=early_stop,
  322. early_stop_patience=early_stop_patience,
  323. use_vdl=use_vdl)
  324. def quant_aware_train(self,
  325. num_epochs,
  326. train_dataset,
  327. train_batch_size=64,
  328. eval_dataset=None,
  329. optimizer=None,
  330. save_interval_epochs=1,
  331. log_interval_steps=10,
  332. save_dir='output',
  333. learning_rate=.00001,
  334. warmup_steps=0,
  335. warmup_start_lr=0.0,
  336. lr_decay_epochs=(216, 243),
  337. lr_decay_gamma=0.1,
  338. metric=None,
  339. use_ema=False,
  340. early_stop=False,
  341. early_stop_patience=5,
  342. use_vdl=True,
  343. resume_checkpoint=None,
  344. quant_config=None):
  345. """
  346. Quantization-aware training.
  347. Args:
  348. num_epochs (int): Number of epochs.
  349. train_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset):
  350. Training dataset.
  351. train_batch_size (int, optional): Total batch size among all cards used in
  352. training. Defaults to 64.
  353. eval_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset, optional):
  354. Evaluation dataset. If None, the model will not be evaluated during training
  355. process. Defaults to None.
  356. optimizer (paddle.optimizer.Optimizer or None, optional): Optimizer used for
  357. training. If None, a default optimizer will be used. Defaults to None.
  358. save_interval_epochs (int, optional): Epoch interval for saving the model.
  359. Defaults to 1.
  360. log_interval_steps (int, optional): Step interval for printing training
  361. information. Defaults to 10.
  362. save_dir (str, optional): Directory to save the model. Defaults to 'output'.
  363. learning_rate (float, optional): Learning rate for training.
  364. Defaults to .00001.
  365. warmup_steps (int, optional): Number of steps of warm-up training.
  366. Defaults to 0.
  367. warmup_start_lr (float, optional): Start learning rate of warm-up training.
  368. Defaults to 0..
  369. lr_decay_epochs (list or tuple, optional): Epoch milestones for learning rate
  370. decay. Defaults to (216, 243).
  371. lr_decay_gamma (float, optional): Gamma coefficient of learning rate decay.
  372. Defaults to .1.
  373. metric (str|None, optional): Evaluation metric. Choices are {'VOC', 'COCO', None}.
  374. If None, determine the metric according to the dataset format.
  375. Defaults to None.
  376. use_ema (bool, optional): Whether to use exponential moving average strategy.
  377. Defaults to False.
  378. early_stop (bool, optional): Whether to adopt early stop strategy.
  379. Defaults to False.
  380. early_stop_patience (int, optional): Early stop patience. Defaults to 5.
  381. use_vdl (bool, optional): Whether to use VisualDL to monitor the training
  382. process. Defaults to True.
  383. quant_config (dict or None, optional): Quantization configuration. If None,
  384. a default rule of thumb configuration will be used. Defaults to None.
  385. resume_checkpoint (str|None, optional): Path of the checkpoint to resume
  386. quantization-aware training from. If None, no training checkpoint will
  387. be resumed. Defaults to None.
  388. """
  389. self._prepare_qat(quant_config)
  390. self.train(
  391. num_epochs=num_epochs,
  392. train_dataset=train_dataset,
  393. train_batch_size=train_batch_size,
  394. eval_dataset=eval_dataset,
  395. optimizer=optimizer,
  396. save_interval_epochs=save_interval_epochs,
  397. log_interval_steps=log_interval_steps,
  398. save_dir=save_dir,
  399. pretrain_weights=None,
  400. learning_rate=learning_rate,
  401. warmup_steps=warmup_steps,
  402. warmup_start_lr=warmup_start_lr,
  403. lr_decay_epochs=lr_decay_epochs,
  404. lr_decay_gamma=lr_decay_gamma,
  405. metric=metric,
  406. use_ema=use_ema,
  407. early_stop=early_stop,
  408. early_stop_patience=early_stop_patience,
  409. use_vdl=use_vdl,
  410. resume_checkpoint=resume_checkpoint)
  411. def evaluate(self,
  412. eval_dataset,
  413. batch_size=1,
  414. metric=None,
  415. return_details=False):
  416. """
  417. Evaluate the model.
  418. Args:
  419. eval_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset):
  420. Evaluation dataset.
  421. batch_size (int, optional): Total batch size among all cards used for
  422. evaluation. Defaults to 1.
  423. metric (str|None, optional): Evaluation metric. Choices are {'VOC', 'COCO', None}.
  424. If None, determine the metric according to the dataset format.
  425. Defaults to None.
  426. return_details (bool, optional): Whether to return evaluation details.
  427. Defaults to False.
  428. Returns:
  429. collections.OrderedDict with key-value pairs:
  430. {"mAP(0.50, 11point)":`mean average precision`}.
  431. """
  432. if metric is None:
  433. if not hasattr(self, 'metric'):
  434. if eval_dataset.__class__.__name__ == 'VOCDetDataset':
  435. self.metric = 'voc'
  436. elif eval_dataset.__class__.__name__ == 'COCODetDataset':
  437. self.metric = 'coco'
  438. else:
  439. assert metric.lower() in ['coco', 'voc'], \
  440. "Evaluation metric {} is not supported. Please choose from 'COCO' and 'VOC'."
  441. self.metric = metric.lower()
  442. if self.metric == 'voc':
  443. eval_dataset.data_fields = {
  444. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  445. 'difficult'
  446. }
  447. elif self.metric == 'coco':
  448. if self.__class__.__name__ == 'MaskRCNN':
  449. eval_dataset.data_fields = {
  450. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  451. 'gt_poly', 'is_crowd'
  452. }
  453. else:
  454. eval_dataset.data_fields = {
  455. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  456. 'is_crowd'
  457. }
  458. eval_dataset.batch_transforms = self._compose_batch_transform(
  459. eval_dataset.transforms, mode='eval')
  460. self._check_transforms(eval_dataset.transforms, 'eval')
  461. self.net.eval()
  462. nranks = paddle.distributed.get_world_size()
  463. local_rank = paddle.distributed.get_rank()
  464. if nranks > 1:
  465. # Initialize parallel environment if not done.
  466. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  467. ):
  468. paddle.distributed.init_parallel_env()
  469. if batch_size > 1:
  470. logging.warning(
  471. "Detector only supports single card evaluation with batch_size=1 "
  472. "during evaluation, so batch_size is forcibly set to 1.")
  473. batch_size = 1
  474. if nranks < 2 or local_rank == 0:
  475. self.eval_data_loader = self.build_data_loader(
  476. eval_dataset, batch_size=batch_size, mode='eval')
  477. is_bbox_normalized = False
  478. if eval_dataset.batch_transforms is not None:
  479. is_bbox_normalized = any(
  480. isinstance(t, _NormalizeBox)
  481. for t in eval_dataset.batch_transforms.batch_transforms)
  482. if self.metric == 'voc':
  483. eval_metric = VOCMetric(
  484. labels=eval_dataset.labels,
  485. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  486. is_bbox_normalized=is_bbox_normalized,
  487. classwise=False)
  488. else:
  489. eval_metric = COCOMetric(
  490. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  491. classwise=False)
  492. scores = collections.OrderedDict()
  493. logging.info(
  494. "Start to evaluate(total_samples={}, total_steps={})...".format(
  495. eval_dataset.num_samples, eval_dataset.num_samples))
  496. with paddle.no_grad():
  497. for step, data in enumerate(self.eval_data_loader):
  498. outputs = self.run(self.net, data, 'eval')
  499. eval_metric.update(data, outputs)
  500. eval_metric.accumulate()
  501. self.eval_details = eval_metric.details
  502. scores.update(eval_metric.get())
  503. eval_metric.reset()
  504. if return_details:
  505. return scores, self.eval_details
  506. return scores
  507. def predict(self, img_file, transforms=None):
  508. """
  509. Do inference.
  510. Args:
  511. img_file (list[np.ndarray|str] | str | np.ndarray): Image path or decoded
  512. image data, which also could constitute a list, meaning all images to be
  513. predicted as a mini-batch.
  514. transforms (paddlers.transforms.Compose|None, optional): Transforms for
  515. inputs. If None, the transforms for evaluation process will be used.
  516. Defaults to None.
  517. Returns:
  518. If `img_file` is a string or np.array, the result is a list of dict with
  519. key-value pairs:
  520. {"category_id": `category_id`, "category": `category`, "bbox": `[x, y, w, h]`, "score": `score`}.
  521. If `img_file` is a list, the result is a list composed of dicts with the
  522. corresponding fields:
  523. category_id(int): the predicted category ID. 0 represents the first
  524. category in the dataset, and so on.
  525. category(str): category name
  526. bbox(list): bounding box in [x, y, w, h] format
  527. score(str): confidence
  528. mask(dict): Only for instance segmentation task. Mask of the object in
  529. RLE format
  530. """
  531. if transforms is None and not hasattr(self, 'test_transforms'):
  532. raise ValueError("transforms need to be defined, now is None.")
  533. if transforms is None:
  534. transforms = self.test_transforms
  535. if isinstance(img_file, (str, np.ndarray)):
  536. images = [img_file]
  537. else:
  538. images = img_file
  539. batch_samples = self._preprocess(images, transforms)
  540. self.net.eval()
  541. outputs = self.run(self.net, batch_samples, 'test')
  542. prediction = self._postprocess(outputs)
  543. if isinstance(img_file, (str, np.ndarray)):
  544. prediction = prediction[0]
  545. return prediction
  546. def _preprocess(self, images, transforms, to_tensor=True):
  547. self._check_transforms(transforms, 'test')
  548. batch_samples = list()
  549. for im in images:
  550. if isinstance(im, str):
  551. im = decode_image(im, to_rgb=False)
  552. sample = {'image': im}
  553. sample = transforms(sample)
  554. batch_samples.append(sample)
  555. batch_transforms = self._compose_batch_transform(transforms, 'test')
  556. batch_samples = batch_transforms(batch_samples)
  557. if to_tensor:
  558. for k in batch_samples:
  559. batch_samples[k] = paddle.to_tensor(batch_samples[k])
  560. return batch_samples
  561. def _postprocess(self, batch_pred):
  562. infer_result = {}
  563. if 'bbox' in batch_pred:
  564. bboxes = batch_pred['bbox']
  565. bbox_nums = batch_pred['bbox_num']
  566. det_res = []
  567. k = 0
  568. for i in range(len(bbox_nums)):
  569. det_nums = bbox_nums[i]
  570. for j in range(det_nums):
  571. dt = bboxes[k]
  572. k = k + 1
  573. num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
  574. if int(num_id) < 0:
  575. continue
  576. category = self.labels[int(num_id)]
  577. w = xmax - xmin
  578. h = ymax - ymin
  579. bbox = [xmin, ymin, w, h]
  580. dt_res = {
  581. 'category_id': int(num_id),
  582. 'category': category,
  583. 'bbox': bbox,
  584. 'score': score
  585. }
  586. det_res.append(dt_res)
  587. infer_result['bbox'] = det_res
  588. if 'mask' in batch_pred:
  589. masks = batch_pred['mask']
  590. bboxes = batch_pred['bbox']
  591. mask_nums = batch_pred['bbox_num']
  592. seg_res = []
  593. k = 0
  594. for i in range(len(mask_nums)):
  595. det_nums = mask_nums[i]
  596. for j in range(det_nums):
  597. mask = masks[k].astype(np.uint8)
  598. score = float(bboxes[k][1])
  599. label = int(bboxes[k][0])
  600. k = k + 1
  601. if label == -1:
  602. continue
  603. category = self.labels[int(label)]
  604. sg_res = {
  605. 'category_id': int(label),
  606. 'category': category,
  607. 'mask': mask.astype('uint8'),
  608. 'score': score
  609. }
  610. seg_res.append(sg_res)
  611. infer_result['mask'] = seg_res
  612. bbox_num = batch_pred['bbox_num']
  613. results = []
  614. start = 0
  615. for num in bbox_num:
  616. end = start + num
  617. curr_res = infer_result['bbox'][start:end]
  618. if 'mask' in infer_result:
  619. mask_res = infer_result['mask'][start:end]
  620. for box, mask in zip(curr_res, mask_res):
  621. box.update(mask)
  622. results.append(curr_res)
  623. start = end
  624. return results
  625. def _check_transforms(self, transforms, mode):
  626. super()._check_transforms(transforms, mode)
  627. if not isinstance(transforms.arrange,
  628. paddlers.transforms.ArrangeDetector):
  629. raise TypeError(
  630. "`transforms.arrange` must be an ArrangeDetector object.")
  631. class PicoDet(BaseDetector):
  632. def __init__(self,
  633. num_classes=80,
  634. backbone='ESNet_m',
  635. nms_score_threshold=.025,
  636. nms_topk=1000,
  637. nms_keep_topk=100,
  638. nms_iou_threshold=.6,
  639. **params):
  640. self.init_params = locals()
  641. if backbone not in {
  642. 'ESNet_s', 'ESNet_m', 'ESNet_l', 'LCNet', 'MobileNetV3',
  643. 'ResNet18_vd'
  644. }:
  645. raise ValueError(
  646. "backbone: {} is not supported. Please choose one of "
  647. "{'ESNet_s', 'ESNet_m', 'ESNet_l', 'LCNet', 'MobileNetV3', 'ResNet18_vd'}.".
  648. format(backbone))
  649. self.backbone_name = backbone
  650. if params.get('with_net', True):
  651. if backbone == 'ESNet_s':
  652. backbone = self._get_backbone(
  653. 'ESNet',
  654. scale=.75,
  655. feature_maps=[4, 11, 14],
  656. act="hard_swish",
  657. channel_ratio=[
  658. 0.875, 0.5, 0.5, 0.5, 0.625, 0.5, 0.625, 0.5, 0.5, 0.5,
  659. 0.5, 0.5, 0.5
  660. ])
  661. neck_out_channels = 96
  662. head_num_convs = 2
  663. elif backbone == 'ESNet_m':
  664. backbone = self._get_backbone(
  665. 'ESNet',
  666. scale=1.0,
  667. feature_maps=[4, 11, 14],
  668. act="hard_swish",
  669. channel_ratio=[
  670. 0.875, 0.5, 1.0, 0.625, 0.5, 0.75, 0.625, 0.625, 0.5,
  671. 0.625, 1.0, 0.625, 0.75
  672. ])
  673. neck_out_channels = 128
  674. head_num_convs = 4
  675. elif backbone == 'ESNet_l':
  676. backbone = self._get_backbone(
  677. 'ESNet',
  678. scale=1.25,
  679. feature_maps=[4, 11, 14],
  680. act="hard_swish",
  681. channel_ratio=[
  682. 0.875, 0.5, 1.0, 0.625, 0.5, 0.75, 0.625, 0.625, 0.5,
  683. 0.625, 1.0, 0.625, 0.75
  684. ])
  685. neck_out_channels = 160
  686. head_num_convs = 4
  687. elif backbone == 'LCNet':
  688. backbone = self._get_backbone(
  689. 'LCNet', scale=1.5, feature_maps=[3, 4, 5])
  690. neck_out_channels = 128
  691. head_num_convs = 4
  692. elif backbone == 'MobileNetV3':
  693. backbone = self._get_backbone(
  694. 'MobileNetV3',
  695. scale=1.0,
  696. with_extra_blocks=False,
  697. extra_block_filters=[],
  698. feature_maps=[7, 13, 16])
  699. neck_out_channels = 128
  700. head_num_convs = 4
  701. else:
  702. backbone = self._get_backbone(
  703. 'ResNet',
  704. depth=18,
  705. variant='d',
  706. return_idx=[1, 2, 3],
  707. freeze_at=-1,
  708. freeze_norm=False,
  709. norm_decay=0.)
  710. neck_out_channels = 128
  711. head_num_convs = 4
  712. neck = ppdet.modeling.CSPPAN(
  713. in_channels=[i.channels for i in backbone.out_shape],
  714. out_channels=neck_out_channels,
  715. num_features=4,
  716. num_csp_blocks=1,
  717. use_depthwise=True)
  718. head_conv_feat = ppdet.modeling.PicoFeat(
  719. feat_in=neck_out_channels,
  720. feat_out=neck_out_channels,
  721. num_fpn_stride=4,
  722. num_convs=head_num_convs,
  723. norm_type='bn',
  724. share_cls_reg=True, )
  725. loss_class = ppdet.modeling.VarifocalLoss(
  726. use_sigmoid=True, iou_weighted=True, loss_weight=1.0)
  727. loss_dfl = ppdet.modeling.DistributionFocalLoss(loss_weight=.25)
  728. loss_bbox = ppdet.modeling.GIoULoss(loss_weight=2.0)
  729. assigner = ppdet.modeling.SimOTAAssigner(
  730. candidate_topk=10, iou_weight=6, num_classes=num_classes)
  731. nms = ppdet.modeling.MultiClassNMS(
  732. nms_top_k=nms_topk,
  733. keep_top_k=nms_keep_topk,
  734. score_threshold=nms_score_threshold,
  735. nms_threshold=nms_iou_threshold)
  736. head = ppdet.modeling.PicoHead(
  737. conv_feat=head_conv_feat,
  738. num_classes=num_classes,
  739. fpn_stride=[8, 16, 32, 64],
  740. prior_prob=0.01,
  741. reg_max=7,
  742. cell_offset=.5,
  743. loss_class=loss_class,
  744. loss_dfl=loss_dfl,
  745. loss_bbox=loss_bbox,
  746. assigner=assigner,
  747. feat_in_chan=neck_out_channels,
  748. nms=nms)
  749. params.update({
  750. 'backbone': backbone,
  751. 'neck': neck,
  752. 'head': head,
  753. })
  754. super(PicoDet, self).__init__(
  755. model_name='PicoDet', num_classes=num_classes, **params)
  756. def _compose_batch_transform(self, transforms, mode='train'):
  757. default_batch_transforms = [_BatchPad(pad_to_stride=32)]
  758. if mode == 'eval':
  759. collate_batch = True
  760. else:
  761. collate_batch = False
  762. custom_batch_transforms = []
  763. for i, op in enumerate(transforms.transforms):
  764. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  765. if mode != 'train':
  766. raise ValueError(
  767. "{} cannot be present in the {} transforms. ".format(
  768. op.__class__.__name__, mode) +
  769. "Please check the {} transforms.".format(mode))
  770. custom_batch_transforms.insert(0, copy.deepcopy(op))
  771. batch_transforms = BatchCompose(
  772. custom_batch_transforms + default_batch_transforms,
  773. collate_batch=collate_batch)
  774. return batch_transforms
  775. def _fix_transforms_shape(self, image_shape):
  776. if getattr(self, 'test_transforms', None):
  777. has_resize_op = False
  778. resize_op_idx = -1
  779. normalize_op_idx = len(self.test_transforms.transforms)
  780. for idx, op in enumerate(self.test_transforms.transforms):
  781. name = op.__class__.__name__
  782. if name == 'Resize':
  783. has_resize_op = True
  784. resize_op_idx = idx
  785. if name == 'Normalize':
  786. normalize_op_idx = idx
  787. if not has_resize_op:
  788. self.test_transforms.transforms.insert(
  789. normalize_op_idx,
  790. Resize(
  791. target_size=image_shape, interp='CUBIC'))
  792. else:
  793. self.test_transforms.transforms[
  794. resize_op_idx].target_size = image_shape
  795. def _get_test_inputs(self, image_shape):
  796. if image_shape is not None:
  797. image_shape = self._check_image_shape(image_shape)
  798. self._fix_transforms_shape(image_shape[-2:])
  799. else:
  800. image_shape = [None, 3, 320, 320]
  801. if getattr(self, 'test_transforms', None):
  802. for idx, op in enumerate(self.test_transforms.transforms):
  803. name = op.__class__.__name__
  804. if name == 'Resize':
  805. image_shape = [None, 3] + list(
  806. self.test_transforms.transforms[idx].target_size)
  807. logging.warning(
  808. '[Important!!!] When exporting inference model for {}, '
  809. 'if fixed_input_shape is not set, it will be forcibly set to {}. '
  810. 'Please ensure image shape after transforms is {}, if not, '
  811. 'fixed_input_shape should be specified manually.'
  812. .format(self.__class__.__name__, image_shape, image_shape[1:]))
  813. self.fixed_input_shape = image_shape
  814. return self._define_input_spec(image_shape)
  815. def train(self,
  816. num_epochs,
  817. train_dataset,
  818. train_batch_size=64,
  819. eval_dataset=None,
  820. optimizer=None,
  821. save_interval_epochs=1,
  822. log_interval_steps=10,
  823. save_dir='output',
  824. pretrain_weights='IMAGENET',
  825. learning_rate=.001,
  826. warmup_steps=0,
  827. warmup_start_lr=0.0,
  828. lr_decay_epochs=(216, 243),
  829. lr_decay_gamma=0.1,
  830. metric=None,
  831. use_ema=False,
  832. early_stop=False,
  833. early_stop_patience=5,
  834. use_vdl=True,
  835. resume_checkpoint=None):
  836. """
  837. Train the model.
  838. Args:
  839. num_epochs (int): Number of epochs.
  840. train_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset):
  841. Training dataset.
  842. train_batch_size (int, optional): Total batch size among all cards used in
  843. training. Defaults to 64.
  844. eval_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset, optional):
  845. Evaluation dataset. If None, the model will not be evaluated during training
  846. process. Defaults to None.
  847. optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used for
  848. training. If None, a default optimizer will be used. Defaults to None.
  849. save_interval_epochs (int, optional): Epoch interval for saving the model.
  850. Defaults to 1.
  851. log_interval_steps (int, optional): Step interval for printing training
  852. information. Defaults to 10.
  853. save_dir (str, optional): Directory to save the model. Defaults to 'output'.
  854. pretrain_weights (str|None, optional): None or name/path of pretrained
  855. weights. If None, no pretrained weights will be loaded.
  856. Defaults to 'IMAGENET'.
  857. learning_rate (float, optional): Learning rate for training. Defaults to .001.
  858. warmup_steps (int, optional): Number of steps of warm-up training.
  859. Defaults to 0.
  860. warmup_start_lr (float, optional): Start learning rate of warm-up training.
  861. Defaults to 0..
  862. lr_decay_epochs (list|tuple, optional): Epoch milestones for learning
  863. rate decay. Defaults to (216, 243).
  864. lr_decay_gamma (float, optional): Gamma coefficient of learning rate decay.
  865. Defaults to .1.
  866. metric (str|None, optional): Evaluation metric. Choices are {'VOC', 'COCO', None}.
  867. If None, determine the metric according to the dataset format.
  868. Defaults to None.
  869. use_ema (bool, optional): Whether to use exponential moving average
  870. strategy. Defaults to False.
  871. early_stop (bool, optional): Whether to adopt early stop strategy.
  872. Defaults to False.
  873. early_stop_patience (int, optional): Early stop patience. Defaults to 5.
  874. use_vdl(bool, optional): Whether to use VisualDL to monitor the training
  875. process. Defaults to True.
  876. resume_checkpoint (str|None, optional): Path of the checkpoint to resume
  877. training from. If None, no training checkpoint will be resumed. At most
  878. Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
  879. Defaults to None.
  880. """
  881. if optimizer is None:
  882. num_steps_each_epoch = len(train_dataset) // train_batch_size
  883. optimizer = self.default_optimizer(
  884. parameters=self.net.parameters(),
  885. learning_rate=learning_rate,
  886. warmup_steps=warmup_steps,
  887. warmup_start_lr=warmup_start_lr,
  888. lr_decay_epochs=lr_decay_epochs,
  889. lr_decay_gamma=lr_decay_gamma,
  890. num_steps_each_epoch=num_steps_each_epoch,
  891. reg_coeff=4e-05,
  892. scheduler='Cosine',
  893. num_epochs=num_epochs)
  894. super(PicoDet, self).train(
  895. num_epochs=num_epochs,
  896. train_dataset=train_dataset,
  897. train_batch_size=train_batch_size,
  898. eval_dataset=eval_dataset,
  899. optimizer=optimizer,
  900. save_interval_epochs=save_interval_epochs,
  901. log_interval_steps=log_interval_steps,
  902. save_dir=save_dir,
  903. pretrain_weights=pretrain_weights,
  904. learning_rate=learning_rate,
  905. warmup_steps=warmup_steps,
  906. warmup_start_lr=warmup_start_lr,
  907. lr_decay_epochs=lr_decay_epochs,
  908. lr_decay_gamma=lr_decay_gamma,
  909. metric=metric,
  910. use_ema=use_ema,
  911. early_stop=early_stop,
  912. early_stop_patience=early_stop_patience,
  913. use_vdl=use_vdl,
  914. resume_checkpoint=resume_checkpoint)
  915. class YOLOv3(BaseDetector):
  916. def __init__(self,
  917. num_classes=80,
  918. backbone='MobileNetV1',
  919. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  920. [59, 119], [116, 90], [156, 198], [373, 326]],
  921. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  922. ignore_threshold=0.7,
  923. nms_score_threshold=0.01,
  924. nms_topk=1000,
  925. nms_keep_topk=100,
  926. nms_iou_threshold=0.45,
  927. label_smooth=False,
  928. **params):
  929. self.init_params = locals()
  930. if backbone not in {
  931. 'MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3',
  932. 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34'
  933. }:
  934. raise ValueError(
  935. "backbone: {} is not supported. Please choose one of "
  936. "{'MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3', 'MobileNetV3_ssld', 'DarkNet53', "
  937. "'ResNet50_vd_dcn', 'ResNet34'}.".format(backbone))
  938. self.backbone_name = backbone
  939. if params.get('with_net', True):
  940. if paddlers.env_info['place'] == 'gpu' and paddlers.env_info[
  941. 'num'] > 1 and not os.environ.get('PADDLERS_EXPORT_STAGE'):
  942. norm_type = 'sync_bn'
  943. else:
  944. norm_type = 'bn'
  945. if 'MobileNetV1' in backbone:
  946. norm_type = 'bn'
  947. backbone = self._get_backbone('MobileNet', norm_type=norm_type)
  948. elif 'MobileNetV3' in backbone:
  949. backbone = self._get_backbone(
  950. 'MobileNetV3',
  951. norm_type=norm_type,
  952. feature_maps=[7, 13, 16])
  953. elif backbone == 'ResNet50_vd_dcn':
  954. backbone = self._get_backbone(
  955. 'ResNet',
  956. norm_type=norm_type,
  957. variant='d',
  958. return_idx=[1, 2, 3],
  959. dcn_v2_stages=[3],
  960. freeze_at=-1,
  961. freeze_norm=False)
  962. elif backbone == 'ResNet34':
  963. backbone = self._get_backbone(
  964. 'ResNet',
  965. depth=34,
  966. norm_type=norm_type,
  967. return_idx=[1, 2, 3],
  968. freeze_at=-1,
  969. freeze_norm=False,
  970. norm_decay=0.)
  971. else:
  972. backbone = self._get_backbone('DarkNet', norm_type=norm_type)
  973. neck = ppdet.modeling.YOLOv3FPN(
  974. norm_type=norm_type,
  975. in_channels=[i.channels for i in backbone.out_shape])
  976. loss = ppdet.modeling.YOLOv3Loss(
  977. num_classes=num_classes,
  978. ignore_thresh=ignore_threshold,
  979. label_smooth=label_smooth)
  980. yolo_head = ppdet.modeling.YOLOv3Head(
  981. in_channels=[i.channels for i in neck.out_shape],
  982. anchors=anchors,
  983. anchor_masks=anchor_masks,
  984. num_classes=num_classes,
  985. loss=loss)
  986. post_process = ppdet.modeling.BBoxPostProcess(
  987. decode=ppdet.modeling.YOLOBox(num_classes=num_classes),
  988. nms=ppdet.modeling.MultiClassNMS(
  989. score_threshold=nms_score_threshold,
  990. nms_top_k=nms_topk,
  991. keep_top_k=nms_keep_topk,
  992. nms_threshold=nms_iou_threshold))
  993. params.update({
  994. 'backbone': backbone,
  995. 'neck': neck,
  996. 'yolo_head': yolo_head,
  997. 'post_process': post_process
  998. })
  999. super(YOLOv3, self).__init__(
  1000. model_name='YOLOv3', num_classes=num_classes, **params)
  1001. self.anchors = anchors
  1002. self.anchor_masks = anchor_masks
  1003. def _compose_batch_transform(self, transforms, mode='train'):
  1004. if mode == 'train':
  1005. default_batch_transforms = [
  1006. _BatchPad(pad_to_stride=-1), _NormalizeBox(),
  1007. _PadBox(getattr(self, 'num_max_boxes', 50)), _BboxXYXY2XYWH(),
  1008. _Gt2YoloTarget(
  1009. anchor_masks=self.anchor_masks,
  1010. anchors=self.anchors,
  1011. downsample_ratios=getattr(self, 'downsample_ratios',
  1012. [32, 16, 8]),
  1013. num_classes=self.num_classes)
  1014. ]
  1015. else:
  1016. default_batch_transforms = [_BatchPad(pad_to_stride=-1)]
  1017. if mode == 'eval' and self.metric == 'voc':
  1018. collate_batch = False
  1019. else:
  1020. collate_batch = True
  1021. custom_batch_transforms = []
  1022. for i, op in enumerate(transforms.transforms):
  1023. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  1024. if mode != 'train':
  1025. raise ValueError(
  1026. "{} cannot be present in the {} transforms. ".format(
  1027. op.__class__.__name__, mode) +
  1028. "Please check the {} transforms.".format(mode))
  1029. custom_batch_transforms.insert(0, copy.deepcopy(op))
  1030. batch_transforms = BatchCompose(
  1031. custom_batch_transforms + default_batch_transforms,
  1032. collate_batch=collate_batch)
  1033. return batch_transforms
  1034. def _fix_transforms_shape(self, image_shape):
  1035. if getattr(self, 'test_transforms', None):
  1036. has_resize_op = False
  1037. resize_op_idx = -1
  1038. normalize_op_idx = len(self.test_transforms.transforms)
  1039. for idx, op in enumerate(self.test_transforms.transforms):
  1040. name = op.__class__.__name__
  1041. if name == 'Resize':
  1042. has_resize_op = True
  1043. resize_op_idx = idx
  1044. if name == 'Normalize':
  1045. normalize_op_idx = idx
  1046. if not has_resize_op:
  1047. self.test_transforms.transforms.insert(
  1048. normalize_op_idx,
  1049. Resize(
  1050. target_size=image_shape, interp='CUBIC'))
  1051. else:
  1052. self.test_transforms.transforms[
  1053. resize_op_idx].target_size = image_shape
  1054. class FasterRCNN(BaseDetector):
  1055. def __init__(self,
  1056. num_classes=80,
  1057. backbone='ResNet50',
  1058. with_fpn=True,
  1059. with_dcn=False,
  1060. aspect_ratios=[0.5, 1.0, 2.0],
  1061. anchor_sizes=[[32], [64], [128], [256], [512]],
  1062. keep_top_k=100,
  1063. nms_threshold=0.5,
  1064. score_threshold=0.05,
  1065. fpn_num_channels=256,
  1066. rpn_batch_size_per_im=256,
  1067. rpn_fg_fraction=0.5,
  1068. test_pre_nms_top_n=None,
  1069. test_post_nms_top_n=1000,
  1070. **params):
  1071. self.init_params = locals()
  1072. if backbone not in {
  1073. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34',
  1074. 'ResNet34_vd', 'ResNet101', 'ResNet101_vd', 'HRNet_W18'
  1075. }:
  1076. raise ValueError(
  1077. "backbone: {} is not supported. Please choose one of "
  1078. "{'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34', 'ResNet34_vd', "
  1079. "'ResNet101', 'ResNet101_vd', 'HRNet_W18'}.".format(backbone))
  1080. self.backbone_name = backbone
  1081. if params.get('with_net', True):
  1082. dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
  1083. if backbone == 'HRNet_W18':
  1084. if not with_fpn:
  1085. logging.warning(
  1086. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1087. format(backbone))
  1088. with_fpn = True
  1089. if with_dcn:
  1090. logging.warning(
  1091. "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  1092. format(backbone))
  1093. backbone = self._get_backbone(
  1094. 'HRNet', width=18, freeze_at=0, return_idx=[0, 1, 2, 3])
  1095. elif backbone == 'ResNet50_vd_ssld':
  1096. if not with_fpn:
  1097. logging.warning(
  1098. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1099. format(backbone))
  1100. with_fpn = True
  1101. backbone = self._get_backbone(
  1102. 'ResNet',
  1103. variant='d',
  1104. norm_type='bn',
  1105. freeze_at=0,
  1106. return_idx=[0, 1, 2, 3],
  1107. num_stages=4,
  1108. lr_mult_list=[0.05, 0.05, 0.1, 0.15],
  1109. dcn_v2_stages=dcn_v2_stages)
  1110. elif 'ResNet50' in backbone:
  1111. if with_fpn:
  1112. backbone = self._get_backbone(
  1113. 'ResNet',
  1114. variant='d' if '_vd' in backbone else 'b',
  1115. norm_type='bn',
  1116. freeze_at=0,
  1117. return_idx=[0, 1, 2, 3],
  1118. num_stages=4,
  1119. dcn_v2_stages=dcn_v2_stages)
  1120. else:
  1121. if with_dcn:
  1122. logging.warning(
  1123. "Backbone {} without fpn should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  1124. format(backbone))
  1125. backbone = self._get_backbone(
  1126. 'ResNet',
  1127. variant='d' if '_vd' in backbone else 'b',
  1128. norm_type='bn',
  1129. freeze_at=0,
  1130. return_idx=[2],
  1131. num_stages=3)
  1132. elif 'ResNet34' in backbone:
  1133. if not with_fpn:
  1134. logging.warning(
  1135. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1136. format(backbone))
  1137. with_fpn = True
  1138. backbone = self._get_backbone(
  1139. 'ResNet',
  1140. depth=34,
  1141. variant='d' if 'vd' in backbone else 'b',
  1142. norm_type='bn',
  1143. freeze_at=0,
  1144. return_idx=[0, 1, 2, 3],
  1145. num_stages=4,
  1146. dcn_v2_stages=dcn_v2_stages)
  1147. else:
  1148. if not with_fpn:
  1149. logging.warning(
  1150. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1151. format(backbone))
  1152. with_fpn = True
  1153. backbone = self._get_backbone(
  1154. 'ResNet',
  1155. depth=101,
  1156. variant='d' if 'vd' in backbone else 'b',
  1157. norm_type='bn',
  1158. freeze_at=0,
  1159. return_idx=[0, 1, 2, 3],
  1160. num_stages=4,
  1161. dcn_v2_stages=dcn_v2_stages)
  1162. rpn_in_channel = backbone.out_shape[0].channels
  1163. if with_fpn:
  1164. self.backbone_name = self.backbone_name + '_fpn'
  1165. if 'HRNet' in self.backbone_name:
  1166. neck = ppdet.modeling.HRFPN(
  1167. in_channels=[i.channels for i in backbone.out_shape],
  1168. out_channel=fpn_num_channels,
  1169. spatial_scales=[
  1170. 1.0 / i.stride for i in backbone.out_shape
  1171. ],
  1172. share_conv=False)
  1173. else:
  1174. neck = ppdet.modeling.FPN(
  1175. in_channels=[i.channels for i in backbone.out_shape],
  1176. out_channel=fpn_num_channels,
  1177. spatial_scales=[
  1178. 1.0 / i.stride for i in backbone.out_shape
  1179. ])
  1180. rpn_in_channel = neck.out_shape[0].channels
  1181. anchor_generator_cfg = {
  1182. 'aspect_ratios': aspect_ratios,
  1183. 'anchor_sizes': anchor_sizes,
  1184. 'strides': [4, 8, 16, 32, 64]
  1185. }
  1186. train_proposal_cfg = {
  1187. 'min_size': 0.0,
  1188. 'nms_thresh': .7,
  1189. 'pre_nms_top_n': 2000,
  1190. 'post_nms_top_n': 1000,
  1191. 'topk_after_collect': True
  1192. }
  1193. test_proposal_cfg = {
  1194. 'min_size': 0.0,
  1195. 'nms_thresh': .7,
  1196. 'pre_nms_top_n': 1000
  1197. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1198. 'post_nms_top_n': test_post_nms_top_n
  1199. }
  1200. head = ppdet.modeling.TwoFCHead(
  1201. in_channel=neck.out_shape[0].channels, out_channel=1024)
  1202. roi_extractor_cfg = {
  1203. 'resolution': 7,
  1204. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1205. 'sampling_ratio': 0,
  1206. 'aligned': True
  1207. }
  1208. with_pool = False
  1209. else:
  1210. neck = None
  1211. anchor_generator_cfg = {
  1212. 'aspect_ratios': aspect_ratios,
  1213. 'anchor_sizes': anchor_sizes,
  1214. 'strides': [16]
  1215. }
  1216. train_proposal_cfg = {
  1217. 'min_size': 0.0,
  1218. 'nms_thresh': .7,
  1219. 'pre_nms_top_n': 12000,
  1220. 'post_nms_top_n': 2000,
  1221. 'topk_after_collect': False
  1222. }
  1223. test_proposal_cfg = {
  1224. 'min_size': 0.0,
  1225. 'nms_thresh': .7,
  1226. 'pre_nms_top_n': 6000
  1227. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1228. 'post_nms_top_n': test_post_nms_top_n
  1229. }
  1230. head = ppdet.modeling.Res5Head()
  1231. roi_extractor_cfg = {
  1232. 'resolution': 14,
  1233. 'spatial_scale':
  1234. [1. / i.stride for i in backbone.out_shape],
  1235. 'sampling_ratio': 0,
  1236. 'aligned': True
  1237. }
  1238. with_pool = True
  1239. rpn_target_assign_cfg = {
  1240. 'batch_size_per_im': rpn_batch_size_per_im,
  1241. 'fg_fraction': rpn_fg_fraction,
  1242. 'negative_overlap': .3,
  1243. 'positive_overlap': .7,
  1244. 'use_random': True
  1245. }
  1246. rpn_head = ppdet.modeling.RPNHead(
  1247. anchor_generator=anchor_generator_cfg,
  1248. rpn_target_assign=rpn_target_assign_cfg,
  1249. train_proposal=train_proposal_cfg,
  1250. test_proposal=test_proposal_cfg,
  1251. in_channel=rpn_in_channel)
  1252. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  1253. bbox_head = ppdet.modeling.BBoxHead(
  1254. head=head,
  1255. in_channel=head.out_shape[0].channels,
  1256. roi_extractor=roi_extractor_cfg,
  1257. with_pool=with_pool,
  1258. bbox_assigner=bbox_assigner,
  1259. num_classes=num_classes)
  1260. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  1261. num_classes=num_classes,
  1262. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  1263. nms=ppdet.modeling.MultiClassNMS(
  1264. score_threshold=score_threshold,
  1265. keep_top_k=keep_top_k,
  1266. nms_threshold=nms_threshold))
  1267. params.update({
  1268. 'backbone': backbone,
  1269. 'neck': neck,
  1270. 'rpn_head': rpn_head,
  1271. 'bbox_head': bbox_head,
  1272. 'bbox_post_process': bbox_post_process
  1273. })
  1274. else:
  1275. if backbone not in {'ResNet50', 'ResNet50_vd'}:
  1276. with_fpn = True
  1277. self.with_fpn = with_fpn
  1278. super(FasterRCNN, self).__init__(
  1279. model_name='FasterRCNN', num_classes=num_classes, **params)
  1280. def train(self,
  1281. num_epochs,
  1282. train_dataset,
  1283. train_batch_size=64,
  1284. eval_dataset=None,
  1285. optimizer=None,
  1286. save_interval_epochs=1,
  1287. log_interval_steps=10,
  1288. save_dir='output',
  1289. pretrain_weights='IMAGENET',
  1290. learning_rate=.001,
  1291. warmup_steps=0,
  1292. warmup_start_lr=0.0,
  1293. lr_decay_epochs=(216, 243),
  1294. lr_decay_gamma=0.1,
  1295. metric=None,
  1296. use_ema=False,
  1297. early_stop=False,
  1298. early_stop_patience=5,
  1299. use_vdl=True,
  1300. resume_checkpoint=None):
  1301. """
  1302. Train the model.
  1303. Args:
  1304. num_epochs (int): Number of epochs.
  1305. train_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset):
  1306. Training dataset.
  1307. train_batch_size (int, optional): Total batch size among all cards used in
  1308. training. Defaults to 64.
  1309. eval_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset, optional):
  1310. Evaluation dataset. If None, the model will not be evaluated during training
  1311. process. Defaults to None.
  1312. optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used for
  1313. training. If None, a default optimizer will be used. Defaults to None.
  1314. save_interval_epochs (int, optional): Epoch interval for saving the model.
  1315. Defaults to 1.
  1316. log_interval_steps (int, optional): Step interval for printing training
  1317. information. Defaults to 10.
  1318. save_dir (str, optional): Directory to save the model. Defaults to 'output'.
  1319. pretrain_weights (str|None, optional): None or name/path of pretrained
  1320. weights. If None, no pretrained weights will be loaded.
  1321. Defaults to 'IMAGENET'.
  1322. learning_rate (float, optional): Learning rate for training. Defaults to .001.
  1323. warmup_steps (int, optional): Number of steps of warm-up training.
  1324. Defaults to 0.
  1325. warmup_start_lr (float, optional): Start learning rate of warm-up training.
  1326. Defaults to 0..
  1327. lr_decay_epochs (list|tuple, optional): Epoch milestones for learning
  1328. rate decay. Defaults to (216, 243).
  1329. lr_decay_gamma (float, optional): Gamma coefficient of learning rate decay.
  1330. Defaults to .1.
  1331. metric (str|None, optional): Evaluation metric. Choices are {'VOC', 'COCO', None}.
  1332. If None, determine the metric according to the dataset format.
  1333. Defaults to None.
  1334. use_ema (bool, optional): Whether to use exponential moving average
  1335. strategy. Defaults to False.
  1336. early_stop (bool, optional): Whether to adopt early stop strategy.
  1337. Defaults to False.
  1338. early_stop_patience (int, optional): Early stop patience. Defaults to 5.
  1339. use_vdl(bool, optional): Whether to use VisualDL to monitor the training
  1340. process. Defaults to True.
  1341. resume_checkpoint (str|None, optional): Path of the checkpoint to resume
  1342. training from. If None, no training checkpoint will be resumed. At most
  1343. Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
  1344. Defaults to None.
  1345. """
  1346. if train_dataset.pos_num < len(train_dataset.file_list):
  1347. train_dataset.num_workers = 0
  1348. super(FasterRCNN, self).train(
  1349. num_epochs, train_dataset, train_batch_size, eval_dataset,
  1350. optimizer, save_interval_epochs, log_interval_steps, save_dir,
  1351. pretrain_weights, learning_rate, warmup_steps, warmup_start_lr,
  1352. lr_decay_epochs, lr_decay_gamma, metric, use_ema, early_stop,
  1353. early_stop_patience, use_vdl, resume_checkpoint)
  1354. def _compose_batch_transform(self, transforms, mode='train'):
  1355. if mode == 'train':
  1356. default_batch_transforms = [
  1357. _BatchPad(pad_to_stride=32 if self.with_fpn else -1)
  1358. ]
  1359. else:
  1360. default_batch_transforms = [
  1361. _BatchPad(pad_to_stride=32 if self.with_fpn else -1)
  1362. ]
  1363. custom_batch_transforms = []
  1364. for i, op in enumerate(transforms.transforms):
  1365. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  1366. if mode != 'train':
  1367. raise ValueError(
  1368. "{} cannot be present in the {} transforms. ".format(
  1369. op.__class__.__name__, mode) +
  1370. "Please check the {} transforms.".format(mode))
  1371. custom_batch_transforms.insert(0, copy.deepcopy(op))
  1372. batch_transforms = BatchCompose(
  1373. custom_batch_transforms + default_batch_transforms,
  1374. collate_batch=False)
  1375. return batch_transforms
  1376. def _fix_transforms_shape(self, image_shape):
  1377. if getattr(self, 'test_transforms', None):
  1378. has_resize_op = False
  1379. resize_op_idx = -1
  1380. normalize_op_idx = len(self.test_transforms.transforms)
  1381. for idx, op in enumerate(self.test_transforms.transforms):
  1382. name = op.__class__.__name__
  1383. if name == 'ResizeByShort':
  1384. has_resize_op = True
  1385. resize_op_idx = idx
  1386. if name == 'Normalize':
  1387. normalize_op_idx = idx
  1388. if not has_resize_op:
  1389. self.test_transforms.transforms.insert(
  1390. normalize_op_idx,
  1391. Resize(
  1392. target_size=image_shape,
  1393. keep_ratio=True,
  1394. interp='CUBIC'))
  1395. else:
  1396. self.test_transforms.transforms[resize_op_idx] = Resize(
  1397. target_size=image_shape, keep_ratio=True, interp='CUBIC')
  1398. self.test_transforms.transforms.append(
  1399. Pad(im_padding_value=[0., 0., 0.]))
  1400. def _get_test_inputs(self, image_shape):
  1401. if image_shape is not None:
  1402. image_shape = self._check_image_shape(image_shape)
  1403. self._fix_transforms_shape(image_shape[-2:])
  1404. else:
  1405. image_shape = [None, 3, -1, -1]
  1406. if self.with_fpn:
  1407. self.test_transforms.transforms.append(
  1408. Pad(im_padding_value=[0., 0., 0.]))
  1409. self.fixed_input_shape = image_shape
  1410. return self._define_input_spec(image_shape)
  1411. class PPYOLO(YOLOv3):
  1412. def __init__(self,
  1413. num_classes=80,
  1414. backbone='ResNet50_vd_dcn',
  1415. anchors=None,
  1416. anchor_masks=None,
  1417. use_coord_conv=True,
  1418. use_iou_aware=True,
  1419. use_spp=True,
  1420. use_drop_block=True,
  1421. scale_x_y=1.05,
  1422. ignore_threshold=0.7,
  1423. label_smooth=False,
  1424. use_iou_loss=True,
  1425. use_matrix_nms=True,
  1426. nms_score_threshold=0.01,
  1427. nms_topk=-1,
  1428. nms_keep_topk=100,
  1429. nms_iou_threshold=0.45,
  1430. **params):
  1431. self.init_params = locals()
  1432. if backbone not in {
  1433. 'ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large',
  1434. 'MobileNetV3_small'
  1435. }:
  1436. raise ValueError(
  1437. "backbone: {} is not supported. Please choose one of "
  1438. "{'ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large', 'MobileNetV3_small'}.".
  1439. format(backbone))
  1440. self.backbone_name = backbone
  1441. self.downsample_ratios = [
  1442. 32, 16, 8
  1443. ] if backbone == 'ResNet50_vd_dcn' else [32, 16]
  1444. if params.get('with_net', True):
  1445. if paddlers.env_info['place'] == 'gpu' and paddlers.env_info[
  1446. 'num'] > 1 and not os.environ.get('PADDLERS_EXPORT_STAGE'):
  1447. norm_type = 'sync_bn'
  1448. else:
  1449. norm_type = 'bn'
  1450. if anchors is None and anchor_masks is None:
  1451. if 'MobileNetV3' in backbone:
  1452. anchors = [[11, 18], [34, 47], [51, 126], [115, 71],
  1453. [120, 195], [254, 235]]
  1454. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  1455. elif backbone == 'ResNet50_vd_dcn':
  1456. anchors = [[10, 13], [16, 30], [33, 23], [30, 61],
  1457. [62, 45], [59, 119], [116, 90], [156, 198],
  1458. [373, 326]]
  1459. anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  1460. else:
  1461. anchors = [[10, 14], [23, 27], [37, 58], [81, 82],
  1462. [135, 169], [344, 319]]
  1463. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  1464. elif anchors is None or anchor_masks is None:
  1465. raise ValueError("Please define both anchors and anchor_masks.")
  1466. if backbone == 'ResNet50_vd_dcn':
  1467. backbone = self._get_backbone(
  1468. 'ResNet',
  1469. variant='d',
  1470. norm_type=norm_type,
  1471. return_idx=[1, 2, 3],
  1472. dcn_v2_stages=[3],
  1473. freeze_at=-1,
  1474. freeze_norm=False,
  1475. norm_decay=0.)
  1476. elif backbone == 'ResNet18_vd':
  1477. backbone = self._get_backbone(
  1478. 'ResNet',
  1479. depth=18,
  1480. variant='d',
  1481. norm_type=norm_type,
  1482. return_idx=[2, 3],
  1483. freeze_at=-1,
  1484. freeze_norm=False,
  1485. norm_decay=0.)
  1486. elif backbone == 'MobileNetV3_large':
  1487. backbone = self._get_backbone(
  1488. 'MobileNetV3',
  1489. model_name='large',
  1490. norm_type=norm_type,
  1491. scale=1,
  1492. with_extra_blocks=False,
  1493. extra_block_filters=[],
  1494. feature_maps=[13, 16])
  1495. elif backbone == 'MobileNetV3_small':
  1496. backbone = self._get_backbone(
  1497. 'MobileNetV3',
  1498. model_name='small',
  1499. norm_type=norm_type,
  1500. scale=1,
  1501. with_extra_blocks=False,
  1502. extra_block_filters=[],
  1503. feature_maps=[9, 12])
  1504. neck = ppdet.modeling.PPYOLOFPN(
  1505. norm_type=norm_type,
  1506. in_channels=[i.channels for i in backbone.out_shape],
  1507. coord_conv=use_coord_conv,
  1508. drop_block=use_drop_block,
  1509. spp=use_spp,
  1510. conv_block_num=0
  1511. if ('MobileNetV3' in self.backbone_name or
  1512. self.backbone_name == 'ResNet18_vd') else 2)
  1513. loss = ppdet.modeling.YOLOv3Loss(
  1514. num_classes=num_classes,
  1515. ignore_thresh=ignore_threshold,
  1516. downsample=self.downsample_ratios,
  1517. label_smooth=label_smooth,
  1518. scale_x_y=scale_x_y,
  1519. iou_loss=ppdet.modeling.IouLoss(
  1520. loss_weight=2.5, loss_square=True)
  1521. if use_iou_loss else None,
  1522. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1523. if use_iou_aware else None)
  1524. yolo_head = ppdet.modeling.YOLOv3Head(
  1525. in_channels=[i.channels for i in neck.out_shape],
  1526. anchors=anchors,
  1527. anchor_masks=anchor_masks,
  1528. num_classes=num_classes,
  1529. loss=loss,
  1530. iou_aware=use_iou_aware)
  1531. if use_matrix_nms:
  1532. nms = ppdet.modeling.MatrixNMS(
  1533. keep_top_k=nms_keep_topk,
  1534. score_threshold=nms_score_threshold,
  1535. post_threshold=.05
  1536. if 'MobileNetV3' in self.backbone_name else .01,
  1537. nms_top_k=nms_topk,
  1538. background_label=-1)
  1539. else:
  1540. nms = ppdet.modeling.MultiClassNMS(
  1541. score_threshold=nms_score_threshold,
  1542. nms_top_k=nms_topk,
  1543. keep_top_k=nms_keep_topk,
  1544. nms_threshold=nms_iou_threshold)
  1545. post_process = ppdet.modeling.BBoxPostProcess(
  1546. decode=ppdet.modeling.YOLOBox(
  1547. num_classes=num_classes,
  1548. conf_thresh=.005
  1549. if 'MobileNetV3' in self.backbone_name else .01,
  1550. scale_x_y=scale_x_y),
  1551. nms=nms)
  1552. params.update({
  1553. 'backbone': backbone,
  1554. 'neck': neck,
  1555. 'yolo_head': yolo_head,
  1556. 'post_process': post_process
  1557. })
  1558. super(YOLOv3, self).__init__(
  1559. model_name='YOLOv3', num_classes=num_classes, **params)
  1560. self.anchors = anchors
  1561. self.anchor_masks = anchor_masks
  1562. self.model_name = 'PPYOLO'
  1563. def _get_test_inputs(self, image_shape):
  1564. if image_shape is not None:
  1565. image_shape = self._check_image_shape(image_shape)
  1566. self._fix_transforms_shape(image_shape[-2:])
  1567. else:
  1568. image_shape = [None, 3, 608, 608]
  1569. if getattr(self, 'test_transforms', None):
  1570. for idx, op in enumerate(self.test_transforms.transforms):
  1571. name = op.__class__.__name__
  1572. if name == 'Resize':
  1573. image_shape = [None, 3] + list(
  1574. self.test_transforms.transforms[idx].target_size)
  1575. logging.warning(
  1576. '[Important!!!] When exporting inference model for {}, '
  1577. 'if fixed_input_shape is not set, it will be forcibly set to {}. '
  1578. 'Please ensure image shape after transforms is {}, if not, '
  1579. 'fixed_input_shape should be specified manually.'
  1580. .format(self.__class__.__name__, image_shape, image_shape[1:]))
  1581. self.fixed_input_shape = image_shape
  1582. return self._define_input_spec(image_shape)
  1583. class PPYOLOTiny(YOLOv3):
  1584. def __init__(self,
  1585. num_classes=80,
  1586. backbone='MobileNetV3',
  1587. anchors=[[10, 15], [24, 36], [72, 42], [35, 87], [102, 96],
  1588. [60, 170], [220, 125], [128, 222], [264, 266]],
  1589. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1590. use_iou_aware=False,
  1591. use_spp=True,
  1592. use_drop_block=True,
  1593. scale_x_y=1.05,
  1594. ignore_threshold=0.5,
  1595. label_smooth=False,
  1596. use_iou_loss=True,
  1597. use_matrix_nms=False,
  1598. nms_score_threshold=0.005,
  1599. nms_topk=1000,
  1600. nms_keep_topk=100,
  1601. nms_iou_threshold=0.45,
  1602. **params):
  1603. self.init_params = locals()
  1604. if backbone != 'MobileNetV3':
  1605. logging.warning("PPYOLOTiny only supports MobileNetV3 as backbone. "
  1606. "Backbone is forcibly set to MobileNetV3.")
  1607. self.backbone_name = 'MobileNetV3'
  1608. self.downsample_ratios = [32, 16, 8]
  1609. if params.get('with_net', True):
  1610. if paddlers.env_info['place'] == 'gpu' and paddlers.env_info[
  1611. 'num'] > 1 and not os.environ.get('PADDLERS_EXPORT_STAGE'):
  1612. norm_type = 'sync_bn'
  1613. else:
  1614. norm_type = 'bn'
  1615. backbone = self._get_backbone(
  1616. 'MobileNetV3',
  1617. model_name='large',
  1618. norm_type=norm_type,
  1619. scale=.5,
  1620. with_extra_blocks=False,
  1621. extra_block_filters=[],
  1622. feature_maps=[7, 13, 16])
  1623. neck = ppdet.modeling.PPYOLOTinyFPN(
  1624. detection_block_channels=[160, 128, 96],
  1625. in_channels=[i.channels for i in backbone.out_shape],
  1626. spp=use_spp,
  1627. drop_block=use_drop_block)
  1628. loss = ppdet.modeling.YOLOv3Loss(
  1629. num_classes=num_classes,
  1630. ignore_thresh=ignore_threshold,
  1631. downsample=self.downsample_ratios,
  1632. label_smooth=label_smooth,
  1633. scale_x_y=scale_x_y,
  1634. iou_loss=ppdet.modeling.IouLoss(
  1635. loss_weight=2.5, loss_square=True)
  1636. if use_iou_loss else None,
  1637. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1638. if use_iou_aware else None)
  1639. yolo_head = ppdet.modeling.YOLOv3Head(
  1640. in_channels=[i.channels for i in neck.out_shape],
  1641. anchors=anchors,
  1642. anchor_masks=anchor_masks,
  1643. num_classes=num_classes,
  1644. loss=loss,
  1645. iou_aware=use_iou_aware)
  1646. if use_matrix_nms:
  1647. nms = ppdet.modeling.MatrixNMS(
  1648. keep_top_k=nms_keep_topk,
  1649. score_threshold=nms_score_threshold,
  1650. post_threshold=.05,
  1651. nms_top_k=nms_topk,
  1652. background_label=-1)
  1653. else:
  1654. nms = ppdet.modeling.MultiClassNMS(
  1655. score_threshold=nms_score_threshold,
  1656. nms_top_k=nms_topk,
  1657. keep_top_k=nms_keep_topk,
  1658. nms_threshold=nms_iou_threshold)
  1659. post_process = ppdet.modeling.BBoxPostProcess(
  1660. decode=ppdet.modeling.YOLOBox(
  1661. num_classes=num_classes,
  1662. conf_thresh=.005,
  1663. downsample_ratio=32,
  1664. clip_bbox=True,
  1665. scale_x_y=scale_x_y),
  1666. nms=nms)
  1667. params.update({
  1668. 'backbone': backbone,
  1669. 'neck': neck,
  1670. 'yolo_head': yolo_head,
  1671. 'post_process': post_process
  1672. })
  1673. super(YOLOv3, self).__init__(
  1674. model_name='YOLOv3', num_classes=num_classes, **params)
  1675. self.anchors = anchors
  1676. self.anchor_masks = anchor_masks
  1677. self.model_name = 'PPYOLOTiny'
  1678. def _get_test_inputs(self, image_shape):
  1679. if image_shape is not None:
  1680. image_shape = self._check_image_shape(image_shape)
  1681. self._fix_transforms_shape(image_shape[-2:])
  1682. else:
  1683. image_shape = [None, 3, 320, 320]
  1684. if getattr(self, 'test_transforms', None):
  1685. for idx, op in enumerate(self.test_transforms.transforms):
  1686. name = op.__class__.__name__
  1687. if name == 'Resize':
  1688. image_shape = [None, 3] + list(
  1689. self.test_transforms.transforms[idx].target_size)
  1690. logging.warning(
  1691. '[Important!!!] When exporting inference model for {},'.format(
  1692. self.__class__.__name__) +
  1693. ' if fixed_input_shape is not set, it will be forcibly set to {}. '.
  1694. format(image_shape) +
  1695. 'Please check image shape after transforms is {}, if not, fixed_input_shape '.
  1696. format(image_shape[1:]) + 'should be specified manually.')
  1697. self.fixed_input_shape = image_shape
  1698. return self._define_input_spec(image_shape)
  1699. class PPYOLOv2(YOLOv3):
  1700. def __init__(self,
  1701. num_classes=80,
  1702. backbone='ResNet50_vd_dcn',
  1703. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  1704. [59, 119], [116, 90], [156, 198], [373, 326]],
  1705. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1706. use_iou_aware=True,
  1707. use_spp=True,
  1708. use_drop_block=True,
  1709. scale_x_y=1.05,
  1710. ignore_threshold=0.7,
  1711. label_smooth=False,
  1712. use_iou_loss=True,
  1713. use_matrix_nms=True,
  1714. nms_score_threshold=0.01,
  1715. nms_topk=-1,
  1716. nms_keep_topk=100,
  1717. nms_iou_threshold=0.45,
  1718. **params):
  1719. self.init_params = locals()
  1720. if backbone not in {'ResNet50_vd_dcn', 'ResNet101_vd_dcn'}:
  1721. raise ValueError(
  1722. "backbone: {} is not supported. Please choose one of "
  1723. "{'ResNet50_vd_dcn', 'ResNet101_vd_dcn'}.".format(backbone))
  1724. self.backbone_name = backbone
  1725. self.downsample_ratios = [32, 16, 8]
  1726. if params.get('with_net', True):
  1727. if paddlers.env_info['place'] == 'gpu' and paddlers.env_info[
  1728. 'num'] > 1 and not os.environ.get('PADDLERS_EXPORT_STAGE'):
  1729. norm_type = 'sync_bn'
  1730. else:
  1731. norm_type = 'bn'
  1732. if backbone == 'ResNet50_vd_dcn':
  1733. backbone = self._get_backbone(
  1734. 'ResNet',
  1735. variant='d',
  1736. norm_type=norm_type,
  1737. return_idx=[1, 2, 3],
  1738. dcn_v2_stages=[3],
  1739. freeze_at=-1,
  1740. freeze_norm=False,
  1741. norm_decay=0.)
  1742. elif backbone == 'ResNet101_vd_dcn':
  1743. backbone = self._get_backbone(
  1744. 'ResNet',
  1745. depth=101,
  1746. variant='d',
  1747. norm_type=norm_type,
  1748. return_idx=[1, 2, 3],
  1749. dcn_v2_stages=[3],
  1750. freeze_at=-1,
  1751. freeze_norm=False,
  1752. norm_decay=0.)
  1753. neck = ppdet.modeling.PPYOLOPAN(
  1754. norm_type=norm_type,
  1755. in_channels=[i.channels for i in backbone.out_shape],
  1756. drop_block=use_drop_block,
  1757. block_size=3,
  1758. keep_prob=.9,
  1759. spp=use_spp)
  1760. loss = ppdet.modeling.YOLOv3Loss(
  1761. num_classes=num_classes,
  1762. ignore_thresh=ignore_threshold,
  1763. downsample=self.downsample_ratios,
  1764. label_smooth=label_smooth,
  1765. scale_x_y=scale_x_y,
  1766. iou_loss=ppdet.modeling.IouLoss(
  1767. loss_weight=2.5, loss_square=True)
  1768. if use_iou_loss else None,
  1769. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1770. if use_iou_aware else None)
  1771. yolo_head = ppdet.modeling.YOLOv3Head(
  1772. in_channels=[i.channels for i in neck.out_shape],
  1773. anchors=anchors,
  1774. anchor_masks=anchor_masks,
  1775. num_classes=num_classes,
  1776. loss=loss,
  1777. iou_aware=use_iou_aware,
  1778. iou_aware_factor=.5)
  1779. if use_matrix_nms:
  1780. nms = ppdet.modeling.MatrixNMS(
  1781. keep_top_k=nms_keep_topk,
  1782. score_threshold=nms_score_threshold,
  1783. post_threshold=.01,
  1784. nms_top_k=nms_topk,
  1785. background_label=-1)
  1786. else:
  1787. nms = ppdet.modeling.MultiClassNMS(
  1788. score_threshold=nms_score_threshold,
  1789. nms_top_k=nms_topk,
  1790. keep_top_k=nms_keep_topk,
  1791. nms_threshold=nms_iou_threshold)
  1792. post_process = ppdet.modeling.BBoxPostProcess(
  1793. decode=ppdet.modeling.YOLOBox(
  1794. num_classes=num_classes,
  1795. conf_thresh=.01,
  1796. downsample_ratio=32,
  1797. clip_bbox=True,
  1798. scale_x_y=scale_x_y),
  1799. nms=nms)
  1800. params.update({
  1801. 'backbone': backbone,
  1802. 'neck': neck,
  1803. 'yolo_head': yolo_head,
  1804. 'post_process': post_process
  1805. })
  1806. super(YOLOv3, self).__init__(
  1807. model_name='YOLOv3', num_classes=num_classes, **params)
  1808. self.anchors = anchors
  1809. self.anchor_masks = anchor_masks
  1810. self.model_name = 'PPYOLOv2'
  1811. def _get_test_inputs(self, image_shape):
  1812. if image_shape is not None:
  1813. image_shape = self._check_image_shape(image_shape)
  1814. self._fix_transforms_shape(image_shape[-2:])
  1815. else:
  1816. image_shape = [None, 3, 640, 640]
  1817. if getattr(self, 'test_transforms', None):
  1818. for idx, op in enumerate(self.test_transforms.transforms):
  1819. name = op.__class__.__name__
  1820. if name == 'Resize':
  1821. image_shape = [None, 3] + list(
  1822. self.test_transforms.transforms[idx].target_size)
  1823. logging.warning(
  1824. '[Important!!!] When exporting inference model for {},'.format(
  1825. self.__class__.__name__) +
  1826. ' if fixed_input_shape is not set, it will be forcibly set to {}. '.
  1827. format(image_shape) +
  1828. 'Please check image shape after transforms is {}, if not, fixed_input_shape '.
  1829. format(image_shape[1:]) + 'should be specified manually.')
  1830. self.fixed_input_shape = image_shape
  1831. return self._define_input_spec(image_shape)
  1832. class MaskRCNN(BaseDetector):
  1833. def __init__(self,
  1834. num_classes=80,
  1835. backbone='ResNet50_vd',
  1836. with_fpn=True,
  1837. with_dcn=False,
  1838. aspect_ratios=[0.5, 1.0, 2.0],
  1839. anchor_sizes=[[32], [64], [128], [256], [512]],
  1840. keep_top_k=100,
  1841. nms_threshold=0.5,
  1842. score_threshold=0.05,
  1843. fpn_num_channels=256,
  1844. rpn_batch_size_per_im=256,
  1845. rpn_fg_fraction=0.5,
  1846. test_pre_nms_top_n=None,
  1847. test_post_nms_top_n=1000,
  1848. **params):
  1849. self.init_params = locals()
  1850. if backbone not in {
  1851. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101',
  1852. 'ResNet101_vd'
  1853. }:
  1854. raise ValueError(
  1855. "backbone: {} is not supported. Please choose one of "
  1856. "{'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101', 'ResNet101_vd'}.".
  1857. format(backbone))
  1858. self.backbone_name = backbone + '_fpn' if with_fpn else backbone
  1859. dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
  1860. if params.get('with_net', True):
  1861. if backbone == 'ResNet50':
  1862. if with_fpn:
  1863. backbone = self._get_backbone(
  1864. 'ResNet',
  1865. norm_type='bn',
  1866. freeze_at=0,
  1867. return_idx=[0, 1, 2, 3],
  1868. num_stages=4,
  1869. dcn_v2_stages=dcn_v2_stages)
  1870. else:
  1871. if with_dcn:
  1872. logging.warning(
  1873. "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  1874. format(backbone))
  1875. backbone = self._get_backbone(
  1876. 'ResNet',
  1877. norm_type='bn',
  1878. freeze_at=0,
  1879. return_idx=[2],
  1880. num_stages=3)
  1881. elif 'ResNet50_vd' in backbone:
  1882. if not with_fpn:
  1883. logging.warning(
  1884. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1885. format(backbone))
  1886. with_fpn = True
  1887. backbone = self._get_backbone(
  1888. 'ResNet',
  1889. variant='d',
  1890. norm_type='bn',
  1891. freeze_at=0,
  1892. return_idx=[0, 1, 2, 3],
  1893. num_stages=4,
  1894. lr_mult_list=[0.05, 0.05, 0.1, 0.15]
  1895. if '_ssld' in backbone else [1.0, 1.0, 1.0, 1.0],
  1896. dcn_v2_stages=dcn_v2_stages)
  1897. else:
  1898. if not with_fpn:
  1899. logging.warning(
  1900. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1901. format(backbone))
  1902. with_fpn = True
  1903. backbone = self._get_backbone(
  1904. 'ResNet',
  1905. variant='d' if '_vd' in backbone else 'b',
  1906. depth=101,
  1907. norm_type='bn',
  1908. freeze_at=0,
  1909. return_idx=[0, 1, 2, 3],
  1910. num_stages=4,
  1911. dcn_v2_stages=dcn_v2_stages)
  1912. rpn_in_channel = backbone.out_shape[0].channels
  1913. if with_fpn:
  1914. neck = ppdet.modeling.FPN(
  1915. in_channels=[i.channels for i in backbone.out_shape],
  1916. out_channel=fpn_num_channels,
  1917. spatial_scales=[
  1918. 1.0 / i.stride for i in backbone.out_shape
  1919. ])
  1920. rpn_in_channel = neck.out_shape[0].channels
  1921. anchor_generator_cfg = {
  1922. 'aspect_ratios': aspect_ratios,
  1923. 'anchor_sizes': anchor_sizes,
  1924. 'strides': [4, 8, 16, 32, 64]
  1925. }
  1926. train_proposal_cfg = {
  1927. 'min_size': 0.0,
  1928. 'nms_thresh': .7,
  1929. 'pre_nms_top_n': 2000,
  1930. 'post_nms_top_n': 1000,
  1931. 'topk_after_collect': True
  1932. }
  1933. test_proposal_cfg = {
  1934. 'min_size': 0.0,
  1935. 'nms_thresh': .7,
  1936. 'pre_nms_top_n': 1000
  1937. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1938. 'post_nms_top_n': test_post_nms_top_n
  1939. }
  1940. bb_head = ppdet.modeling.TwoFCHead(
  1941. in_channel=neck.out_shape[0].channels, out_channel=1024)
  1942. bb_roi_extractor_cfg = {
  1943. 'resolution': 7,
  1944. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1945. 'sampling_ratio': 0,
  1946. 'aligned': True
  1947. }
  1948. with_pool = False
  1949. m_head = ppdet.modeling.MaskFeat(
  1950. in_channel=neck.out_shape[0].channels,
  1951. out_channel=256,
  1952. num_convs=4)
  1953. m_roi_extractor_cfg = {
  1954. 'resolution': 14,
  1955. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1956. 'sampling_ratio': 0,
  1957. 'aligned': True
  1958. }
  1959. mask_assigner = MaskAssigner(
  1960. num_classes=num_classes, mask_resolution=28)
  1961. share_bbox_feat = False
  1962. else:
  1963. neck = None
  1964. anchor_generator_cfg = {
  1965. 'aspect_ratios': aspect_ratios,
  1966. 'anchor_sizes': anchor_sizes,
  1967. 'strides': [16]
  1968. }
  1969. train_proposal_cfg = {
  1970. 'min_size': 0.0,
  1971. 'nms_thresh': .7,
  1972. 'pre_nms_top_n': 12000,
  1973. 'post_nms_top_n': 2000,
  1974. 'topk_after_collect': False
  1975. }
  1976. test_proposal_cfg = {
  1977. 'min_size': 0.0,
  1978. 'nms_thresh': .7,
  1979. 'pre_nms_top_n': 6000
  1980. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1981. 'post_nms_top_n': test_post_nms_top_n
  1982. }
  1983. bb_head = ppdet.modeling.Res5Head()
  1984. bb_roi_extractor_cfg = {
  1985. 'resolution': 14,
  1986. 'spatial_scale':
  1987. [1. / i.stride for i in backbone.out_shape],
  1988. 'sampling_ratio': 0,
  1989. 'aligned': True
  1990. }
  1991. with_pool = True
  1992. m_head = ppdet.modeling.MaskFeat(
  1993. in_channel=bb_head.out_shape[0].channels,
  1994. out_channel=256,
  1995. num_convs=0)
  1996. m_roi_extractor_cfg = {
  1997. 'resolution': 14,
  1998. 'spatial_scale':
  1999. [1. / i.stride for i in backbone.out_shape],
  2000. 'sampling_ratio': 0,
  2001. 'aligned': True
  2002. }
  2003. mask_assigner = MaskAssigner(
  2004. num_classes=num_classes, mask_resolution=14)
  2005. share_bbox_feat = True
  2006. rpn_target_assign_cfg = {
  2007. 'batch_size_per_im': rpn_batch_size_per_im,
  2008. 'fg_fraction': rpn_fg_fraction,
  2009. 'negative_overlap': .3,
  2010. 'positive_overlap': .7,
  2011. 'use_random': True
  2012. }
  2013. rpn_head = ppdet.modeling.RPNHead(
  2014. anchor_generator=anchor_generator_cfg,
  2015. rpn_target_assign=rpn_target_assign_cfg,
  2016. train_proposal=train_proposal_cfg,
  2017. test_proposal=test_proposal_cfg,
  2018. in_channel=rpn_in_channel)
  2019. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  2020. bbox_head = ppdet.modeling.BBoxHead(
  2021. head=bb_head,
  2022. in_channel=bb_head.out_shape[0].channels,
  2023. roi_extractor=bb_roi_extractor_cfg,
  2024. with_pool=with_pool,
  2025. bbox_assigner=bbox_assigner,
  2026. num_classes=num_classes)
  2027. mask_head = ppdet.modeling.MaskHead(
  2028. head=m_head,
  2029. roi_extractor=m_roi_extractor_cfg,
  2030. mask_assigner=mask_assigner,
  2031. share_bbox_feat=share_bbox_feat,
  2032. num_classes=num_classes)
  2033. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  2034. num_classes=num_classes,
  2035. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  2036. nms=ppdet.modeling.MultiClassNMS(
  2037. score_threshold=score_threshold,
  2038. keep_top_k=keep_top_k,
  2039. nms_threshold=nms_threshold))
  2040. mask_post_process = ppdet.modeling.MaskPostProcess(binary_thresh=.5)
  2041. params.update({
  2042. 'backbone': backbone,
  2043. 'neck': neck,
  2044. 'rpn_head': rpn_head,
  2045. 'bbox_head': bbox_head,
  2046. 'mask_head': mask_head,
  2047. 'bbox_post_process': bbox_post_process,
  2048. 'mask_post_process': mask_post_process
  2049. })
  2050. self.with_fpn = with_fpn
  2051. super(MaskRCNN, self).__init__(
  2052. model_name='MaskRCNN', num_classes=num_classes, **params)
  2053. def train(self,
  2054. num_epochs,
  2055. train_dataset,
  2056. train_batch_size=64,
  2057. eval_dataset=None,
  2058. optimizer=None,
  2059. save_interval_epochs=1,
  2060. log_interval_steps=10,
  2061. save_dir='output',
  2062. pretrain_weights='IMAGENET',
  2063. learning_rate=.001,
  2064. warmup_steps=0,
  2065. warmup_start_lr=0.0,
  2066. lr_decay_epochs=(216, 243),
  2067. lr_decay_gamma=0.1,
  2068. metric=None,
  2069. use_ema=False,
  2070. early_stop=False,
  2071. early_stop_patience=5,
  2072. use_vdl=True,
  2073. resume_checkpoint=None):
  2074. """
  2075. Train the model.
  2076. Args:
  2077. num_epochs (int): Number of epochs.
  2078. train_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset):
  2079. Training dataset.
  2080. train_batch_size (int, optional): Total batch size among all cards used in
  2081. training. Defaults to 64.
  2082. eval_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset, optional):
  2083. Evaluation dataset. If None, the model will not be evaluated during training
  2084. process. Defaults to None.
  2085. optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used for
  2086. training. If None, a default optimizer will be used. Defaults to None.
  2087. save_interval_epochs (int, optional): Epoch interval for saving the model.
  2088. Defaults to 1.
  2089. log_interval_steps (int, optional): Step interval for printing training
  2090. information. Defaults to 10.
  2091. save_dir (str, optional): Directory to save the model. Defaults to 'output'.
  2092. pretrain_weights (str|None, optional): None or name/path of pretrained
  2093. weights. If None, no pretrained weights will be loaded.
  2094. Defaults to 'IMAGENET'.
  2095. learning_rate (float, optional): Learning rate for training. Defaults to .001.
  2096. warmup_steps (int, optional): Number of steps of warm-up training.
  2097. Defaults to 0.
  2098. warmup_start_lr (float, optional): Start learning rate of warm-up training.
  2099. Defaults to 0..
  2100. lr_decay_epochs (list|tuple, optional): Epoch milestones for learning
  2101. rate decay. Defaults to (216, 243).
  2102. lr_decay_gamma (float, optional): Gamma coefficient of learning rate decay.
  2103. Defaults to .1.
  2104. metric (str|None, optional): Evaluation metric. Choices are {'VOC', 'COCO', None}.
  2105. If None, determine the metric according to the dataset format.
  2106. Defaults to None.
  2107. use_ema (bool, optional): Whether to use exponential moving average
  2108. strategy. Defaults to False.
  2109. early_stop (bool, optional): Whether to adopt early stop strategy.
  2110. Defaults to False.
  2111. early_stop_patience (int, optional): Early stop patience. Defaults to 5.
  2112. use_vdl(bool, optional): Whether to use VisualDL to monitor the training
  2113. process. Defaults to True.
  2114. resume_checkpoint (str|None, optional): Path of the checkpoint to resume
  2115. training from. If None, no training checkpoint will be resumed. At most
  2116. Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
  2117. Defaults to None.
  2118. """
  2119. if train_dataset.pos_num < len(train_dataset.file_list):
  2120. train_dataset.num_workers = 0
  2121. super(MaskRCNN, self).train(
  2122. num_epochs, train_dataset, train_batch_size, eval_dataset,
  2123. optimizer, save_interval_epochs, log_interval_steps, save_dir,
  2124. pretrain_weights, learning_rate, warmup_steps, warmup_start_lr,
  2125. lr_decay_epochs, lr_decay_gamma, metric, use_ema, early_stop,
  2126. early_stop_patience, use_vdl, resume_checkpoint)
  2127. def _compose_batch_transform(self, transforms, mode='train'):
  2128. if mode == 'train':
  2129. default_batch_transforms = [
  2130. _BatchPad(pad_to_stride=32 if self.with_fpn else -1)
  2131. ]
  2132. else:
  2133. default_batch_transforms = [
  2134. _BatchPad(pad_to_stride=32 if self.with_fpn else -1)
  2135. ]
  2136. custom_batch_transforms = []
  2137. for i, op in enumerate(transforms.transforms):
  2138. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  2139. if mode != 'train':
  2140. raise ValueError(
  2141. "{} cannot be present in the {} transforms. ".format(
  2142. op.__class__.__name__, mode) +
  2143. "Please check the {} transforms.".format(mode))
  2144. custom_batch_transforms.insert(0, copy.deepcopy(op))
  2145. batch_transforms = BatchCompose(
  2146. custom_batch_transforms + default_batch_transforms,
  2147. collate_batch=False)
  2148. return batch_transforms
  2149. def _fix_transforms_shape(self, image_shape):
  2150. if getattr(self, 'test_transforms', None):
  2151. has_resize_op = False
  2152. resize_op_idx = -1
  2153. normalize_op_idx = len(self.test_transforms.transforms)
  2154. for idx, op in enumerate(self.test_transforms.transforms):
  2155. name = op.__class__.__name__
  2156. if name == 'ResizeByShort':
  2157. has_resize_op = True
  2158. resize_op_idx = idx
  2159. if name == 'Normalize':
  2160. normalize_op_idx = idx
  2161. if not has_resize_op:
  2162. self.test_transforms.transforms.insert(
  2163. normalize_op_idx,
  2164. Resize(
  2165. target_size=image_shape,
  2166. keep_ratio=True,
  2167. interp='CUBIC'))
  2168. else:
  2169. self.test_transforms.transforms[resize_op_idx] = Resize(
  2170. target_size=image_shape, keep_ratio=True, interp='CUBIC')
  2171. self.test_transforms.transforms.append(
  2172. Pad(im_padding_value=[0., 0., 0.]))
  2173. def _get_test_inputs(self, image_shape):
  2174. if image_shape is not None:
  2175. image_shape = self._check_image_shape(image_shape)
  2176. self._fix_transforms_shape(image_shape[-2:])
  2177. else:
  2178. image_shape = [None, 3, -1, -1]
  2179. if self.with_fpn:
  2180. self.test_transforms.transforms.append(
  2181. Pad(im_padding_value=[0., 0., 0.]))
  2182. self.fixed_input_shape = image_shape
  2183. return self._define_input_spec(image_shape)