segmenter.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import os
  16. import os.path as osp
  17. from collections import OrderedDict
  18. import numpy as np
  19. import cv2
  20. import paddle
  21. import paddle.nn.functional as F
  22. from paddle.static import InputSpec
  23. import paddlers
  24. import paddlers.models.ppseg as ppseg
  25. import paddlers.rs_models.seg as cmseg
  26. import paddlers.utils.logging as logging
  27. from paddlers.models import seg_losses
  28. from paddlers.transforms import Resize, decode_image
  29. from paddlers.utils import get_single_card_bs, DisablePrint
  30. from paddlers.utils.checkpoint import seg_pretrain_weights_dict
  31. from .base import BaseModel
  32. from .utils import seg_metrics as metrics
  33. from .utils.infer_nets import InferSegNet
  34. from .utils.slider_predict import slider_predict
  35. __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"]
  36. class BaseSegmenter(BaseModel):
  37. def __init__(self,
  38. model_name,
  39. num_classes=2,
  40. use_mixed_loss=False,
  41. losses=None,
  42. **params):
  43. self.init_params = locals()
  44. if 'with_net' in self.init_params:
  45. del self.init_params['with_net']
  46. super(BaseSegmenter, self).__init__('segmenter')
  47. if not hasattr(ppseg.models, model_name) and \
  48. not hasattr(cmseg, model_name):
  49. raise ValueError("ERROR: There is no model named {}.".format(
  50. model_name))
  51. self.model_name = model_name
  52. self.num_classes = num_classes
  53. self.use_mixed_loss = use_mixed_loss
  54. self.losses = losses
  55. self.labels = None
  56. if params.get('with_net', True):
  57. params.pop('with_net', None)
  58. self.net = self.build_net(**params)
  59. self.find_unused_parameters = True
  60. def build_net(self, **params):
  61. # TODO: when using paddle.utils.unique_name.guard,
  62. # DeepLabv3p and HRNet will raise an error.
  63. net = dict(ppseg.models.__dict__, **cmseg.__dict__)[self.model_name](
  64. num_classes=self.num_classes, **params)
  65. return net
  66. def _build_inference_net(self):
  67. infer_net = InferSegNet(self.net)
  68. infer_net.eval()
  69. return infer_net
  70. def _fix_transforms_shape(self, image_shape):
  71. if hasattr(self, 'test_transforms'):
  72. if self.test_transforms is not None:
  73. has_resize_op = False
  74. resize_op_idx = -1
  75. normalize_op_idx = len(self.test_transforms.transforms)
  76. for idx, op in enumerate(self.test_transforms.transforms):
  77. name = op.__class__.__name__
  78. if name == 'Normalize':
  79. normalize_op_idx = idx
  80. if 'Resize' in name:
  81. has_resize_op = True
  82. resize_op_idx = idx
  83. if not has_resize_op:
  84. self.test_transforms.transforms.insert(
  85. normalize_op_idx, Resize(target_size=image_shape))
  86. else:
  87. self.test_transforms.transforms[resize_op_idx] = Resize(
  88. target_size=image_shape)
  89. def _get_test_inputs(self, image_shape):
  90. if image_shape is not None:
  91. if len(image_shape) == 2:
  92. image_shape = [1, 3] + image_shape
  93. self._fix_transforms_shape(image_shape[-2:])
  94. else:
  95. image_shape = [None, 3, -1, -1]
  96. self.fixed_input_shape = image_shape
  97. input_spec = [
  98. InputSpec(
  99. shape=image_shape, name='image', dtype='float32')
  100. ]
  101. return input_spec
  102. def run(self, net, inputs, mode):
  103. net_out = net(inputs[0])
  104. logit = net_out[0]
  105. outputs = OrderedDict()
  106. if mode == 'test':
  107. origin_shape = inputs[1]
  108. if self.status == 'Infer':
  109. label_map_list, score_map_list = self.postprocess(
  110. net_out, origin_shape, transforms=inputs[2])
  111. else:
  112. logit_list = self.postprocess(
  113. logit, origin_shape, transforms=inputs[2])
  114. label_map_list = []
  115. score_map_list = []
  116. for logit in logit_list:
  117. logit = paddle.transpose(logit, perm=[0, 2, 3, 1]) # NHWC
  118. label_map_list.append(
  119. paddle.argmax(
  120. logit, axis=-1, keepdim=False, dtype='int32')
  121. .squeeze().numpy())
  122. score_map_list.append(
  123. F.softmax(
  124. logit, axis=-1).squeeze().numpy().astype('float32'))
  125. outputs['label_map'] = label_map_list
  126. outputs['score_map'] = score_map_list
  127. if mode == 'eval':
  128. if self.status == 'Infer':
  129. pred = paddle.unsqueeze(net_out[0], axis=1) # NCHW
  130. else:
  131. pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
  132. label = inputs[1]
  133. if label.ndim == 3:
  134. paddle.unsqueeze_(label, axis=1)
  135. if label.ndim != 4:
  136. raise ValueError("Expected label.ndim == 4 but got {}".format(
  137. label.ndim))
  138. origin_shape = [label.shape[-2:]]
  139. pred = self.postprocess(
  140. pred, origin_shape, transforms=inputs[2])[0] # NCHW
  141. intersect_area, pred_area, label_area = ppseg.utils.metrics.calculate_area(
  142. pred, label, self.num_classes)
  143. outputs['intersect_area'] = intersect_area
  144. outputs['pred_area'] = pred_area
  145. outputs['label_area'] = label_area
  146. outputs['conf_mat'] = metrics.confusion_matrix(pred, label,
  147. self.num_classes)
  148. if mode == 'train':
  149. loss_list = metrics.loss_computation(
  150. logits_list=net_out, labels=inputs[1], losses=self.losses)
  151. loss = sum(loss_list)
  152. outputs['loss'] = loss
  153. return outputs
  154. def default_loss(self):
  155. if isinstance(self.use_mixed_loss, bool):
  156. if self.use_mixed_loss:
  157. losses = [
  158. seg_losses.CrossEntropyLoss(),
  159. seg_losses.LovaszSoftmaxLoss()
  160. ]
  161. coef = [.8, .2]
  162. loss_type = [seg_losses.MixedLoss(losses=losses, coef=coef), ]
  163. else:
  164. loss_type = [seg_losses.CrossEntropyLoss()]
  165. else:
  166. losses, coef = list(zip(*self.use_mixed_loss))
  167. if not set(losses).issubset(
  168. ['CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss']):
  169. raise ValueError(
  170. "Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported."
  171. )
  172. losses = [getattr(seg_losses, loss)() for loss in losses]
  173. loss_type = [seg_losses.MixedLoss(losses=losses, coef=list(coef))]
  174. if self.model_name == 'FastSCNN':
  175. loss_type *= 2
  176. loss_coef = [1.0, 0.4]
  177. elif self.model_name == 'BiSeNetV2':
  178. loss_type *= 5
  179. loss_coef = [1.0] * 5
  180. else:
  181. loss_coef = [1.0]
  182. losses = {'types': loss_type, 'coef': loss_coef}
  183. return losses
  184. def default_optimizer(self,
  185. parameters,
  186. learning_rate,
  187. num_epochs,
  188. num_steps_each_epoch,
  189. lr_decay_power=0.9):
  190. decay_step = num_epochs * num_steps_each_epoch
  191. lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
  192. learning_rate, decay_step, end_lr=0, power=lr_decay_power)
  193. optimizer = paddle.optimizer.Momentum(
  194. learning_rate=lr_scheduler,
  195. parameters=parameters,
  196. momentum=0.9,
  197. weight_decay=4e-5)
  198. return optimizer
  199. def train(self,
  200. num_epochs,
  201. train_dataset,
  202. train_batch_size=2,
  203. eval_dataset=None,
  204. optimizer=None,
  205. save_interval_epochs=1,
  206. log_interval_steps=2,
  207. save_dir='output',
  208. pretrain_weights='CITYSCAPES',
  209. learning_rate=0.01,
  210. lr_decay_power=0.9,
  211. early_stop=False,
  212. early_stop_patience=5,
  213. use_vdl=True,
  214. resume_checkpoint=None):
  215. """
  216. Train the model.
  217. Args:
  218. num_epochs (int): Number of epochs.
  219. train_dataset (paddlers.datasets.SegDataset): Training dataset.
  220. train_batch_size (int, optional): Total batch size among all cards used in
  221. training. Defaults to 2.
  222. eval_dataset (paddlers.datasets.SegDataset|None, optional): Evaluation dataset.
  223. If None, the model will not be evaluated during training process.
  224. Defaults to None.
  225. optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in
  226. training. If None, a default optimizer will be used. Defaults to None.
  227. save_interval_epochs (int, optional): Epoch interval for saving the model.
  228. Defaults to 1.
  229. log_interval_steps (int, optional): Step interval for printing training
  230. information. Defaults to 2.
  231. save_dir (str, optional): Directory to save the model. Defaults to 'output'.
  232. pretrain_weights (str|None, optional): None or name/path of pretrained
  233. weights. If None, no pretrained weights will be loaded.
  234. Defaults to 'CITYSCAPES'.
  235. learning_rate (float, optional): Learning rate for training. Defaults to .01.
  236. lr_decay_power (float, optional): Learning decay power. Defaults to .9.
  237. early_stop (bool, optional): Whether to adopt early stop strategy. Defaults
  238. to False.
  239. early_stop_patience (int, optional): Early stop patience. Defaults to 5.
  240. use_vdl (bool, optional): Whether to use VisualDL to monitor the training
  241. process. Defaults to True.
  242. resume_checkpoint (str|None, optional): Path of the checkpoint to resume
  243. training from. If None, no training checkpoint will be resumed. At most
  244. Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
  245. Defaults to None.
  246. """
  247. if self.status == 'Infer':
  248. logging.error(
  249. "Exported inference model does not support training.",
  250. exit=True)
  251. if pretrain_weights is not None and resume_checkpoint is not None:
  252. logging.error(
  253. "`pretrain_weights` and `resume_checkpoint` cannot be set simultaneously.",
  254. exit=True)
  255. self.labels = train_dataset.labels
  256. if self.losses is None:
  257. self.losses = self.default_loss()
  258. if optimizer is None:
  259. num_steps_each_epoch = train_dataset.num_samples // train_batch_size
  260. self.optimizer = self.default_optimizer(
  261. self.net.parameters(), learning_rate, num_epochs,
  262. num_steps_each_epoch, lr_decay_power)
  263. else:
  264. self.optimizer = optimizer
  265. if pretrain_weights is not None:
  266. if not osp.exists(pretrain_weights):
  267. if self.model_name not in seg_pretrain_weights_dict:
  268. logging.warning(
  269. "Path of pretrained weights ('{}') does not exist!".
  270. format(pretrain_weights))
  271. pretrain_weights = None
  272. elif pretrain_weights not in seg_pretrain_weights_dict[
  273. self.model_name]:
  274. logging.warning(
  275. "Path of pretrained weights ('{}') does not exist!".
  276. format(pretrain_weights))
  277. pretrain_weights = seg_pretrain_weights_dict[
  278. self.model_name][0]
  279. logging.warning(
  280. "`pretrain_weights` is forcibly set to '{}'. "
  281. "If you don't want to use pretrained weights, "
  282. "please set `pretrain_weights` to None.".format(
  283. pretrain_weights))
  284. else:
  285. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  286. logging.error(
  287. "Invalid pretrained weights. Please specify a .pdparams file.",
  288. exit=True)
  289. pretrained_dir = osp.join(save_dir, 'pretrain')
  290. is_backbone_weights = pretrain_weights == 'IMAGENET'
  291. self.initialize_net(
  292. pretrain_weights=pretrain_weights,
  293. save_dir=pretrained_dir,
  294. resume_checkpoint=resume_checkpoint,
  295. is_backbone_weights=is_backbone_weights)
  296. self.train_loop(
  297. num_epochs=num_epochs,
  298. train_dataset=train_dataset,
  299. train_batch_size=train_batch_size,
  300. eval_dataset=eval_dataset,
  301. save_interval_epochs=save_interval_epochs,
  302. log_interval_steps=log_interval_steps,
  303. save_dir=save_dir,
  304. early_stop=early_stop,
  305. early_stop_patience=early_stop_patience,
  306. use_vdl=use_vdl)
  307. def quant_aware_train(self,
  308. num_epochs,
  309. train_dataset,
  310. train_batch_size=2,
  311. eval_dataset=None,
  312. optimizer=None,
  313. save_interval_epochs=1,
  314. log_interval_steps=2,
  315. save_dir='output',
  316. learning_rate=0.0001,
  317. lr_decay_power=0.9,
  318. early_stop=False,
  319. early_stop_patience=5,
  320. use_vdl=True,
  321. resume_checkpoint=None,
  322. quant_config=None):
  323. """
  324. Quantization-aware training.
  325. Args:
  326. num_epochs (int): Number of epochs.
  327. train_dataset (paddlers.datasets.SegDataset): Training dataset.
  328. train_batch_size (int, optional): Total batch size among all cards used in
  329. training. Defaults to 2.
  330. eval_dataset (paddlers.datasets.SegDataset|None, optional): Evaluation dataset.
  331. If None, the model will not be evaluated during training process.
  332. Defaults to None.
  333. optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in
  334. training. If None, a default optimizer will be used. Defaults to None.
  335. save_interval_epochs (int, optional): Epoch interval for saving the model.
  336. Defaults to 1.
  337. log_interval_steps (int, optional): Step interval for printing training
  338. information. Defaults to 2.
  339. save_dir (str, optional): Directory to save the model. Defaults to 'output'.
  340. learning_rate (float, optional): Learning rate for training.
  341. Defaults to .0001.
  342. lr_decay_power (float, optional): Learning decay power. Defaults to .9.
  343. early_stop (bool, optional): Whether to adopt early stop strategy.
  344. Defaults to False.
  345. early_stop_patience (int, optional): Early stop patience. Defaults to 5.
  346. use_vdl (bool, optional): Whether to use VisualDL to monitor the training
  347. process. Defaults to True.
  348. quant_config (dict|None, optional): Quantization configuration. If None,
  349. a default rule of thumb configuration will be used. Defaults to None.
  350. resume_checkpoint (str|None, optional): Path of the checkpoint to resume
  351. quantization-aware training from. If None, no training checkpoint will
  352. be resumed. Defaults to None.
  353. """
  354. self._prepare_qat(quant_config)
  355. self.train(
  356. num_epochs=num_epochs,
  357. train_dataset=train_dataset,
  358. train_batch_size=train_batch_size,
  359. eval_dataset=eval_dataset,
  360. optimizer=optimizer,
  361. save_interval_epochs=save_interval_epochs,
  362. log_interval_steps=log_interval_steps,
  363. save_dir=save_dir,
  364. pretrain_weights=None,
  365. learning_rate=learning_rate,
  366. lr_decay_power=lr_decay_power,
  367. early_stop=early_stop,
  368. early_stop_patience=early_stop_patience,
  369. use_vdl=use_vdl,
  370. resume_checkpoint=resume_checkpoint)
  371. def evaluate(self, eval_dataset, batch_size=1, return_details=False):
  372. """
  373. Evaluate the model.
  374. Args:
  375. eval_dataset (paddlers.datasets.SegDataset): Evaluation dataset.
  376. batch_size (int, optional): Total batch size among all cards used for
  377. evaluation. Defaults to 1.
  378. return_details (bool, optional): Whether to return evaluation details.
  379. Defaults to False.
  380. Returns:
  381. collections.OrderedDict with key-value pairs:
  382. {"miou": mean intersection over union,
  383. "category_iou": category-wise mean intersection over union,
  384. "oacc": overall accuracy,
  385. "category_acc": category-wise accuracy,
  386. "kappa": kappa coefficient,
  387. "category_F1-score": F1 score}.
  388. """
  389. self._check_transforms(eval_dataset.transforms, 'eval')
  390. self.net.eval()
  391. nranks = paddle.distributed.get_world_size()
  392. local_rank = paddle.distributed.get_rank()
  393. if nranks > 1:
  394. # Initialize parallel environment if not done.
  395. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  396. ):
  397. paddle.distributed.init_parallel_env()
  398. batch_size_each_card = get_single_card_bs(batch_size)
  399. if batch_size_each_card > 1:
  400. batch_size_each_card = 1
  401. batch_size = batch_size_each_card * paddlers.env_info['num']
  402. logging.warning(
  403. "Segmenter only supports batch_size=1 for each gpu/cpu card " \
  404. "during evaluation, so batch_size " \
  405. "is forcibly set to {}.".format(batch_size))
  406. self.eval_data_loader = self.build_data_loader(
  407. eval_dataset, batch_size=batch_size, mode='eval')
  408. intersect_area_all = 0
  409. pred_area_all = 0
  410. label_area_all = 0
  411. conf_mat_all = []
  412. logging.info(
  413. "Start to evaluate(total_samples={}, total_steps={})...".format(
  414. eval_dataset.num_samples,
  415. math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
  416. with paddle.no_grad():
  417. for step, data in enumerate(self.eval_data_loader):
  418. data.append(eval_dataset.transforms.transforms)
  419. outputs = self.run(self.net, data, 'eval')
  420. pred_area = outputs['pred_area']
  421. label_area = outputs['label_area']
  422. intersect_area = outputs['intersect_area']
  423. conf_mat = outputs['conf_mat']
  424. # Gather from all ranks
  425. if nranks > 1:
  426. intersect_area_list = []
  427. pred_area_list = []
  428. label_area_list = []
  429. conf_mat_list = []
  430. paddle.distributed.all_gather(intersect_area_list,
  431. intersect_area)
  432. paddle.distributed.all_gather(pred_area_list, pred_area)
  433. paddle.distributed.all_gather(label_area_list, label_area)
  434. paddle.distributed.all_gather(conf_mat_list, conf_mat)
  435. # Some image has been evaluated and should be eliminated in last iter
  436. if (step + 1) * nranks > len(eval_dataset):
  437. valid = len(eval_dataset) - step * nranks
  438. intersect_area_list = intersect_area_list[:valid]
  439. pred_area_list = pred_area_list[:valid]
  440. label_area_list = label_area_list[:valid]
  441. conf_mat_list = conf_mat_list[:valid]
  442. intersect_area_all += sum(intersect_area_list)
  443. pred_area_all += sum(pred_area_list)
  444. label_area_all += sum(label_area_list)
  445. conf_mat_all.extend(conf_mat_list)
  446. else:
  447. intersect_area_all = intersect_area_all + intersect_area
  448. pred_area_all = pred_area_all + pred_area
  449. label_area_all = label_area_all + label_area
  450. conf_mat_all.append(conf_mat)
  451. class_iou, miou = ppseg.utils.metrics.mean_iou(
  452. intersect_area_all, pred_area_all, label_area_all)
  453. class_acc, oacc = ppseg.utils.metrics.accuracy(intersect_area_all,
  454. pred_area_all)
  455. kappa = ppseg.utils.metrics.kappa(intersect_area_all, pred_area_all,
  456. label_area_all)
  457. category_f1score = metrics.f1_score(intersect_area_all, pred_area_all,
  458. label_area_all)
  459. eval_metrics = OrderedDict(
  460. zip([
  461. 'miou', 'category_iou', 'oacc', 'category_acc', 'kappa',
  462. 'category_F1-score'
  463. ], [miou, class_iou, oacc, class_acc, kappa, category_f1score]))
  464. if return_details:
  465. conf_mat = sum(conf_mat_all)
  466. eval_details = {'confusion_matrix': conf_mat.tolist()}
  467. return eval_metrics, eval_details
  468. return eval_metrics
  469. def predict(self, img_file, transforms=None):
  470. """
  471. Do inference.
  472. Args:
  473. img_file (list[np.ndarray|str] | str | np.ndarray): Image path or decoded
  474. image data, which also could constitute a list, meaning all images to be
  475. predicted as a mini-batch.
  476. transforms (paddlers.transforms.Compose|None, optional): Transforms for
  477. inputs. If None, the transforms for evaluation process will be used.
  478. Defaults to None.
  479. Returns:
  480. If `img_file` is a tuple of string or np.array, the result is a dict with
  481. the following key-value pairs:
  482. label_map (np.ndarray): Predicted label map (HW).
  483. score_map (np.ndarray): Prediction score map (HWC).
  484. If `img_file` is a list, the result is a list composed of dicts with the
  485. above keys.
  486. """
  487. if transforms is None and not hasattr(self, 'test_transforms'):
  488. raise ValueError("transforms need to be defined, now is None.")
  489. if transforms is None:
  490. transforms = self.test_transforms
  491. if isinstance(img_file, (str, np.ndarray)):
  492. images = [img_file]
  493. else:
  494. images = img_file
  495. batch_im, batch_origin_shape = self.preprocess(images, transforms,
  496. self.model_type)
  497. self.net.eval()
  498. data = (batch_im, batch_origin_shape, transforms.transforms)
  499. outputs = self.run(self.net, data, 'test')
  500. label_map_list = outputs['label_map']
  501. score_map_list = outputs['score_map']
  502. if isinstance(img_file, list):
  503. prediction = [{
  504. 'label_map': l,
  505. 'score_map': s
  506. } for l, s in zip(label_map_list, score_map_list)]
  507. else:
  508. prediction = {
  509. 'label_map': label_map_list[0],
  510. 'score_map': score_map_list[0]
  511. }
  512. return prediction
  513. def slider_predict(self,
  514. img_file,
  515. save_dir,
  516. block_size,
  517. overlap=36,
  518. transforms=None,
  519. invalid_value=255,
  520. merge_strategy='keep_last',
  521. batch_size=1,
  522. quiet=False):
  523. """
  524. Do inference using sliding windows.
  525. Args:
  526. img_file (str): Image path.
  527. save_dir (str): Directory that contains saved geotiff file.
  528. block_size (list[int] | tuple[int] | int):
  529. Size of block. If `block_size` is list or tuple, it should be in
  530. (W, H) format.
  531. overlap (list[int] | tuple[int] | int, optional):
  532. Overlap between two blocks. If `overlap` is list or tuple, it should
  533. be in (W, H) format. Defaults to 36.
  534. transforms (paddlers.transforms.Compose|None, optional): Transforms for
  535. inputs. If None, the transforms for evaluation process will be used.
  536. Defaults to None.
  537. invalid_value (int, optional): Value that marks invalid pixels in output
  538. image. Defaults to 255.
  539. merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices
  540. are {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last'
  541. means keeping the values of the first and the last block in traversal
  542. order, respectively. 'accum' means determining the class of an overlapping
  543. pixel according to accumulated probabilities. Defaults to 'keep_last'.
  544. batch_size (int, optional): Batch size used in inference. Defaults to 1.
  545. quiet (bool, optional): If True, disable the progress bar. Defaults to False.
  546. """
  547. slider_predict(self.predict, img_file, save_dir, block_size, overlap,
  548. transforms, invalid_value, merge_strategy, batch_size,
  549. not quiet)
  550. def preprocess(self, images, transforms, to_tensor=True):
  551. self._check_transforms(transforms, 'test')
  552. batch_im = list()
  553. batch_ori_shape = list()
  554. for im in images:
  555. if isinstance(im, str):
  556. im = decode_image(im, read_raw=True)
  557. ori_shape = im.shape[:2]
  558. sample = {'image': im}
  559. im = transforms(sample)[0]
  560. batch_im.append(im)
  561. batch_ori_shape.append(ori_shape)
  562. if to_tensor:
  563. batch_im = paddle.to_tensor(batch_im)
  564. else:
  565. batch_im = np.asarray(batch_im)
  566. return batch_im, batch_ori_shape
  567. @staticmethod
  568. def get_transforms_shape_info(batch_ori_shape, transforms):
  569. batch_restore_list = list()
  570. for ori_shape in batch_ori_shape:
  571. restore_list = list()
  572. h, w = ori_shape[0], ori_shape[1]
  573. for op in transforms:
  574. if op.__class__.__name__ == 'Resize':
  575. restore_list.append(('resize', (h, w)))
  576. h, w = op.target_size
  577. elif op.__class__.__name__ == 'ResizeByShort':
  578. restore_list.append(('resize', (h, w)))
  579. im_short_size = min(h, w)
  580. im_long_size = max(h, w)
  581. scale = float(op.short_size) / float(im_short_size)
  582. if 0 < op.max_size < np.round(scale * im_long_size):
  583. scale = float(op.max_size) / float(im_long_size)
  584. h = int(round(h * scale))
  585. w = int(round(w * scale))
  586. elif op.__class__.__name__ == 'ResizeByLong':
  587. restore_list.append(('resize', (h, w)))
  588. im_long_size = max(h, w)
  589. scale = float(op.long_size) / float(im_long_size)
  590. h = int(round(h * scale))
  591. w = int(round(w * scale))
  592. elif op.__class__.__name__ == 'Pad':
  593. if op.target_size:
  594. target_h, target_w = op.target_size
  595. else:
  596. target_h = int(
  597. (np.ceil(h / op.size_divisor) * op.size_divisor))
  598. target_w = int(
  599. (np.ceil(w / op.size_divisor) * op.size_divisor))
  600. if op.pad_mode == -1:
  601. offsets = op.offsets
  602. elif op.pad_mode == 0:
  603. offsets = [0, 0]
  604. elif op.pad_mode == 1:
  605. offsets = [(target_h - h) // 2, (target_w - w) // 2]
  606. else:
  607. offsets = [target_h - h, target_w - w]
  608. restore_list.append(('padding', (h, w), offsets))
  609. h, w = target_h, target_w
  610. batch_restore_list.append(restore_list)
  611. return batch_restore_list
  612. def postprocess(self, batch_pred, batch_origin_shape, transforms):
  613. batch_restore_list = BaseSegmenter.get_transforms_shape_info(
  614. batch_origin_shape, transforms)
  615. if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
  616. return self._infer_postprocess(
  617. batch_label_map=batch_pred[0],
  618. batch_score_map=batch_pred[1],
  619. batch_restore_list=batch_restore_list)
  620. results = []
  621. if batch_pred.dtype == paddle.float32:
  622. mode = 'bilinear'
  623. else:
  624. mode = 'nearest'
  625. for pred, restore_list in zip(batch_pred, batch_restore_list):
  626. pred = paddle.unsqueeze(pred, axis=0)
  627. for item in restore_list[::-1]:
  628. h, w = item[1][0], item[1][1]
  629. if item[0] == 'resize':
  630. pred = F.interpolate(
  631. pred, (h, w), mode=mode, data_format='NCHW')
  632. elif item[0] == 'padding':
  633. x, y = item[2]
  634. pred = pred[:, :, y:y + h, x:x + w]
  635. else:
  636. pass
  637. results.append(pred)
  638. return results
  639. def _infer_postprocess(self, batch_label_map, batch_score_map,
  640. batch_restore_list):
  641. label_maps = []
  642. score_maps = []
  643. for label_map, score_map, restore_list in zip(
  644. batch_label_map, batch_score_map, batch_restore_list):
  645. if not isinstance(label_map, np.ndarray):
  646. label_map = paddle.unsqueeze(label_map, axis=[0, 3])
  647. score_map = paddle.unsqueeze(score_map, axis=0)
  648. for item in restore_list[::-1]:
  649. h, w = item[1][0], item[1][1]
  650. if item[0] == 'resize':
  651. if isinstance(label_map, np.ndarray):
  652. label_map = cv2.resize(
  653. label_map, (w, h), interpolation=cv2.INTER_NEAREST)
  654. score_map = cv2.resize(
  655. score_map, (w, h), interpolation=cv2.INTER_LINEAR)
  656. else:
  657. label_map = F.interpolate(
  658. label_map, (h, w),
  659. mode='nearest',
  660. data_format='NHWC')
  661. score_map = F.interpolate(
  662. score_map, (h, w),
  663. mode='bilinear',
  664. data_format='NHWC')
  665. elif item[0] == 'padding':
  666. x, y = item[2]
  667. if isinstance(label_map, np.ndarray):
  668. label_map = label_map[y:y + h, x:x + w]
  669. score_map = score_map[y:y + h, x:x + w]
  670. else:
  671. label_map = label_map[:, y:y + h, x:x + w, :]
  672. score_map = score_map[:, y:y + h, x:x + w, :]
  673. else:
  674. pass
  675. label_map = label_map.squeeze()
  676. score_map = score_map.squeeze()
  677. if not isinstance(label_map, np.ndarray):
  678. label_map = label_map.numpy()
  679. score_map = score_map.numpy()
  680. label_maps.append(label_map.squeeze())
  681. score_maps.append(score_map.squeeze())
  682. return label_maps, score_maps
  683. def _check_transforms(self, transforms, mode):
  684. super()._check_transforms(transforms, mode)
  685. if not isinstance(transforms.arrange,
  686. paddlers.transforms.ArrangeSegmenter):
  687. raise TypeError(
  688. "`transforms.arrange` must be an ArrangeSegmenter object.")
  689. def set_losses(self, losses, weights=None):
  690. if weights is None:
  691. weights = [1. for _ in range(len(losses))]
  692. self.losses = {'types': losses, 'coef': weights}
  693. class UNet(BaseSegmenter):
  694. def __init__(self,
  695. in_channels=3,
  696. num_classes=2,
  697. use_mixed_loss=False,
  698. losses=None,
  699. use_deconv=False,
  700. align_corners=False,
  701. **params):
  702. params.update({
  703. 'use_deconv': use_deconv,
  704. 'align_corners': align_corners
  705. })
  706. super(UNet, self).__init__(
  707. model_name='UNet',
  708. input_channel=in_channels,
  709. num_classes=num_classes,
  710. use_mixed_loss=use_mixed_loss,
  711. losses=losses,
  712. **params)
  713. class DeepLabV3P(BaseSegmenter):
  714. def __init__(self,
  715. in_channels=3,
  716. num_classes=2,
  717. backbone='ResNet50_vd',
  718. use_mixed_loss=False,
  719. losses=None,
  720. output_stride=8,
  721. backbone_indices=(0, 3),
  722. aspp_ratios=(1, 12, 24, 36),
  723. aspp_out_channels=256,
  724. align_corners=False,
  725. **params):
  726. self.backbone_name = backbone
  727. if backbone not in ['ResNet50_vd', 'ResNet101_vd']:
  728. raise ValueError(
  729. "backbone: {} is not supported. Please choose one of "
  730. "{'ResNet50_vd', 'ResNet101_vd'}.".format(backbone))
  731. if params.get('with_net', True):
  732. with DisablePrint():
  733. backbone = getattr(ppseg.models, backbone)(
  734. input_channel=in_channels, output_stride=output_stride)
  735. else:
  736. backbone = None
  737. params.update({
  738. 'backbone': backbone,
  739. 'backbone_indices': backbone_indices,
  740. 'aspp_ratios': aspp_ratios,
  741. 'aspp_out_channels': aspp_out_channels,
  742. 'align_corners': align_corners
  743. })
  744. super(DeepLabV3P, self).__init__(
  745. model_name='DeepLabV3P',
  746. num_classes=num_classes,
  747. use_mixed_loss=use_mixed_loss,
  748. losses=losses,
  749. **params)
  750. class FastSCNN(BaseSegmenter):
  751. def __init__(self,
  752. num_classes=2,
  753. use_mixed_loss=False,
  754. losses=None,
  755. align_corners=False,
  756. **params):
  757. params.update({'align_corners': align_corners})
  758. super(FastSCNN, self).__init__(
  759. model_name='FastSCNN',
  760. num_classes=num_classes,
  761. use_mixed_loss=use_mixed_loss,
  762. losses=losses,
  763. **params)
  764. class HRNet(BaseSegmenter):
  765. def __init__(self,
  766. num_classes=2,
  767. width=48,
  768. use_mixed_loss=False,
  769. losses=None,
  770. align_corners=False,
  771. **params):
  772. if width not in (18, 48):
  773. raise ValueError(
  774. "width={} is not supported, please choose from {18, 48}.".
  775. format(width))
  776. self.backbone_name = 'HRNet_W{}'.format(width)
  777. if params.get('with_net', True):
  778. with DisablePrint():
  779. backbone = getattr(ppseg.models, self.backbone_name)(
  780. align_corners=align_corners)
  781. else:
  782. backbone = None
  783. params.update({'backbone': backbone, 'align_corners': align_corners})
  784. super(HRNet, self).__init__(
  785. model_name='FCN',
  786. num_classes=num_classes,
  787. use_mixed_loss=use_mixed_loss,
  788. losses=losses,
  789. **params)
  790. self.model_name = 'HRNet'
  791. class BiSeNetV2(BaseSegmenter):
  792. def __init__(self,
  793. num_classes=2,
  794. use_mixed_loss=False,
  795. losses=None,
  796. align_corners=False,
  797. **params):
  798. params.update({'align_corners': align_corners})
  799. super(BiSeNetV2, self).__init__(
  800. model_name='BiSeNetV2',
  801. num_classes=num_classes,
  802. use_mixed_loss=use_mixed_loss,
  803. losses=losses,
  804. **params)
  805. class FarSeg(BaseSegmenter):
  806. def __init__(self,
  807. in_channels=3,
  808. num_classes=2,
  809. use_mixed_loss=False,
  810. losses=None,
  811. **params):
  812. super(FarSeg, self).__init__(
  813. model_name='FarSeg',
  814. num_classes=num_classes,
  815. use_mixed_loss=use_mixed_loss,
  816. losses=losses,
  817. in_channels=in_channels,
  818. **params)