base.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import time
  17. import copy
  18. import math
  19. import json
  20. from functools import partial, wraps
  21. from inspect import signature
  22. import yaml
  23. import paddle
  24. from paddle.io import DataLoader, DistributedBatchSampler
  25. from paddleslim import QAT
  26. from paddleslim.analysis import flops
  27. from paddleslim import L1NormFilterPruner, FPGMFilterPruner
  28. import paddlers
  29. import paddlers.utils.logging as logging
  30. from paddlers.utils import (
  31. seconds_to_hms, get_single_card_bs, dict2str, get_pretrain_weights,
  32. load_pretrain_weights, load_checkpoint, SmoothedValue, TrainingStats,
  33. _get_shared_memory_size_in_M, EarlyStop, to_data_parallel, scheduler_step)
  34. from .slim.prune import _pruner_eval_fn, _pruner_template_input, sensitive_prune
  35. class ModelMeta(type):
  36. def __new__(cls, name, bases, attrs):
  37. def _deco(init_func):
  38. @wraps(init_func)
  39. def _wrapper(self, *args, **kwargs):
  40. if hasattr(self, '_raw_params'):
  41. ret = init_func(self, *args, **kwargs)
  42. else:
  43. sig = signature(init_func)
  44. bnd_args = sig.bind(self, *args, **kwargs)
  45. raw_params = bnd_args.arguments
  46. raw_params.pop('self')
  47. self._raw_params = raw_params
  48. ret = init_func(self, *args, **kwargs)
  49. return ret
  50. return _wrapper
  51. old_init_func = attrs['__init__']
  52. attrs['__init__'] = _deco(old_init_func)
  53. return type.__new__(cls, name, bases, attrs)
  54. class BaseModel(metaclass=ModelMeta):
  55. def __init__(self, model_type):
  56. self.model_type = model_type
  57. self.in_channels = None
  58. self.num_classes = None
  59. self.labels = None
  60. self.version = paddlers.__version__
  61. self.net = None
  62. self.optimizer = None
  63. self.test_inputs = None
  64. self.train_data_loader = None
  65. self.eval_data_loader = None
  66. self.eval_metrics = None
  67. self.best_accuracy = -1.
  68. self.best_model_epoch = -1
  69. # Whether to use synchronized BN
  70. self.sync_bn = False
  71. self.status = 'Normal'
  72. # The initial epoch when training is resumed
  73. self.completed_epochs = 0
  74. self.pruner = None
  75. self.pruning_ratios = None
  76. self.quantizer = None
  77. self.quant_config = None
  78. self.fixed_input_shape = None
  79. def initialize_net(self,
  80. pretrain_weights=None,
  81. save_dir='.',
  82. resume_checkpoint=None,
  83. is_backbone_weights=False):
  84. if pretrain_weights is not None and \
  85. not osp.exists(pretrain_weights):
  86. if not osp.isdir(save_dir):
  87. if osp.exists(save_dir):
  88. os.remove(save_dir)
  89. os.makedirs(save_dir)
  90. if self.model_type == 'classifier':
  91. pretrain_weights = get_pretrain_weights(
  92. pretrain_weights, self.model_name, save_dir)
  93. else:
  94. backbone_name = getattr(self, 'backbone_name', None)
  95. pretrain_weights = get_pretrain_weights(
  96. pretrain_weights,
  97. self.__class__.__name__,
  98. save_dir,
  99. backbone_name=backbone_name)
  100. if pretrain_weights is not None:
  101. if is_backbone_weights:
  102. load_pretrain_weights(
  103. self.net.backbone,
  104. pretrain_weights,
  105. model_name='backbone of ' + self.model_name)
  106. else:
  107. load_pretrain_weights(
  108. self.net, pretrain_weights, model_name=self.model_name)
  109. if resume_checkpoint is not None:
  110. if not osp.exists(resume_checkpoint):
  111. logging.error(
  112. "The checkpoint path {} to resume training from does not exist."
  113. .format(resume_checkpoint),
  114. exit=True)
  115. if not osp.exists(osp.join(resume_checkpoint, 'model.pdparams')):
  116. logging.error(
  117. "Model parameter state dictionary file 'model.pdparams' "
  118. "was not found in given checkpoint path {}!".format(
  119. resume_checkpoint),
  120. exit=True)
  121. if not osp.exists(osp.join(resume_checkpoint, 'model.pdopt')):
  122. logging.error(
  123. "Optimizer state dictionary file 'model.pdparams' "
  124. "was not found in given checkpoint path {}!".format(
  125. resume_checkpoint),
  126. exit=True)
  127. if not osp.exists(osp.join(resume_checkpoint, 'model.yml')):
  128. logging.error(
  129. "'model.yml' was not found in given checkpoint path {}!".
  130. format(resume_checkpoint),
  131. exit=True)
  132. with open(osp.join(resume_checkpoint, "model.yml")) as f:
  133. info = yaml.load(f.read(), Loader=yaml.Loader)
  134. self.completed_epochs = info['completed_epochs']
  135. self.best_accuracy = info['_Attributes']['best_accuracy']
  136. self.best_model_epoch = info['_Attributes']['best_model_epoch']
  137. load_checkpoint(
  138. self.net,
  139. self.optimizer,
  140. model_name=self.model_name,
  141. checkpoint=resume_checkpoint)
  142. def get_model_info(self, get_raw_params=False, inplace=True):
  143. if inplace:
  144. init_params = self.init_params
  145. else:
  146. init_params = copy.deepcopy(self.init_params)
  147. info = dict()
  148. info['version'] = paddlers.__version__
  149. info['Model'] = self.__class__.__name__
  150. info['_Attributes'] = dict(
  151. [('model_type', self.model_type), ('in_channels', self.in_channels),
  152. ('num_classes', self.num_classes), ('labels', self.labels),
  153. ('fixed_input_shape', self.fixed_input_shape),
  154. ('best_accuracy', self.best_accuracy),
  155. ('best_model_epoch', self.best_model_epoch)])
  156. if 'self' in init_params:
  157. del init_params['self']
  158. if '__class__' in init_params:
  159. del init_params['__class__']
  160. if 'model_name' in init_params:
  161. del init_params['model_name']
  162. if 'params' in init_params:
  163. del init_params['params']
  164. info['_init_params'] = init_params
  165. if get_raw_params:
  166. info['raw_params'] = self._raw_params
  167. try:
  168. primary_metric_key = list(self.eval_metrics.keys())[0]
  169. primary_metric_value = float(self.eval_metrics[primary_metric_key])
  170. info['_Attributes']['eval_metrics'] = {
  171. primary_metric_key: primary_metric_value
  172. }
  173. except:
  174. pass
  175. if hasattr(self, 'test_transforms'):
  176. if self.test_transforms is not None:
  177. info['Transforms'] = list()
  178. for op in self.test_transforms.transforms:
  179. name = op.__class__.__name__
  180. attr = op.__dict__
  181. info['Transforms'].append({name: attr})
  182. arrange = self.test_transforms.arrange
  183. if arrange is not None:
  184. info['Transforms'].append({
  185. arrange.__class__.__name__: {
  186. 'mode': 'test'
  187. }
  188. })
  189. info['completed_epochs'] = self.completed_epochs
  190. return info
  191. def get_pruning_info(self):
  192. info = dict()
  193. info['pruner'] = self.pruner.__class__.__name__
  194. info['pruning_ratios'] = self.pruning_ratios
  195. pruner_inputs = self.pruner.inputs
  196. if self.model_type == 'detector':
  197. pruner_inputs = {k: v.tolist() for k, v in pruner_inputs[0].items()}
  198. info['pruner_inputs'] = pruner_inputs
  199. return info
  200. def get_quant_info(self):
  201. info = dict()
  202. info['quant_config'] = self.quant_config
  203. return info
  204. def save_model(self, save_dir):
  205. if not osp.isdir(save_dir):
  206. if osp.exists(save_dir):
  207. os.remove(save_dir)
  208. os.makedirs(save_dir)
  209. model_info = self.get_model_info(get_raw_params=True)
  210. model_info['status'] = self.status
  211. paddle.save(self.net.state_dict(), osp.join(save_dir, 'model.pdparams'))
  212. paddle.save(self.optimizer.state_dict(),
  213. osp.join(save_dir, 'model.pdopt'))
  214. with open(
  215. osp.join(save_dir, 'model.yml'), encoding='utf-8',
  216. mode='w') as f:
  217. yaml.dump(model_info, f)
  218. # Save evaluation details
  219. if hasattr(self, 'eval_details'):
  220. with open(osp.join(save_dir, 'eval_details.json'), 'w') as f:
  221. json.dump(self.eval_details, f)
  222. if self.status == 'Pruned' and self.pruner is not None:
  223. pruning_info = self.get_pruning_info()
  224. with open(
  225. osp.join(save_dir, 'prune.yml'), encoding='utf-8',
  226. mode='w') as f:
  227. yaml.dump(pruning_info, f)
  228. if self.status == 'Quantized' and self.quantizer is not None:
  229. quant_info = self.get_quant_info()
  230. with open(
  231. osp.join(save_dir, 'quant.yml'), encoding='utf-8',
  232. mode='w') as f:
  233. yaml.dump(quant_info, f)
  234. # Success flag
  235. open(osp.join(save_dir, '.success'), 'w').close()
  236. logging.info("Model saved in {}.".format(save_dir))
  237. def build_data_loader(self, dataset, batch_size, mode='train'):
  238. if dataset.num_samples < batch_size:
  239. raise ValueError(
  240. 'The volume of dataset({}) must be larger than batch size({}).'
  241. .format(dataset.num_samples, batch_size))
  242. batch_size_each_card = get_single_card_bs(batch_size=batch_size)
  243. batch_sampler = DistributedBatchSampler(
  244. dataset,
  245. batch_size=batch_size_each_card,
  246. shuffle=dataset.shuffle,
  247. drop_last=mode == 'train')
  248. if dataset.num_workers > 0:
  249. shm_size = _get_shared_memory_size_in_M()
  250. if shm_size is None or shm_size < 1024.:
  251. use_shared_memory = False
  252. else:
  253. use_shared_memory = True
  254. else:
  255. use_shared_memory = False
  256. loader = DataLoader(
  257. dataset,
  258. batch_sampler=batch_sampler,
  259. collate_fn=dataset.batch_transforms,
  260. num_workers=dataset.num_workers,
  261. return_list=True,
  262. use_shared_memory=use_shared_memory)
  263. return loader
  264. def train_loop(self,
  265. num_epochs,
  266. train_dataset,
  267. train_batch_size,
  268. eval_dataset=None,
  269. save_interval_epochs=1,
  270. log_interval_steps=10,
  271. save_dir='output',
  272. ema=None,
  273. early_stop=False,
  274. early_stop_patience=5,
  275. use_vdl=True):
  276. self._check_transforms(train_dataset.transforms, 'train')
  277. if self.model_type == 'detector' and 'RCNN' in self.__class__.__name__ and train_dataset.pos_num < len(
  278. train_dataset.file_list):
  279. nranks = 1
  280. else:
  281. nranks = paddle.distributed.get_world_size()
  282. local_rank = paddle.distributed.get_rank()
  283. if nranks > 1:
  284. find_unused_parameters = getattr(self, 'find_unused_parameters',
  285. False)
  286. # Initialize parallel environment if not done.
  287. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  288. ):
  289. paddle.distributed.init_parallel_env()
  290. ddp_net = to_data_parallel(
  291. self.net, find_unused_parameters=find_unused_parameters)
  292. else:
  293. ddp_net = to_data_parallel(
  294. self.net, find_unused_parameters=find_unused_parameters)
  295. if use_vdl:
  296. from visualdl import LogWriter
  297. vdl_logdir = osp.join(save_dir, 'vdl_log')
  298. log_writer = LogWriter(vdl_logdir)
  299. # task_id: refer to paddlers
  300. task_id = getattr(paddlers, "task_id", "")
  301. thresh = .0001
  302. if early_stop:
  303. earlystop = EarlyStop(early_stop_patience, thresh)
  304. self.train_data_loader = self.build_data_loader(
  305. train_dataset, batch_size=train_batch_size, mode='train')
  306. if eval_dataset is not None:
  307. self.test_transforms = copy.deepcopy(eval_dataset.transforms)
  308. start_epoch = self.completed_epochs
  309. train_step_time = SmoothedValue(log_interval_steps)
  310. train_step_each_epoch = math.floor(train_dataset.num_samples /
  311. train_batch_size)
  312. train_total_step = train_step_each_epoch * (num_epochs - start_epoch)
  313. if eval_dataset is not None:
  314. eval_batch_size = train_batch_size
  315. eval_epoch_time = 0
  316. current_step = 0
  317. for i in range(start_epoch, num_epochs):
  318. self.net.train()
  319. if callable(
  320. getattr(self.train_data_loader.dataset, 'set_epoch', None)):
  321. self.train_data_loader.dataset.set_epoch(i)
  322. train_avg_metrics = TrainingStats()
  323. step_time_tic = time.time()
  324. for step, data in enumerate(self.train_data_loader()):
  325. if nranks > 1:
  326. outputs = self.train_step(step, data, ddp_net)
  327. else:
  328. outputs = self.train_step(step, data, self.net)
  329. scheduler_step(self.optimizer, outputs['loss'])
  330. train_avg_metrics.update(outputs)
  331. lr = self.optimizer.get_lr()
  332. outputs['lr'] = lr
  333. if ema is not None:
  334. ema.update(self.net)
  335. step_time_toc = time.time()
  336. train_step_time.update(step_time_toc - step_time_tic)
  337. step_time_tic = step_time_toc
  338. current_step += 1
  339. # Log loss info every log_interval_steps
  340. if current_step % log_interval_steps == 0 and local_rank == 0:
  341. if use_vdl:
  342. for k, v in outputs.items():
  343. log_writer.add_scalar(
  344. '{}-Metrics/Training(Step): {}'.format(
  345. task_id, k), v, current_step)
  346. # Estimation remaining time
  347. avg_step_time = train_step_time.avg()
  348. eta = avg_step_time * (train_total_step - current_step)
  349. if eval_dataset is not None:
  350. eval_num_epochs = math.ceil(
  351. (num_epochs - i - 1) / save_interval_epochs)
  352. if eval_epoch_time == 0:
  353. eta += avg_step_time * math.ceil(
  354. eval_dataset.num_samples / eval_batch_size)
  355. else:
  356. eta += eval_epoch_time * eval_num_epochs
  357. logging.info(
  358. "[TRAIN] Epoch={}/{}, Step={}/{}, {}, time_each_step={}s, eta={}"
  359. .format(i + 1, num_epochs, step + 1,
  360. train_step_each_epoch,
  361. dict2str(outputs),
  362. round(avg_step_time, 2), seconds_to_hms(eta)))
  363. logging.info('[TRAIN] Epoch {} finished, {} .'
  364. .format(i + 1, train_avg_metrics.log()))
  365. self.completed_epochs += 1
  366. if ema is not None:
  367. weight = copy.deepcopy(self.net.state_dict())
  368. self.net.set_state_dict(ema.apply())
  369. eval_epoch_tic = time.time()
  370. # Every save_interval_epochs, evaluate and save the model
  371. if (i + 1) % save_interval_epochs == 0 or i == num_epochs - 1:
  372. if eval_dataset is not None and eval_dataset.num_samples > 0:
  373. eval_result = self.evaluate(
  374. eval_dataset,
  375. batch_size=eval_batch_size,
  376. return_details=True)
  377. # Save the optimial model
  378. if local_rank == 0:
  379. self.eval_metrics, self.eval_details = eval_result
  380. if use_vdl:
  381. for k, v in self.eval_metrics.items():
  382. try:
  383. log_writer.add_scalar(
  384. '{}-Metrics/Eval(Epoch): {}'.format(
  385. task_id, k), v, i + 1)
  386. except TypeError:
  387. pass
  388. logging.info('[EVAL] Finished, Epoch={}, {} .'.format(
  389. i + 1, dict2str(self.eval_metrics)))
  390. best_accuracy_key = list(self.eval_metrics.keys())[0]
  391. current_accuracy = self.eval_metrics[best_accuracy_key]
  392. if current_accuracy > self.best_accuracy:
  393. self.best_accuracy = current_accuracy
  394. self.best_model_epoch = i + 1
  395. best_model_dir = osp.join(save_dir, "best_model")
  396. self.save_model(save_dir=best_model_dir)
  397. if self.best_model_epoch > 0:
  398. logging.info(
  399. 'Current evaluated best model on eval_dataset is epoch_{}, {}={}'
  400. .format(self.best_model_epoch,
  401. best_accuracy_key, self.best_accuracy))
  402. eval_epoch_time = time.time() - eval_epoch_tic
  403. current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))
  404. if local_rank == 0:
  405. self.save_model(save_dir=current_save_dir)
  406. if eval_dataset is not None and early_stop:
  407. if earlystop(current_accuracy):
  408. break
  409. if ema is not None:
  410. self.net.set_state_dict(weight)
  411. def analyze_sensitivity(self,
  412. dataset,
  413. batch_size=8,
  414. criterion='l1_norm',
  415. save_dir='output'):
  416. """
  417. Args:
  418. dataset (paddlers.datasets.BaseDataset): Dataset used for evaluation during
  419. sensitivity analysis.
  420. batch_size (int, optional): Batch size used in evaluation. Defaults to 8.
  421. criterion (str, optional): Pruning criterion. Choices are {'l1_norm', 'fpgm'}.
  422. Defaults to 'l1_norm'.
  423. save_dir (str, optional): Directory to save sensitivity file of the model.
  424. Defaults to 'output'.
  425. """
  426. if self.__class__.__name__ in {'FasterRCNN', 'MaskRCNN', 'PicoDet'}:
  427. raise ValueError("{} does not support pruning currently!".format(
  428. self.__class__.__name__))
  429. assert criterion in {'l1_norm', 'fpgm'}, \
  430. "Pruning criterion {} is not supported. Please choose from {'l1_norm', 'fpgm'}."
  431. self._check_transforms(dataset.transforms, 'eval')
  432. if self.model_type == 'detector':
  433. self.net.eval()
  434. else:
  435. self.net.train()
  436. inputs = _pruner_template_input(
  437. sample=dataset[0], model_type=self.model_type)
  438. if criterion == 'l1_norm':
  439. self.pruner = L1NormFilterPruner(self.net, inputs=inputs)
  440. else:
  441. self.pruner = FPGMFilterPruner(self.net, inputs=inputs)
  442. if not osp.isdir(save_dir):
  443. os.makedirs(save_dir)
  444. sen_file = osp.join(save_dir, 'model.sensi.data')
  445. logging.info('Sensitivity analysis of model parameters starts...')
  446. self.pruner.sensitive(
  447. eval_func=partial(_pruner_eval_fn, self, dataset, batch_size),
  448. sen_file=sen_file)
  449. logging.info(
  450. 'Sensitivity analysis is complete. The result is saved at {}.'.
  451. format(sen_file))
  452. def prune(self, pruned_flops, save_dir=None):
  453. """
  454. Args:
  455. pruned_flops (float): Ratio of FLOPs to be pruned.
  456. save_dir (str|None, optional): If None, the pruned model will not be
  457. saved. Otherwise, the pruned model will be saved at `save_dir`.
  458. Defaults to None.
  459. """
  460. if self.status == "Pruned":
  461. raise ValueError(
  462. "A pruned model cannot be pruned for a second time!")
  463. pre_pruning_flops = flops(self.net, self.pruner.inputs)
  464. logging.info("Pre-pruning FLOPs: {}. Pruning starts...".format(
  465. pre_pruning_flops))
  466. _, self.pruning_ratios = sensitive_prune(self.pruner, pruned_flops)
  467. post_pruning_flops = flops(self.net, self.pruner.inputs)
  468. logging.info("Pruning is complete. Post-pruning FLOPs: {}".format(
  469. post_pruning_flops))
  470. logging.warning("Pruning the model may hurt its performance. "
  471. "Re-training is highly recommended.")
  472. self.status = 'Pruned'
  473. if save_dir is not None:
  474. self.save_model(save_dir)
  475. logging.info("Pruned model is saved at {}".format(save_dir))
  476. def _prepare_qat(self, quant_config):
  477. if self.status == 'Infer':
  478. logging.error(
  479. "Exported inference model does not support quantization-aware training.",
  480. exit=True)
  481. if quant_config is None:
  482. # Default quantization configuration
  483. quant_config = {
  484. # {None, 'PACT'}. Weight preprocess type. If None, no preprocessing is performed.
  485. 'weight_preprocess_type': None,
  486. # {None, 'PACT'}. Activation preprocess type. If None, no preprocessing is performed.
  487. 'activation_preprocess_type': None,
  488. # {'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max'}.
  489. # Weight quantization type.
  490. 'weight_quantize_type': 'channel_wise_abs_max',
  491. # {'abs_max', 'range_abs_max', 'moving_average_abs_max'}. Activation quantization type.
  492. 'activation_quantize_type': 'moving_average_abs_max',
  493. # The number of bits of weights after quantization.
  494. 'weight_bits': 8,
  495. # The number of bits of activation after quantization.
  496. 'activation_bits': 8,
  497. # Data type after quantization, such as 'uint8', 'int8', etc.
  498. 'dtype': 'int8',
  499. # Window size for 'range_abs_max' quantization.
  500. 'window_size': 10000,
  501. # Decay coefficient of moving average.
  502. 'moving_rate': .9,
  503. # Types of layers that will be quantized.
  504. 'quantizable_layer_type': ['Conv2D', 'Linear']
  505. }
  506. if self.status != 'Quantized':
  507. self.quant_config = quant_config
  508. self.quantizer = QAT(config=self.quant_config)
  509. logging.info(
  510. "Preparing the model for quantization-aware training...")
  511. self.quantizer.quantize(self.net)
  512. logging.info("Model is ready for quantization-aware training.")
  513. self.status = 'Quantized'
  514. elif quant_config != self.quant_config:
  515. logging.error(
  516. "The model has been quantized with the following quant_config: {}."
  517. "Performing quantization-aware training with a quantized model "
  518. "using a different configuration is not supported."
  519. .format(self.quant_config),
  520. exit=True)
  521. def _get_pipeline_info(self, save_dir):
  522. pipeline_info = {}
  523. pipeline_info["pipeline_name"] = self.model_type
  524. nodes = [{
  525. "src0": {
  526. "type": "Source",
  527. "next": "decode0"
  528. }
  529. }, {
  530. "decode0": {
  531. "type": "Decode",
  532. "next": "predict0"
  533. }
  534. }, {
  535. "predict0": {
  536. "type": "Predict",
  537. "init_params": {
  538. "use_gpu": False,
  539. "gpu_id": 0,
  540. "use_trt": False,
  541. "model_dir": save_dir,
  542. },
  543. "next": "sink0"
  544. }
  545. }, {
  546. "sink0": {
  547. "type": "Sink"
  548. }
  549. }]
  550. pipeline_info["pipeline_nodes"] = nodes
  551. pipeline_info["version"] = "1.0.0"
  552. return pipeline_info
  553. def _build_inference_net(self):
  554. raise NotImplementedError
  555. def _export_inference_model(self, save_dir, image_shape=None):
  556. self.test_inputs = self._get_test_inputs(image_shape)
  557. infer_net = self._build_inference_net()
  558. if self.status == 'Quantized':
  559. self.quantizer.save_quantized_model(infer_net,
  560. osp.join(save_dir, 'model'),
  561. self.test_inputs)
  562. quant_info = self.get_quant_info()
  563. with open(
  564. osp.join(save_dir, 'quant.yml'), encoding='utf-8',
  565. mode='w') as f:
  566. yaml.dump(quant_info, f)
  567. else:
  568. static_net = paddle.jit.to_static(
  569. infer_net, input_spec=self.test_inputs)
  570. paddle.jit.save(static_net, osp.join(save_dir, 'model'))
  571. if self.status == 'Pruned':
  572. pruning_info = self.get_pruning_info()
  573. with open(
  574. osp.join(save_dir, 'prune.yml'), encoding='utf-8',
  575. mode='w') as f:
  576. yaml.dump(pruning_info, f)
  577. model_info = self.get_model_info()
  578. model_info['status'] = 'Infer'
  579. with open(
  580. osp.join(save_dir, 'model.yml'), encoding='utf-8',
  581. mode='w') as f:
  582. yaml.dump(model_info, f)
  583. pipeline_info = self._get_pipeline_info(save_dir)
  584. with open(
  585. osp.join(save_dir, 'pipeline.yml'), encoding='utf-8',
  586. mode='w') as f:
  587. yaml.dump(pipeline_info, f)
  588. # Success flag
  589. open(osp.join(save_dir, '.success'), 'w').close()
  590. logging.info("The inference model for deployment is saved in {}.".
  591. format(save_dir))
  592. def train_step(self, step, data, net):
  593. outputs = self.run(net, data, mode='train')
  594. loss = outputs['loss']
  595. loss.backward()
  596. self.optimizer.step()
  597. self.optimizer.clear_grad()
  598. return outputs
  599. def _check_transforms(self, transforms, mode):
  600. # NOTE: Check transforms and transforms.arrange and give user-friendly error messages.
  601. if not isinstance(transforms, paddlers.transforms.Compose):
  602. raise TypeError("`transforms` must be paddlers.transforms.Compose.")
  603. arrange_obj = transforms.arrange
  604. if not isinstance(arrange_obj, paddlers.transforms.operators.Arrange):
  605. raise TypeError("`transforms.arrange` must be an Arrange object.")
  606. if arrange_obj.mode != mode:
  607. raise ValueError(
  608. f"Incorrect arrange mode! Expected {mode} but got {arrange_obj.mode}."
  609. )
  610. def run(self, net, inputs, mode):
  611. raise NotImplementedError
  612. def train(self, *args, **kwargs):
  613. raise NotImplementedError
  614. def evaluate(self, *args, **kwargs):
  615. raise NotImplementedError
  616. def preprocess(self, images, transforms, to_tensor):
  617. raise NotImplementedError
  618. def postprocess(self, *args, **kwargs):
  619. raise NotImplementedError