segmenter.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import os.path as osp
  16. import numpy as np
  17. import cv2
  18. from collections import OrderedDict
  19. import paddle
  20. import paddle.nn.functional as F
  21. from paddle.static import InputSpec
  22. import paddlers.models.ppseg as paddleseg
  23. import paddlers
  24. from paddlers.transforms import arrange_transforms
  25. from paddlers.utils import get_single_card_bs, DisablePrint
  26. import paddlers.utils.logging as logging
  27. from .base import BaseModel
  28. from .utils import seg_metrics as metrics
  29. from paddlers.utils.checkpoint import seg_pretrain_weights_dict
  30. from paddlers.transforms import ImgDecoder, Resize
  31. __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"]
  32. class BaseSegmenter(BaseModel):
  33. def __init__(self,
  34. model_name,
  35. num_classes=2,
  36. use_mixed_loss=False,
  37. **params):
  38. self.init_params = locals()
  39. if 'with_net' in self.init_params:
  40. del self.init_params['with_net']
  41. super(BaseSegmenter, self).__init__('segmenter')
  42. if not hasattr(paddleseg.models, model_name) and \
  43. not hasattr(paddleseg.rs_models, model_name):
  44. raise Exception("ERROR: There's no model named {}.".format(
  45. model_name))
  46. self.model_name = model_name
  47. self.num_classes = num_classes
  48. self.use_mixed_loss = use_mixed_loss
  49. self.losses = None
  50. self.labels = None
  51. if params.get('with_net', True):
  52. params.pop('with_net', None)
  53. self.net = self.build_net(**params)
  54. self.find_unused_parameters = True
  55. def build_net(self, **params):
  56. # TODO: when using paddle.utils.unique_name.guard,
  57. # DeepLabv3p and HRNet will raise a error
  58. net = dict(paddleseg.models.__dict__, **paddleseg.rs_models.__dict__)[self.model_name](
  59. num_classes=self.num_classes, **params)
  60. return net
  61. def _fix_transforms_shape(self, image_shape):
  62. if hasattr(self, 'test_transforms'):
  63. if self.test_transforms is not None:
  64. has_resize_op = False
  65. resize_op_idx = -1
  66. normalize_op_idx = len(self.test_transforms.transforms)
  67. for idx, op in enumerate(self.test_transforms.transforms):
  68. name = op.__class__.__name__
  69. if name == 'Normalize':
  70. normalize_op_idx = idx
  71. if 'Resize' in name:
  72. has_resize_op = True
  73. resize_op_idx = idx
  74. if not has_resize_op:
  75. self.test_transforms.transforms.insert(
  76. normalize_op_idx, Resize(target_size=image_shape))
  77. else:
  78. self.test_transforms.transforms[resize_op_idx] = Resize(
  79. target_size=image_shape)
  80. def _get_test_inputs(self, image_shape):
  81. if image_shape is not None:
  82. if len(image_shape) == 2:
  83. image_shape = [1, 3] + image_shape
  84. self._fix_transforms_shape(image_shape[-2:])
  85. else:
  86. image_shape = [None, 3, -1, -1]
  87. self.fixed_input_shape = image_shape
  88. input_spec = [
  89. InputSpec(
  90. shape=image_shape, name='image', dtype='float32')
  91. ]
  92. return input_spec
  93. def run(self, net, inputs, mode):
  94. net_out = net(inputs[0])
  95. logit = net_out[0]
  96. outputs = OrderedDict()
  97. if mode == 'test':
  98. origin_shape = inputs[1]
  99. if self.status == 'Infer':
  100. label_map_list, score_map_list = self._postprocess(
  101. net_out, origin_shape, transforms=inputs[2])
  102. else:
  103. logit_list = self._postprocess(
  104. logit, origin_shape, transforms=inputs[2])
  105. label_map_list = []
  106. score_map_list = []
  107. for logit in logit_list:
  108. logit = paddle.transpose(logit, perm=[0, 2, 3, 1]) # NHWC
  109. label_map_list.append(
  110. paddle.argmax(
  111. logit, axis=-1, keepdim=False, dtype='int32')
  112. .squeeze().numpy())
  113. score_map_list.append(
  114. F.softmax(
  115. logit, axis=-1).squeeze().numpy().astype(
  116. 'float32'))
  117. outputs['label_map'] = label_map_list
  118. outputs['score_map'] = score_map_list
  119. if mode == 'eval':
  120. if self.status == 'Infer':
  121. pred = paddle.unsqueeze(net_out[0], axis=1) # NCHW
  122. else:
  123. pred = paddle.argmax(
  124. logit, axis=1, keepdim=True, dtype='int32')
  125. label = inputs[1]
  126. origin_shape = [label.shape[-2:]]
  127. pred = self._postprocess(
  128. pred, origin_shape, transforms=inputs[2])[0] # NCHW
  129. intersect_area, pred_area, label_area = paddleseg.utils.metrics.calculate_area(
  130. pred, label, self.num_classes)
  131. outputs['intersect_area'] = intersect_area
  132. outputs['pred_area'] = pred_area
  133. outputs['label_area'] = label_area
  134. outputs['conf_mat'] = metrics.confusion_matrix(pred, label,
  135. self.num_classes)
  136. if mode == 'train':
  137. loss_list = metrics.loss_computation(
  138. logits_list=net_out, labels=inputs[1], losses=self.losses)
  139. loss = sum(loss_list)
  140. outputs['loss'] = loss
  141. return outputs
  142. def default_loss(self):
  143. if isinstance(self.use_mixed_loss, bool):
  144. if self.use_mixed_loss:
  145. losses = [
  146. paddleseg.models.CrossEntropyLoss(),
  147. paddleseg.models.LovaszSoftmaxLoss()
  148. ]
  149. coef = [.8, .2]
  150. loss_type = [
  151. paddleseg.models.MixedLoss(
  152. losses=losses, coef=coef),
  153. ]
  154. else:
  155. loss_type = [paddleseg.models.CrossEntropyLoss()]
  156. else:
  157. losses, coef = list(zip(*self.use_mixed_loss))
  158. if not set(losses).issubset(
  159. ['CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss']):
  160. raise ValueError(
  161. "Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported."
  162. )
  163. losses = [getattr(paddleseg.models, loss)() for loss in losses]
  164. loss_type = [
  165. paddleseg.models.MixedLoss(
  166. losses=losses, coef=list(coef))
  167. ]
  168. if self.model_name == 'FastSCNN':
  169. loss_type *= 2
  170. loss_coef = [1.0, 0.4]
  171. elif self.model_name == 'BiSeNetV2':
  172. loss_type *= 5
  173. loss_coef = [1.0] * 5
  174. else:
  175. loss_coef = [1.0]
  176. losses = {'types': loss_type, 'coef': loss_coef}
  177. return losses
  178. def default_optimizer(self,
  179. parameters,
  180. learning_rate,
  181. num_epochs,
  182. num_steps_each_epoch,
  183. lr_decay_power=0.9):
  184. decay_step = num_epochs * num_steps_each_epoch
  185. lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
  186. learning_rate, decay_step, end_lr=0, power=lr_decay_power)
  187. optimizer = paddle.optimizer.Momentum(
  188. learning_rate=lr_scheduler,
  189. parameters=parameters,
  190. momentum=0.9,
  191. weight_decay=4e-5)
  192. return optimizer
  193. def train(self,
  194. num_epochs,
  195. train_dataset,
  196. train_batch_size=2,
  197. eval_dataset=None,
  198. optimizer=None,
  199. save_interval_epochs=1,
  200. log_interval_steps=2,
  201. save_dir='output',
  202. pretrain_weights='CITYSCAPES',
  203. learning_rate=0.01,
  204. lr_decay_power=0.9,
  205. early_stop=False,
  206. early_stop_patience=5,
  207. use_vdl=True,
  208. resume_checkpoint=None):
  209. """
  210. Train the model.
  211. Args:
  212. num_epochs(int): The number of epochs.
  213. train_dataset(paddlers.dataset): Training dataset.
  214. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
  215. eval_dataset(paddlers.dataset, optional):
  216. Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
  217. optimizer(paddle.optimizer.Optimizer or None, optional):
  218. Optimizer used in training. If None, a default optimizer is used. Defaults to None.
  219. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  220. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  221. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  222. pretrain_weights(str or None, optional):
  223. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'CITYSCAPES'.
  224. learning_rate(float, optional): Learning rate for training. Defaults to .025.
  225. lr_decay_power(float, optional): Learning decay power. Defaults to .9.
  226. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  227. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  228. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  229. resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
  230. If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
  231. `pretrain_weights` can be set simultaneously. Defaults to None.
  232. """
  233. if self.status == 'Infer':
  234. logging.error(
  235. "Exported inference model does not support training.",
  236. exit=True)
  237. if pretrain_weights is not None and resume_checkpoint is not None:
  238. logging.error(
  239. "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
  240. exit=True)
  241. self.labels = train_dataset.labels
  242. if self.losses is None:
  243. self.losses = self.default_loss()
  244. if optimizer is None:
  245. num_steps_each_epoch = train_dataset.num_samples // train_batch_size
  246. self.optimizer = self.default_optimizer(
  247. self.net.parameters(), learning_rate, num_epochs,
  248. num_steps_each_epoch, lr_decay_power)
  249. else:
  250. self.optimizer = optimizer
  251. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  252. if pretrain_weights not in seg_pretrain_weights_dict[
  253. self.model_name]:
  254. logging.warning(
  255. "Path of pretrain_weights('{}') does not exist!".format(
  256. pretrain_weights))
  257. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  258. "If don't want to use pretrain weights, "
  259. "set pretrain_weights to be None.".format(
  260. seg_pretrain_weights_dict[self.model_name][
  261. 0]))
  262. pretrain_weights = seg_pretrain_weights_dict[self.model_name][
  263. 0]
  264. elif pretrain_weights is not None and osp.exists(pretrain_weights):
  265. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  266. logging.error(
  267. "Invalid pretrain weights. Please specify a '.pdparams' file.",
  268. exit=True)
  269. pretrained_dir = osp.join(save_dir, 'pretrain')
  270. is_backbone_weights = pretrain_weights == 'IMAGENET'
  271. self.net_initialize(
  272. pretrain_weights=pretrain_weights,
  273. save_dir=pretrained_dir,
  274. resume_checkpoint=resume_checkpoint,
  275. is_backbone_weights=is_backbone_weights)
  276. self.train_loop(
  277. num_epochs=num_epochs,
  278. train_dataset=train_dataset,
  279. train_batch_size=train_batch_size,
  280. eval_dataset=eval_dataset,
  281. save_interval_epochs=save_interval_epochs,
  282. log_interval_steps=log_interval_steps,
  283. save_dir=save_dir,
  284. early_stop=early_stop,
  285. early_stop_patience=early_stop_patience,
  286. use_vdl=use_vdl)
  287. def quant_aware_train(self,
  288. num_epochs,
  289. train_dataset,
  290. train_batch_size=2,
  291. eval_dataset=None,
  292. optimizer=None,
  293. save_interval_epochs=1,
  294. log_interval_steps=2,
  295. save_dir='output',
  296. learning_rate=0.0001,
  297. lr_decay_power=0.9,
  298. early_stop=False,
  299. early_stop_patience=5,
  300. use_vdl=True,
  301. resume_checkpoint=None,
  302. quant_config=None):
  303. """
  304. Quantization-aware training.
  305. Args:
  306. num_epochs(int): The number of epochs.
  307. train_dataset(paddlers.dataset): Training dataset.
  308. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
  309. eval_dataset(paddlers.dataset, optional):
  310. Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
  311. optimizer(paddle.optimizer.Optimizer or None, optional):
  312. Optimizer used in training. If None, a default optimizer is used. Defaults to None.
  313. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  314. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  315. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  316. learning_rate(float, optional): Learning rate for training. Defaults to .025.
  317. lr_decay_power(float, optional): Learning decay power. Defaults to .9.
  318. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  319. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  320. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  321. quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
  322. configuration will be used. Defaults to None.
  323. resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
  324. from. If None, no training checkpoint will be resumed. Defaults to None.
  325. """
  326. self._prepare_qat(quant_config)
  327. self.train(
  328. num_epochs=num_epochs,
  329. train_dataset=train_dataset,
  330. train_batch_size=train_batch_size,
  331. eval_dataset=eval_dataset,
  332. optimizer=optimizer,
  333. save_interval_epochs=save_interval_epochs,
  334. log_interval_steps=log_interval_steps,
  335. save_dir=save_dir,
  336. pretrain_weights=None,
  337. learning_rate=learning_rate,
  338. lr_decay_power=lr_decay_power,
  339. early_stop=early_stop,
  340. early_stop_patience=early_stop_patience,
  341. use_vdl=use_vdl,
  342. resume_checkpoint=resume_checkpoint)
  343. def evaluate(self, eval_dataset, batch_size=1, return_details=False):
  344. """
  345. Evaluate the model.
  346. Args:
  347. eval_dataset(paddlers.dataset): Evaluation dataset.
  348. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  349. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  350. Returns:
  351. collections.OrderedDict with key-value pairs:
  352. {"miou": `mean intersection over union`,
  353. "category_iou": `category-wise mean intersection over union`,
  354. "oacc": `overall accuracy`,
  355. "category_acc": `category-wise accuracy`,
  356. "kappa": ` kappa coefficient`,
  357. "category_F1-score": `F1 score`}.
  358. """
  359. arrange_transforms(
  360. model_type=self.model_type,
  361. transforms=eval_dataset.transforms,
  362. mode='eval')
  363. self.net.eval()
  364. nranks = paddle.distributed.get_world_size()
  365. local_rank = paddle.distributed.get_rank()
  366. if nranks > 1:
  367. # Initialize parallel environment if not done.
  368. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  369. ):
  370. paddle.distributed.init_parallel_env()
  371. batch_size_each_card = get_single_card_bs(batch_size)
  372. if batch_size_each_card > 1:
  373. batch_size_each_card = 1
  374. batch_size = batch_size_each_card * paddlers.env_info['num']
  375. logging.warning(
  376. "Segmenter only supports batch_size=1 for each gpu/cpu card " \
  377. "during evaluation, so batch_size " \
  378. "is forcibly set to {}.".format(batch_size))
  379. self.eval_data_loader = self.build_data_loader(
  380. eval_dataset, batch_size=batch_size, mode='eval')
  381. intersect_area_all = 0
  382. pred_area_all = 0
  383. label_area_all = 0
  384. conf_mat_all = []
  385. logging.info(
  386. "Start to evaluate(total_samples={}, total_steps={})...".format(
  387. eval_dataset.num_samples,
  388. math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
  389. with paddle.no_grad():
  390. for step, data in enumerate(self.eval_data_loader):
  391. data.append(eval_dataset.transforms.transforms)
  392. outputs = self.run(self.net, data, 'eval')
  393. pred_area = outputs['pred_area']
  394. label_area = outputs['label_area']
  395. intersect_area = outputs['intersect_area']
  396. conf_mat = outputs['conf_mat']
  397. # Gather from all ranks
  398. if nranks > 1:
  399. intersect_area_list = []
  400. pred_area_list = []
  401. label_area_list = []
  402. conf_mat_list = []
  403. paddle.distributed.all_gather(intersect_area_list,
  404. intersect_area)
  405. paddle.distributed.all_gather(pred_area_list, pred_area)
  406. paddle.distributed.all_gather(label_area_list, label_area)
  407. paddle.distributed.all_gather(conf_mat_list, conf_mat)
  408. # Some image has been evaluated and should be eliminated in last iter
  409. if (step + 1) * nranks > len(eval_dataset):
  410. valid = len(eval_dataset) - step * nranks
  411. intersect_area_list = intersect_area_list[:valid]
  412. pred_area_list = pred_area_list[:valid]
  413. label_area_list = label_area_list[:valid]
  414. conf_mat_list = conf_mat_list[:valid]
  415. intersect_area_all += sum(intersect_area_list)
  416. pred_area_all += sum(pred_area_list)
  417. label_area_all += sum(label_area_list)
  418. conf_mat_all.extend(conf_mat_list)
  419. else:
  420. intersect_area_all = intersect_area_all + intersect_area
  421. pred_area_all = pred_area_all + pred_area
  422. label_area_all = label_area_all + label_area
  423. conf_mat_all.append(conf_mat)
  424. class_iou, miou = paddleseg.utils.metrics.mean_iou(
  425. intersect_area_all, pred_area_all, label_area_all)
  426. # TODO 确认是按oacc还是macc
  427. class_acc, oacc = paddleseg.utils.metrics.accuracy(intersect_area_all,
  428. pred_area_all)
  429. kappa = paddleseg.utils.metrics.kappa(intersect_area_all,
  430. pred_area_all, label_area_all)
  431. category_f1score = metrics.f1_score(intersect_area_all, pred_area_all,
  432. label_area_all)
  433. eval_metrics = OrderedDict(
  434. zip([
  435. 'miou', 'category_iou', 'oacc', 'category_acc', 'kappa',
  436. 'category_F1-score'
  437. ], [miou, class_iou, oacc, class_acc, kappa, category_f1score]))
  438. if return_details:
  439. conf_mat = sum(conf_mat_all)
  440. eval_details = {'confusion_matrix': conf_mat.tolist()}
  441. return eval_metrics, eval_details
  442. return eval_metrics
  443. def predict(self, img_file, transforms=None):
  444. """
  445. Do inference.
  446. Args:
  447. Args:
  448. img_file(List[np.ndarray or str], str or np.ndarray):
  449. Image path or decoded image data in a BGR format, which also could constitute a list,
  450. meaning all images to be predicted as a mini-batch.
  451. transforms(paddlers.transforms.Compose or None, optional):
  452. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  453. Returns:
  454. If img_file is a string or np.array, the result is a dict with key-value pairs:
  455. {"label map": `label map`, "score_map": `score map`}.
  456. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  457. label_map(np.ndarray): the predicted label map (HW)
  458. score_map(np.ndarray): the prediction score map (HWC)
  459. """
  460. if transforms is None and not hasattr(self, 'test_transforms'):
  461. raise Exception("transforms need to be defined, now is None.")
  462. if transforms is None:
  463. transforms = self.test_transforms
  464. if isinstance(img_file, (str, np.ndarray)):
  465. images = [img_file]
  466. else:
  467. images = img_file
  468. batch_im, batch_origin_shape = self._preprocess(images, transforms,
  469. self.model_type)
  470. self.net.eval()
  471. data = (batch_im, batch_origin_shape, transforms.transforms)
  472. outputs = self.run(self.net, data, 'test')
  473. label_map_list = outputs['label_map']
  474. score_map_list = outputs['score_map']
  475. if isinstance(img_file, list):
  476. prediction = [{
  477. 'label_map': l,
  478. 'score_map': s
  479. } for l, s in zip(label_map_list, score_map_list)]
  480. else:
  481. prediction = {
  482. 'label_map': label_map_list[0],
  483. 'score_map': score_map_list[0]
  484. }
  485. return prediction
  486. def _preprocess(self, images, transforms, to_tensor=True):
  487. arrange_transforms(
  488. model_type=self.model_type, transforms=transforms, mode='test')
  489. batch_im = list()
  490. batch_ori_shape = list()
  491. for im in images:
  492. sample = {'image': im}
  493. if isinstance(sample['image'], str):
  494. sample = ImgDecode(to_rgb=False)(sample)
  495. ori_shape = sample['image'].shape[:2]
  496. im = transforms(sample)[0]
  497. batch_im.append(im)
  498. batch_ori_shape.append(ori_shape)
  499. if to_tensor:
  500. batch_im = paddle.to_tensor(batch_im)
  501. else:
  502. batch_im = np.asarray(batch_im)
  503. return batch_im, batch_ori_shape
  504. @staticmethod
  505. def get_transforms_shape_info(batch_ori_shape, transforms):
  506. batch_restore_list = list()
  507. for ori_shape in batch_ori_shape:
  508. restore_list = list()
  509. h, w = ori_shape[0], ori_shape[1]
  510. for op in transforms:
  511. if op.__class__.__name__ == 'Resize':
  512. restore_list.append(('resize', (h, w)))
  513. h, w = op.target_size
  514. elif op.__class__.__name__ == 'ResizeByShort':
  515. restore_list.append(('resize', (h, w)))
  516. im_short_size = min(h, w)
  517. im_long_size = max(h, w)
  518. scale = float(op.short_size) / float(im_short_size)
  519. if 0 < op.max_size < np.round(scale * im_long_size):
  520. scale = float(op.max_size) / float(im_long_size)
  521. h = int(round(h * scale))
  522. w = int(round(w * scale))
  523. elif op.__class__.__name__ == 'ResizeByLong':
  524. restore_list.append(('resize', (h, w)))
  525. im_long_size = max(h, w)
  526. scale = float(op.long_size) / float(im_long_size)
  527. h = int(round(h * scale))
  528. w = int(round(w * scale))
  529. elif op.__class__.__name__ == 'Padding':
  530. if op.target_size:
  531. target_h, target_w = op.target_size
  532. else:
  533. target_h = int(
  534. (np.ceil(h / op.size_divisor) * op.size_divisor))
  535. target_w = int(
  536. (np.ceil(w / op.size_divisor) * op.size_divisor))
  537. if op.pad_mode == -1:
  538. offsets = op.offsets
  539. elif op.pad_mode == 0:
  540. offsets = [0, 0]
  541. elif op.pad_mode == 1:
  542. offsets = [(target_h - h) // 2, (target_w - w) // 2]
  543. else:
  544. offsets = [target_h - h, target_w - w]
  545. restore_list.append(('padding', (h, w), offsets))
  546. h, w = target_h, target_w
  547. batch_restore_list.append(restore_list)
  548. return batch_restore_list
  549. def _postprocess(self, batch_pred, batch_origin_shape, transforms):
  550. batch_restore_list = BaseSegmenter.get_transforms_shape_info(
  551. batch_origin_shape, transforms)
  552. if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
  553. return self._infer_postprocess(
  554. batch_label_map=batch_pred[0],
  555. batch_score_map=batch_pred[1],
  556. batch_restore_list=batch_restore_list)
  557. results = []
  558. if batch_pred.dtype == paddle.float32:
  559. mode = 'bilinear'
  560. else:
  561. mode = 'nearest'
  562. for pred, restore_list in zip(batch_pred, batch_restore_list):
  563. pred = paddle.unsqueeze(pred, axis=0)
  564. for item in restore_list[::-1]:
  565. h, w = item[1][0], item[1][1]
  566. if item[0] == 'resize':
  567. pred = F.interpolate(
  568. pred, (h, w), mode=mode, data_format='NCHW')
  569. elif item[0] == 'padding':
  570. x, y = item[2]
  571. pred = pred[:, :, y:y + h, x:x + w]
  572. else:
  573. pass
  574. results.append(pred)
  575. return results
  576. def _infer_postprocess(self, batch_label_map, batch_score_map,
  577. batch_restore_list):
  578. label_maps = []
  579. score_maps = []
  580. for label_map, score_map, restore_list in zip(
  581. batch_label_map, batch_score_map, batch_restore_list):
  582. if not isinstance(label_map, np.ndarray):
  583. label_map = paddle.unsqueeze(label_map, axis=[0, 3])
  584. score_map = paddle.unsqueeze(score_map, axis=0)
  585. for item in restore_list[::-1]:
  586. h, w = item[1][0], item[1][1]
  587. if item[0] == 'resize':
  588. if isinstance(label_map, np.ndarray):
  589. label_map = cv2.resize(
  590. label_map, (w, h), interpolation=cv2.INTER_NEAREST)
  591. score_map = cv2.resize(
  592. score_map, (w, h), interpolation=cv2.INTER_LINEAR)
  593. else:
  594. label_map = F.interpolate(
  595. label_map, (h, w),
  596. mode='nearest',
  597. data_format='NHWC')
  598. score_map = F.interpolate(
  599. score_map, (h, w),
  600. mode='bilinear',
  601. data_format='NHWC')
  602. elif item[0] == 'padding':
  603. x, y = item[2]
  604. if isinstance(label_map, np.ndarray):
  605. label_map = label_map[..., y:y + h, x:x + w]
  606. score_map = score_map[..., y:y + h, x:x + w]
  607. else:
  608. label_map = label_map[:, :, y:y + h, x:x + w]
  609. score_map = score_map[:, :, y:y + h, x:x + w]
  610. else:
  611. pass
  612. label_map = label_map.squeeze()
  613. score_map = score_map.squeeze()
  614. if not isinstance(label_map, np.ndarray):
  615. label_map = label_map.numpy()
  616. score_map = score_map.numpy()
  617. label_maps.append(label_map.squeeze())
  618. score_maps.append(score_map.squeeze())
  619. return label_maps, score_maps
  620. class UNet(BaseSegmenter):
  621. def __init__(self,
  622. num_classes=2,
  623. use_mixed_loss=False,
  624. use_deconv=False,
  625. align_corners=False,
  626. **params):
  627. params.update({
  628. 'use_deconv': use_deconv,
  629. 'align_corners': align_corners
  630. })
  631. super(UNet, self).__init__(
  632. model_name='UNet',
  633. num_classes=num_classes,
  634. use_mixed_loss=use_mixed_loss,
  635. **params)
  636. class DeepLabV3P(BaseSegmenter):
  637. def __init__(self,
  638. input_channel=3,
  639. num_classes=2,
  640. backbone='ResNet50_vd',
  641. use_mixed_loss=False,
  642. output_stride=8,
  643. backbone_indices=(0, 3),
  644. aspp_ratios=(1, 12, 24, 36),
  645. aspp_out_channels=256,
  646. align_corners=False,
  647. **params):
  648. self.backbone_name = backbone
  649. if backbone not in ['ResNet50_vd', 'ResNet101_vd']:
  650. raise ValueError(
  651. "backbone: {} is not supported. Please choose one of "
  652. "('ResNet50_vd', 'ResNet101_vd')".format(backbone))
  653. if params.get('with_net', True):
  654. with DisablePrint():
  655. backbone = getattr(paddleseg.models, backbone)(
  656. input_channel=input_channel,
  657. output_stride=output_stride)
  658. else:
  659. backbone = None
  660. params.update({
  661. 'backbone': backbone,
  662. 'backbone_indices': backbone_indices,
  663. 'aspp_ratios': aspp_ratios,
  664. 'aspp_out_channels': aspp_out_channels,
  665. 'align_corners': align_corners
  666. })
  667. super(DeepLabV3P, self).__init__(
  668. model_name='DeepLabV3P',
  669. num_classes=num_classes,
  670. use_mixed_loss=use_mixed_loss,
  671. **params)
  672. class FastSCNN(BaseSegmenter):
  673. def __init__(self,
  674. num_classes=2,
  675. use_mixed_loss=False,
  676. align_corners=False,
  677. **params):
  678. params.update({'align_corners': align_corners})
  679. super(FastSCNN, self).__init__(
  680. model_name='FastSCNN',
  681. num_classes=num_classes,
  682. use_mixed_loss=use_mixed_loss,
  683. **params)
  684. class HRNet(BaseSegmenter):
  685. def __init__(self,
  686. num_classes=2,
  687. width=48,
  688. use_mixed_loss=False,
  689. align_corners=False,
  690. **params):
  691. if width not in (18, 48):
  692. raise ValueError(
  693. "width={} is not supported, please choose from [18, 48]".
  694. format(width))
  695. self.backbone_name = 'HRNet_W{}'.format(width)
  696. if params.get('with_net', True):
  697. with DisablePrint():
  698. backbone = getattr(paddleseg.models, self.backbone_name)(
  699. align_corners=align_corners)
  700. else:
  701. backbone = None
  702. params.update({'backbone': backbone, 'align_corners': align_corners})
  703. super(HRNet, self).__init__(
  704. model_name='FCN',
  705. num_classes=num_classes,
  706. use_mixed_loss=use_mixed_loss,
  707. **params)
  708. self.model_name = 'HRNet'
  709. class BiSeNetV2(BaseSegmenter):
  710. def __init__(self,
  711. num_classes=2,
  712. use_mixed_loss=False,
  713. align_corners=False,
  714. **params):
  715. params.update({'align_corners': align_corners})
  716. super(BiSeNetV2, self).__init__(
  717. model_name='BiSeNetV2',
  718. num_classes=num_classes,
  719. use_mixed_loss=use_mixed_loss,
  720. **params)
  721. class FarSeg(BaseSegmenter):
  722. def __init__(self,
  723. num_classes=2,
  724. use_mixed_loss=False,
  725. **params):
  726. super(FarSeg, self).__init__(
  727. model_name='FarSeg',
  728. num_classes=num_classes,
  729. use_mixed_loss=use_mixed_loss,
  730. **params)