base.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import time
  17. import copy
  18. import math
  19. import json
  20. from functools import partial, wraps
  21. from inspect import signature
  22. import yaml
  23. import paddle
  24. from paddle.io import DataLoader, DistributedBatchSampler
  25. from paddleslim import QAT
  26. from paddleslim.analysis import flops
  27. from paddleslim import L1NormFilterPruner, FPGMFilterPruner
  28. import paddlers
  29. import paddlers.utils.logging as logging
  30. from paddlers.transforms import arrange_transforms
  31. from paddlers.utils import (seconds_to_hms, get_single_card_bs, dict2str,
  32. get_pretrain_weights, load_pretrain_weights,
  33. load_checkpoint, SmoothedValue, TrainingStats,
  34. _get_shared_memory_size_in_M, EarlyStop)
  35. from .slim.prune import _pruner_eval_fn, _pruner_template_input, sensitive_prune
  36. from .utils.infer_nets import InferNet, InferCDNet
  37. class ModelMeta(type):
  38. def __new__(cls, name, bases, attrs):
  39. def _deco(init_func):
  40. @wraps(init_func)
  41. def _wrapper(self, *args, **kwargs):
  42. if hasattr(self, '_raw_params'):
  43. ret = init_func(self, *args, **kwargs)
  44. else:
  45. sig = signature(init_func)
  46. bnd_args = sig.bind(self, *args, **kwargs)
  47. raw_params = bnd_args.arguments
  48. raw_params.pop('self')
  49. self._raw_params = raw_params
  50. ret = init_func(self, *args, **kwargs)
  51. return ret
  52. return _wrapper
  53. old_init_func = attrs['__init__']
  54. attrs['__init__'] = _deco(old_init_func)
  55. return type.__new__(cls, name, bases, attrs)
  56. class BaseModel(metaclass=ModelMeta):
  57. def __init__(self, model_type):
  58. self.model_type = model_type
  59. self.in_channels = None
  60. self.num_classes = None
  61. self.labels = None
  62. self.version = paddlers.__version__
  63. self.net = None
  64. self.optimizer = None
  65. self.test_inputs = None
  66. self.train_data_loader = None
  67. self.eval_data_loader = None
  68. self.eval_metrics = None
  69. self.best_accuracy = -1.
  70. self.best_model_epoch = -1
  71. # 是否使用多卡间同步BatchNorm均值和方差
  72. self.sync_bn = False
  73. self.status = 'Normal'
  74. # 已完成迭代轮数,为恢复训练时的起始轮数
  75. self.completed_epochs = 0
  76. self.pruner = None
  77. self.pruning_ratios = None
  78. self.quantizer = None
  79. self.quant_config = None
  80. self.fixed_input_shape = None
  81. def net_initialize(self,
  82. pretrain_weights=None,
  83. save_dir='.',
  84. resume_checkpoint=None,
  85. is_backbone_weights=False):
  86. if pretrain_weights is not None and \
  87. not osp.exists(pretrain_weights):
  88. if not osp.isdir(save_dir):
  89. if osp.exists(save_dir):
  90. os.remove(save_dir)
  91. os.makedirs(save_dir)
  92. if self.model_type == 'classifier':
  93. pretrain_weights = get_pretrain_weights(
  94. pretrain_weights, self.model_name, save_dir)
  95. else:
  96. backbone_name = getattr(self, 'backbone_name', None)
  97. pretrain_weights = get_pretrain_weights(
  98. pretrain_weights,
  99. self.__class__.__name__,
  100. save_dir,
  101. backbone_name=backbone_name)
  102. if pretrain_weights is not None:
  103. if is_backbone_weights:
  104. load_pretrain_weights(
  105. self.net.backbone,
  106. pretrain_weights,
  107. model_name='backbone of ' + self.model_name)
  108. else:
  109. load_pretrain_weights(
  110. self.net, pretrain_weights, model_name=self.model_name)
  111. if resume_checkpoint is not None:
  112. if not osp.exists(resume_checkpoint):
  113. logging.error(
  114. "The checkpoint path {} to resume training from does not exist."
  115. .format(resume_checkpoint),
  116. exit=True)
  117. if not osp.exists(osp.join(resume_checkpoint, 'model.pdparams')):
  118. logging.error(
  119. "Model parameter state dictionary file 'model.pdparams' "
  120. "not found under given checkpoint path {}".format(
  121. resume_checkpoint),
  122. exit=True)
  123. if not osp.exists(osp.join(resume_checkpoint, 'model.pdopt')):
  124. logging.error(
  125. "Optimizer state dictionary file 'model.pdparams' "
  126. "not found under given checkpoint path {}".format(
  127. resume_checkpoint),
  128. exit=True)
  129. if not osp.exists(osp.join(resume_checkpoint, 'model.yml')):
  130. logging.error(
  131. "'model.yml' not found under given checkpoint path {}".
  132. format(resume_checkpoint),
  133. exit=True)
  134. with open(osp.join(resume_checkpoint, "model.yml")) as f:
  135. info = yaml.load(f.read(), Loader=yaml.Loader)
  136. self.completed_epochs = info['completed_epochs']
  137. self.best_accuracy = info['_Attributes']['best_accuracy']
  138. self.best_model_epoch = info['_Attributes']['best_model_epoch']
  139. load_checkpoint(
  140. self.net,
  141. self.optimizer,
  142. model_name=self.model_name,
  143. checkpoint=resume_checkpoint)
  144. def get_model_info(self, get_raw_params=False, inplace=True):
  145. if inplace:
  146. init_params = self.init_params
  147. else:
  148. init_params = copy.deepcopy(self.init_params)
  149. info = dict()
  150. info['version'] = paddlers.__version__
  151. info['Model'] = self.__class__.__name__
  152. info['_Attributes'] = dict(
  153. [('model_type', self.model_type), ('in_channels', self.in_channels),
  154. ('num_classes', self.num_classes), ('labels', self.labels),
  155. ('fixed_input_shape', self.fixed_input_shape),
  156. ('best_accuracy', self.best_accuracy),
  157. ('best_model_epoch', self.best_model_epoch)])
  158. if 'self' in init_params:
  159. del init_params['self']
  160. if '__class__' in init_params:
  161. del init_params['__class__']
  162. if 'model_name' in init_params:
  163. del init_params['model_name']
  164. if 'params' in init_params:
  165. del init_params['params']
  166. info['_init_params'] = init_params
  167. if get_raw_params:
  168. info['raw_params'] = self._raw_params
  169. try:
  170. primary_metric_key = list(self.eval_metrics.keys())[0]
  171. primary_metric_value = float(self.eval_metrics[primary_metric_key])
  172. info['_Attributes']['eval_metrics'] = {
  173. primary_metric_key: primary_metric_value
  174. }
  175. except:
  176. pass
  177. if hasattr(self, 'test_transforms'):
  178. if self.test_transforms is not None:
  179. info['Transforms'] = list()
  180. for op in self.test_transforms.transforms:
  181. name = op.__class__.__name__
  182. if name.startswith('Arrange'):
  183. continue
  184. attr = op.__dict__
  185. info['Transforms'].append({name: attr})
  186. info['completed_epochs'] = self.completed_epochs
  187. return info
  188. def get_pruning_info(self):
  189. info = dict()
  190. info['pruner'] = self.pruner.__class__.__name__
  191. info['pruning_ratios'] = self.pruning_ratios
  192. pruner_inputs = self.pruner.inputs
  193. if self.model_type == 'detector':
  194. pruner_inputs = {k: v.tolist() for k, v in pruner_inputs[0].items()}
  195. info['pruner_inputs'] = pruner_inputs
  196. return info
  197. def get_quant_info(self):
  198. info = dict()
  199. info['quant_config'] = self.quant_config
  200. return info
  201. def save_model(self, save_dir):
  202. if not osp.isdir(save_dir):
  203. if osp.exists(save_dir):
  204. os.remove(save_dir)
  205. os.makedirs(save_dir)
  206. model_info = self.get_model_info(get_raw_params=True)
  207. model_info['status'] = self.status
  208. paddle.save(self.net.state_dict(), osp.join(save_dir, 'model.pdparams'))
  209. paddle.save(self.optimizer.state_dict(),
  210. osp.join(save_dir, 'model.pdopt'))
  211. with open(
  212. osp.join(save_dir, 'model.yml'), encoding='utf-8',
  213. mode='w') as f:
  214. yaml.dump(model_info, f)
  215. # 评估结果保存
  216. if hasattr(self, 'eval_details'):
  217. with open(osp.join(save_dir, 'eval_details.json'), 'w') as f:
  218. json.dump(self.eval_details, f)
  219. if self.status == 'Pruned' and self.pruner is not None:
  220. pruning_info = self.get_pruning_info()
  221. with open(
  222. osp.join(save_dir, 'prune.yml'), encoding='utf-8',
  223. mode='w') as f:
  224. yaml.dump(pruning_info, f)
  225. if self.status == 'Quantized' and self.quantizer is not None:
  226. quant_info = self.get_quant_info()
  227. with open(
  228. osp.join(save_dir, 'quant.yml'), encoding='utf-8',
  229. mode='w') as f:
  230. yaml.dump(quant_info, f)
  231. # 模型保存成功的标志
  232. open(osp.join(save_dir, '.success'), 'w').close()
  233. logging.info("Model saved in {}.".format(save_dir))
  234. def build_data_loader(self, dataset, batch_size, mode='train'):
  235. if dataset.num_samples < batch_size:
  236. raise Exception(
  237. 'The volume of dataset({}) must be larger than batch size({}).'
  238. .format(dataset.num_samples, batch_size))
  239. batch_size_each_card = get_single_card_bs(batch_size=batch_size)
  240. # TODO detection eval阶段需做判断
  241. batch_sampler = DistributedBatchSampler(
  242. dataset,
  243. batch_size=batch_size_each_card,
  244. shuffle=dataset.shuffle,
  245. drop_last=mode == 'train')
  246. if dataset.num_workers > 0:
  247. shm_size = _get_shared_memory_size_in_M()
  248. if shm_size is None or shm_size < 1024.:
  249. use_shared_memory = False
  250. else:
  251. use_shared_memory = True
  252. else:
  253. use_shared_memory = False
  254. loader = DataLoader(
  255. dataset,
  256. batch_sampler=batch_sampler,
  257. collate_fn=dataset.batch_transforms,
  258. num_workers=dataset.num_workers,
  259. return_list=True,
  260. use_shared_memory=use_shared_memory)
  261. return loader
  262. def train_loop(self,
  263. num_epochs,
  264. train_dataset,
  265. train_batch_size,
  266. eval_dataset=None,
  267. save_interval_epochs=1,
  268. log_interval_steps=10,
  269. save_dir='output',
  270. ema=None,
  271. early_stop=False,
  272. early_stop_patience=5,
  273. use_vdl=True):
  274. arrange_transforms(
  275. model_type=self.model_type,
  276. transforms=train_dataset.transforms,
  277. mode='train')
  278. if "RCNN" in self.__class__.__name__ and train_dataset.pos_num < len(
  279. train_dataset.file_list):
  280. nranks = 1
  281. else:
  282. nranks = paddle.distributed.get_world_size()
  283. local_rank = paddle.distributed.get_rank()
  284. if nranks > 1:
  285. find_unused_parameters = getattr(self, 'find_unused_parameters',
  286. False)
  287. # Initialize parallel environment if not done.
  288. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  289. ):
  290. paddle.distributed.init_parallel_env()
  291. ddp_net = paddle.DataParallel(
  292. self.net, find_unused_parameters=find_unused_parameters)
  293. else:
  294. ddp_net = paddle.DataParallel(
  295. self.net, find_unused_parameters=find_unused_parameters)
  296. if use_vdl:
  297. from visualdl import LogWriter
  298. vdl_logdir = osp.join(save_dir, 'vdl_log')
  299. log_writer = LogWriter(vdl_logdir)
  300. # task_id: refer to paddlers
  301. task_id = getattr(paddlers, "task_id", "")
  302. thresh = .0001
  303. if early_stop:
  304. earlystop = EarlyStop(early_stop_patience, thresh)
  305. self.train_data_loader = self.build_data_loader(
  306. train_dataset, batch_size=train_batch_size, mode='train')
  307. if eval_dataset is not None:
  308. self.test_transforms = copy.deepcopy(eval_dataset.transforms)
  309. start_epoch = self.completed_epochs
  310. train_step_time = SmoothedValue(log_interval_steps)
  311. train_step_each_epoch = math.floor(train_dataset.num_samples /
  312. train_batch_size)
  313. train_total_step = train_step_each_epoch * (num_epochs - start_epoch)
  314. if eval_dataset is not None:
  315. eval_batch_size = train_batch_size
  316. eval_epoch_time = 0
  317. current_step = 0
  318. for i in range(start_epoch, num_epochs):
  319. self.net.train()
  320. if callable(
  321. getattr(self.train_data_loader.dataset, 'set_epoch', None)):
  322. self.train_data_loader.dataset.set_epoch(i)
  323. train_avg_metrics = TrainingStats()
  324. step_time_tic = time.time()
  325. for step, data in enumerate(self.train_data_loader()):
  326. if nranks > 1:
  327. outputs = self.run(ddp_net, data, mode='train')
  328. else:
  329. outputs = self.run(self.net, data, mode='train')
  330. loss = outputs['loss']
  331. loss.backward()
  332. self.optimizer.step()
  333. self.optimizer.clear_grad()
  334. lr = self.optimizer.get_lr()
  335. if isinstance(self.optimizer._learning_rate,
  336. paddle.optimizer.lr.LRScheduler):
  337. self.optimizer._learning_rate.step()
  338. train_avg_metrics.update(outputs)
  339. outputs['lr'] = lr
  340. if ema is not None:
  341. ema.update(self.net)
  342. step_time_toc = time.time()
  343. train_step_time.update(step_time_toc - step_time_tic)
  344. step_time_tic = step_time_toc
  345. current_step += 1
  346. # 每间隔log_interval_steps,输出loss信息
  347. if current_step % log_interval_steps == 0 and local_rank == 0:
  348. if use_vdl:
  349. for k, v in outputs.items():
  350. log_writer.add_scalar(
  351. '{}-Metrics/Training(Step): {}'.format(
  352. task_id, k), v, current_step)
  353. # 估算剩余时间
  354. avg_step_time = train_step_time.avg()
  355. eta = avg_step_time * (train_total_step - current_step)
  356. if eval_dataset is not None:
  357. eval_num_epochs = math.ceil(
  358. (num_epochs - i - 1) / save_interval_epochs)
  359. if eval_epoch_time == 0:
  360. eta += avg_step_time * math.ceil(
  361. eval_dataset.num_samples / eval_batch_size)
  362. else:
  363. eta += eval_epoch_time * eval_num_epochs
  364. logging.info(
  365. "[TRAIN] Epoch={}/{}, Step={}/{}, {}, time_each_step={}s, eta={}"
  366. .format(i + 1, num_epochs, step + 1,
  367. train_step_each_epoch,
  368. dict2str(outputs),
  369. round(avg_step_time, 2), seconds_to_hms(eta)))
  370. logging.info('[TRAIN] Epoch {} finished, {} .'
  371. .format(i + 1, train_avg_metrics.log()))
  372. self.completed_epochs += 1
  373. if ema is not None:
  374. weight = copy.deepcopy(self.net.state_dict())
  375. self.net.set_state_dict(ema.apply())
  376. eval_epoch_tic = time.time()
  377. # 每间隔save_interval_epochs, 在验证集上评估和对模型进行保存
  378. if (i + 1) % save_interval_epochs == 0 or i == num_epochs - 1:
  379. if eval_dataset is not None and eval_dataset.num_samples > 0:
  380. eval_result = self.evaluate(
  381. eval_dataset,
  382. batch_size=eval_batch_size,
  383. return_details=True)
  384. # 保存最优模型
  385. if local_rank == 0:
  386. self.eval_metrics, self.eval_details = eval_result
  387. if use_vdl:
  388. for k, v in self.eval_metrics.items():
  389. try:
  390. log_writer.add_scalar(
  391. '{}-Metrics/Eval(Epoch): {}'.format(
  392. task_id, k), v, i + 1)
  393. except TypeError:
  394. pass
  395. logging.info('[EVAL] Finished, Epoch={}, {} .'.format(
  396. i + 1, dict2str(self.eval_metrics)))
  397. best_accuracy_key = list(self.eval_metrics.keys())[0]
  398. current_accuracy = self.eval_metrics[best_accuracy_key]
  399. if current_accuracy > self.best_accuracy:
  400. self.best_accuracy = current_accuracy
  401. self.best_model_epoch = i + 1
  402. best_model_dir = osp.join(save_dir, "best_model")
  403. self.save_model(save_dir=best_model_dir)
  404. if self.best_model_epoch > 0:
  405. logging.info(
  406. 'Current evaluated best model on eval_dataset is epoch_{}, {}={}'
  407. .format(self.best_model_epoch,
  408. best_accuracy_key, self.best_accuracy))
  409. eval_epoch_time = time.time() - eval_epoch_tic
  410. current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))
  411. if local_rank == 0:
  412. self.save_model(save_dir=current_save_dir)
  413. if eval_dataset is not None and early_stop:
  414. if earlystop(current_accuracy):
  415. break
  416. if ema is not None:
  417. self.net.set_state_dict(weight)
  418. def analyze_sensitivity(self,
  419. dataset,
  420. batch_size=8,
  421. criterion='l1_norm',
  422. save_dir='output'):
  423. """
  424. Args:
  425. dataset(paddlers.dataset): Dataset used for evaluation during sensitivity analysis.
  426. batch_size(int, optional): Batch size used in evaluation. Defaults to 8.
  427. criterion({'l1_norm', 'fpgm'}, optional): Pruning criterion. Defaults to 'l1_norm'.
  428. save_dir(str, optional): The directory to save sensitivity file of the model. Defaults to 'output'.
  429. """
  430. if self.__class__.__name__ in {'FasterRCNN', 'MaskRCNN', 'PicoDet'}:
  431. raise Exception("{} does not support pruning currently!".format(
  432. self.__class__.__name__))
  433. assert criterion in {'l1_norm', 'fpgm'}, \
  434. "Pruning criterion {} is not supported. Please choose from ['l1_norm', 'fpgm']"
  435. arrange_transforms(
  436. model_type=self.model_type,
  437. transforms=dataset.transforms,
  438. mode='eval')
  439. if self.model_type == 'detector':
  440. self.net.eval()
  441. else:
  442. self.net.train()
  443. inputs = _pruner_template_input(
  444. sample=dataset[0], model_type=self.model_type)
  445. if criterion == 'l1_norm':
  446. self.pruner = L1NormFilterPruner(self.net, inputs=inputs)
  447. else:
  448. self.pruner = FPGMFilterPruner(self.net, inputs=inputs)
  449. if not osp.isdir(save_dir):
  450. os.makedirs(save_dir)
  451. sen_file = osp.join(save_dir, 'model.sensi.data')
  452. logging.info('Sensitivity analysis of model parameters starts...')
  453. self.pruner.sensitive(
  454. eval_func=partial(_pruner_eval_fn, self, dataset, batch_size),
  455. sen_file=sen_file)
  456. logging.info(
  457. 'Sensitivity analysis is complete. The result is saved at {}.'.
  458. format(sen_file))
  459. def prune(self, pruned_flops, save_dir=None):
  460. """
  461. Args:
  462. pruned_flops(float): Ratio of FLOPs to be pruned.
  463. save_dir(None or str, optional): If None, the pruned model will not be saved.
  464. Otherwise, the pruned model will be saved at save_dir. Defaults to None.
  465. """
  466. if self.status == "Pruned":
  467. raise Exception(
  468. "A pruned model cannot be done model pruning again!")
  469. pre_pruning_flops = flops(self.net, self.pruner.inputs)
  470. logging.info("Pre-pruning FLOPs: {}. Pruning starts...".format(
  471. pre_pruning_flops))
  472. _, self.pruning_ratios = sensitive_prune(self.pruner, pruned_flops)
  473. post_pruning_flops = flops(self.net, self.pruner.inputs)
  474. logging.info("Pruning is complete. Post-pruning FLOPs: {}".format(
  475. post_pruning_flops))
  476. logging.warning("Pruning the model may hurt its performance, "
  477. "retraining is highly recommended")
  478. self.status = 'Pruned'
  479. if save_dir is not None:
  480. self.save_model(save_dir)
  481. logging.info("Pruned model is saved at {}".format(save_dir))
  482. def _prepare_qat(self, quant_config):
  483. if self.status == 'Infer':
  484. logging.error(
  485. "Exported inference model does not support quantization aware training.",
  486. exit=True)
  487. if quant_config is None:
  488. # default quantization configuration
  489. quant_config = {
  490. # {None, 'PACT'}. Weight preprocess type. If None, no preprocessing is performed.
  491. 'weight_preprocess_type': None,
  492. # {None, 'PACT'}. Activation preprocess type. If None, no preprocessing is performed.
  493. 'activation_preprocess_type': None,
  494. # {'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max'}.
  495. # Weight quantization type.
  496. 'weight_quantize_type': 'channel_wise_abs_max',
  497. # {'abs_max', 'range_abs_max', 'moving_average_abs_max'}. Activation quantization type.
  498. 'activation_quantize_type': 'moving_average_abs_max',
  499. # The number of bits of weights after quantization.
  500. 'weight_bits': 8,
  501. # The number of bits of activation after quantization.
  502. 'activation_bits': 8,
  503. # Data type after quantization, such as 'uint8', 'int8', etc.
  504. 'dtype': 'int8',
  505. # Window size for 'range_abs_max' quantization.
  506. 'window_size': 10000,
  507. # Decay coefficient of moving average.
  508. 'moving_rate': .9,
  509. # Types of layers that will be quantized.
  510. 'quantizable_layer_type': ['Conv2D', 'Linear']
  511. }
  512. if self.status != 'Quantized':
  513. self.quant_config = quant_config
  514. self.quantizer = QAT(config=self.quant_config)
  515. logging.info(
  516. "Preparing the model for quantization-aware training...")
  517. self.quantizer.quantize(self.net)
  518. logging.info("Model is ready for quantization-aware training.")
  519. self.status = 'Quantized'
  520. elif quant_config != self.quant_config:
  521. logging.error(
  522. "The model has been quantized with the following quant_config: {}."
  523. "Doing quantization-aware training with a quantized model "
  524. "using a different configuration is not supported."
  525. .format(self.quant_config),
  526. exit=True)
  527. def _get_pipeline_info(self, save_dir):
  528. pipeline_info = {}
  529. pipeline_info["pipeline_name"] = self.model_type
  530. nodes = [{
  531. "src0": {
  532. "type": "Source",
  533. "next": "decode0"
  534. }
  535. }, {
  536. "decode0": {
  537. "type": "Decode",
  538. "next": "predict0"
  539. }
  540. }, {
  541. "predict0": {
  542. "type": "Predict",
  543. "init_params": {
  544. "use_gpu": False,
  545. "gpu_id": 0,
  546. "use_trt": False,
  547. "model_dir": save_dir,
  548. },
  549. "next": "sink0"
  550. }
  551. }, {
  552. "sink0": {
  553. "type": "Sink"
  554. }
  555. }]
  556. pipeline_info["pipeline_nodes"] = nodes
  557. pipeline_info["version"] = "1.0.0"
  558. return pipeline_info
  559. def _build_inference_net(self):
  560. if self.model_type == 'detector':
  561. infer_net = self.net
  562. elif self.model_type == 'changedetector':
  563. infer_net = InferCDNet(self.net)
  564. else:
  565. infer_net = InferNet(self.net, self.model_type)
  566. infer_net.eval()
  567. return infer_net
  568. def _export_inference_model(self, save_dir, image_shape=None):
  569. self.test_inputs = self._get_test_inputs(image_shape)
  570. infer_net = self._build_inference_net()
  571. if self.status == 'Quantized':
  572. self.quantizer.save_quantized_model(infer_net,
  573. osp.join(save_dir, 'model'),
  574. self.test_inputs)
  575. quant_info = self.get_quant_info()
  576. with open(
  577. osp.join(save_dir, 'quant.yml'), encoding='utf-8',
  578. mode='w') as f:
  579. yaml.dump(quant_info, f)
  580. else:
  581. static_net = paddle.jit.to_static(
  582. infer_net, input_spec=self.test_inputs)
  583. paddle.jit.save(static_net, osp.join(save_dir, 'model'))
  584. if self.status == 'Pruned':
  585. pruning_info = self.get_pruning_info()
  586. with open(
  587. osp.join(save_dir, 'prune.yml'), encoding='utf-8',
  588. mode='w') as f:
  589. yaml.dump(pruning_info, f)
  590. model_info = self.get_model_info()
  591. model_info['status'] = 'Infer'
  592. with open(
  593. osp.join(save_dir, 'model.yml'), encoding='utf-8',
  594. mode='w') as f:
  595. yaml.dump(model_info, f)
  596. pipeline_info = self._get_pipeline_info(save_dir)
  597. with open(
  598. osp.join(save_dir, 'pipeline.yml'), encoding='utf-8',
  599. mode='w') as f:
  600. yaml.dump(pipeline_info, f)
  601. # 模型保存成功的标志
  602. open(osp.join(save_dir, '.success'), 'w').close()
  603. logging.info("The model for the inference deployment is saved in {}.".
  604. format(save_dir))