segmenter.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import os.path as osp
  16. from collections import OrderedDict
  17. import numpy as np
  18. import cv2
  19. import paddle
  20. import paddle.nn.functional as F
  21. from paddle.static import InputSpec
  22. import paddlers
  23. import paddlers.models.paddleseg as ppseg
  24. import paddlers.rs_models.seg as cmseg
  25. import paddlers.utils.logging as logging
  26. from paddlers.models import seg_losses
  27. from paddlers.transforms import Resize, decode_image, construct_sample
  28. from paddlers.transforms.operators import ArrangeSegmenter
  29. from paddlers.utils import get_single_card_bs, DisablePrint
  30. from paddlers.utils.checkpoint import seg_pretrain_weights_dict
  31. from .base import BaseModel
  32. from .utils import seg_metrics as metrics
  33. from .utils.infer_nets import InferSegNet
  34. from .utils.slider_predict import slider_predict
  35. __all__ = [
  36. "UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg", "FactSeg",
  37. "C2FNet"
  38. ]
  39. class BaseSegmenter(BaseModel):
  40. _arrange = ArrangeSegmenter
  41. def __init__(self,
  42. model_name,
  43. num_classes=2,
  44. use_mixed_loss=False,
  45. losses=None,
  46. **params):
  47. self.init_params = locals()
  48. if 'with_net' in self.init_params:
  49. del self.init_params['with_net']
  50. super(BaseSegmenter, self).__init__('segmenter')
  51. if not hasattr(ppseg.models, model_name) and \
  52. not hasattr(cmseg, model_name):
  53. raise ValueError("ERROR: There is no model named {}.".format(
  54. model_name))
  55. self.model_name = model_name
  56. self.num_classes = num_classes
  57. self.use_mixed_loss = use_mixed_loss
  58. self.losses = losses
  59. self.labels = None
  60. if params.get('with_net', True):
  61. params.pop('with_net', None)
  62. self.net = self.build_net(**params)
  63. def build_net(self, **params):
  64. # TODO: when using paddle.utils.unique_name.guard,
  65. # DeepLabv3p and HRNet will raise an error.
  66. net = dict(ppseg.models.__dict__, **cmseg.__dict__)[self.model_name](
  67. num_classes=self.num_classes, **params)
  68. return net
  69. def _build_inference_net(self):
  70. infer_net = InferSegNet(self.net)
  71. infer_net.eval()
  72. return infer_net
  73. def _fix_transforms_shape(self, image_shape):
  74. if hasattr(self, 'test_transforms'):
  75. if self.test_transforms is not None:
  76. has_resize_op = False
  77. resize_op_idx = -1
  78. normalize_op_idx = len(self.test_transforms.transforms)
  79. for idx, op in enumerate(self.test_transforms.transforms):
  80. name = op.__class__.__name__
  81. if name == 'Normalize':
  82. normalize_op_idx = idx
  83. if 'Resize' in name:
  84. has_resize_op = True
  85. resize_op_idx = idx
  86. if not has_resize_op:
  87. self.test_transforms.transforms.insert(
  88. normalize_op_idx, Resize(target_size=image_shape))
  89. else:
  90. self.test_transforms.transforms[resize_op_idx] = Resize(
  91. target_size=image_shape)
  92. def _get_test_inputs(self, image_shape):
  93. if image_shape is not None:
  94. if len(image_shape) == 2:
  95. image_shape = [1, 3] + image_shape
  96. self._fix_transforms_shape(image_shape[-2:])
  97. else:
  98. image_shape = [None, 3, -1, -1]
  99. self.fixed_input_shape = image_shape
  100. input_spec = [
  101. InputSpec(
  102. shape=image_shape, name='image', dtype='float32')
  103. ]
  104. return input_spec
  105. def run(self, net, inputs, mode):
  106. inputs, batch_restore_list = inputs
  107. net_out = net(inputs[0])
  108. logit = net_out[0]
  109. outputs = OrderedDict()
  110. if mode == 'test':
  111. if self.status == 'Infer':
  112. label_map_list, score_map_list = self.postprocess(
  113. net_out, batch_restore_list)
  114. else:
  115. logit_list = self.postprocess(logit, batch_restore_list)
  116. label_map_list = []
  117. score_map_list = []
  118. for logit in logit_list:
  119. logit = paddle.transpose(logit, perm=[0, 2, 3, 1]) # NHWC
  120. label_map_list.append(
  121. paddle.argmax(
  122. logit, axis=-1, keepdim=False, dtype='int32')
  123. .squeeze().numpy())
  124. score_map_list.append(
  125. F.softmax(
  126. logit, axis=-1).squeeze().numpy().astype('float32'))
  127. outputs['label_map'] = label_map_list
  128. outputs['score_map'] = score_map_list
  129. if mode == 'eval':
  130. if self.status == 'Infer':
  131. pred = paddle.unsqueeze(net_out[0], axis=1) # NCHW
  132. else:
  133. pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
  134. label = inputs[1]
  135. if label.ndim == 3:
  136. paddle.unsqueeze_(label, axis=1)
  137. if label.ndim != 4:
  138. raise ValueError("Expected label.ndim == 4 but got {}".format(
  139. label.ndim))
  140. pred = self.postprocess(pred, batch_restore_list)[0] # NCHW
  141. intersect_area, pred_area, label_area = ppseg.utils.metrics.calculate_area(
  142. pred, label, self.num_classes)
  143. outputs['intersect_area'] = intersect_area
  144. outputs['pred_area'] = pred_area
  145. outputs['label_area'] = label_area
  146. outputs['conf_mat'] = metrics.confusion_matrix(pred, label,
  147. self.num_classes)
  148. if mode == 'train':
  149. loss_list = metrics.loss_computation(
  150. logits_list=net_out, labels=inputs[1], losses=self.losses)
  151. loss = sum(loss_list)
  152. outputs['loss'] = loss
  153. return outputs
  154. def default_loss(self):
  155. if isinstance(self.use_mixed_loss, bool):
  156. if self.use_mixed_loss:
  157. losses = [
  158. seg_losses.CrossEntropyLoss(),
  159. seg_losses.LovaszSoftmaxLoss()
  160. ]
  161. coef = [.8, .2]
  162. loss_type = [seg_losses.MixedLoss(losses=losses, coef=coef), ]
  163. else:
  164. loss_type = [seg_losses.CrossEntropyLoss()]
  165. else:
  166. losses, coef = list(zip(*self.use_mixed_loss))
  167. if not set(losses).issubset(
  168. ['CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss']):
  169. raise ValueError(
  170. "Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported."
  171. )
  172. losses = [getattr(seg_losses, loss)() for loss in losses]
  173. loss_type = [seg_losses.MixedLoss(losses=losses, coef=list(coef))]
  174. loss_coef = [1.0]
  175. losses = {'types': loss_type, 'coef': loss_coef}
  176. return losses
  177. def default_optimizer(self,
  178. parameters,
  179. learning_rate,
  180. num_epochs,
  181. num_steps_each_epoch,
  182. lr_decay_power=0.9):
  183. decay_step = num_epochs * num_steps_each_epoch
  184. lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
  185. learning_rate, decay_step, end_lr=0, power=lr_decay_power)
  186. optimizer = paddle.optimizer.Momentum(
  187. learning_rate=lr_scheduler,
  188. parameters=parameters,
  189. momentum=0.9,
  190. weight_decay=4e-5)
  191. return optimizer
  192. def train(self,
  193. num_epochs,
  194. train_dataset,
  195. train_batch_size=2,
  196. eval_dataset=None,
  197. optimizer=None,
  198. save_interval_epochs=1,
  199. log_interval_steps=2,
  200. save_dir='output',
  201. pretrain_weights='CITYSCAPES',
  202. learning_rate=0.01,
  203. lr_decay_power=0.9,
  204. early_stop=False,
  205. early_stop_patience=5,
  206. use_vdl=True,
  207. resume_checkpoint=None):
  208. """
  209. Train the model.
  210. Args:
  211. num_epochs (int): Number of epochs.
  212. train_dataset (paddlers.datasets.SegDataset): Training dataset.
  213. train_batch_size (int, optional): Total batch size among all cards used in
  214. training. Defaults to 2.
  215. eval_dataset (paddlers.datasets.SegDataset|None, optional): Evaluation dataset.
  216. If None, the model will not be evaluated during training process.
  217. Defaults to None.
  218. optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in
  219. training. If None, a default optimizer will be used. Defaults to None.
  220. save_interval_epochs (int, optional): Epoch interval for saving the model.
  221. Defaults to 1.
  222. log_interval_steps (int, optional): Step interval for printing training
  223. information. Defaults to 2.
  224. save_dir (str, optional): Directory to save the model. Defaults to 'output'.
  225. pretrain_weights (str|None, optional): None or name/path of pretrained
  226. weights. If None, no pretrained weights will be loaded.
  227. Defaults to 'CITYSCAPES'.
  228. learning_rate (float, optional): Learning rate for training. Defaults to .01.
  229. lr_decay_power (float, optional): Learning decay power. Defaults to .9.
  230. early_stop (bool, optional): Whether to adopt early stop strategy. Defaults
  231. to False.
  232. early_stop_patience (int, optional): Early stop patience. Defaults to 5.
  233. use_vdl (bool, optional): Whether to use VisualDL to monitor the training
  234. process. Defaults to True.
  235. resume_checkpoint (str|None, optional): Path of the checkpoint to resume
  236. training from. If None, no training checkpoint will be resumed. At most
  237. Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
  238. Defaults to None.
  239. """
  240. if self.status == 'Infer':
  241. logging.error(
  242. "Exported inference model does not support training.",
  243. exit=True)
  244. if pretrain_weights is not None and resume_checkpoint is not None:
  245. logging.error(
  246. "`pretrain_weights` and `resume_checkpoint` cannot be set simultaneously.",
  247. exit=True)
  248. self.labels = train_dataset.labels
  249. if self.losses is None:
  250. self.losses = self.default_loss()
  251. if optimizer is None:
  252. num_steps_each_epoch = train_dataset.num_samples // train_batch_size
  253. self.optimizer = self.default_optimizer(
  254. self.net.parameters(), learning_rate, num_epochs,
  255. num_steps_each_epoch, lr_decay_power)
  256. else:
  257. self.optimizer = optimizer
  258. if pretrain_weights is not None:
  259. if not osp.exists(pretrain_weights):
  260. if self.model_name not in seg_pretrain_weights_dict:
  261. logging.warning(
  262. "Path of pretrained weights ('{}') does not exist!".
  263. format(pretrain_weights))
  264. pretrain_weights = None
  265. elif pretrain_weights not in seg_pretrain_weights_dict[
  266. self.model_name]:
  267. logging.warning(
  268. "Path of pretrained weights ('{}') does not exist!".
  269. format(pretrain_weights))
  270. pretrain_weights = seg_pretrain_weights_dict[
  271. self.model_name][0]
  272. logging.warning(
  273. "`pretrain_weights` is forcibly set to '{}'. "
  274. "If you don't want to use pretrained weights, "
  275. "please set `pretrain_weights` to None.".format(
  276. pretrain_weights))
  277. else:
  278. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  279. logging.error(
  280. "Invalid pretrained weights. Please specify a .pdparams file.",
  281. exit=True)
  282. pretrained_dir = osp.join(save_dir, 'pretrain')
  283. is_backbone_weights = pretrain_weights == 'IMAGENET'
  284. self.initialize_net(
  285. pretrain_weights=pretrain_weights,
  286. save_dir=pretrained_dir,
  287. resume_checkpoint=resume_checkpoint,
  288. is_backbone_weights=is_backbone_weights)
  289. self.train_loop(
  290. num_epochs=num_epochs,
  291. train_dataset=train_dataset,
  292. train_batch_size=train_batch_size,
  293. eval_dataset=eval_dataset,
  294. save_interval_epochs=save_interval_epochs,
  295. log_interval_steps=log_interval_steps,
  296. save_dir=save_dir,
  297. early_stop=early_stop,
  298. early_stop_patience=early_stop_patience,
  299. use_vdl=use_vdl)
  300. def quant_aware_train(self,
  301. num_epochs,
  302. train_dataset,
  303. train_batch_size=2,
  304. eval_dataset=None,
  305. optimizer=None,
  306. save_interval_epochs=1,
  307. log_interval_steps=2,
  308. save_dir='output',
  309. learning_rate=0.0001,
  310. lr_decay_power=0.9,
  311. early_stop=False,
  312. early_stop_patience=5,
  313. use_vdl=True,
  314. resume_checkpoint=None,
  315. quant_config=None):
  316. """
  317. Quantization-aware training.
  318. Args:
  319. num_epochs (int): Number of epochs.
  320. train_dataset (paddlers.datasets.SegDataset): Training dataset.
  321. train_batch_size (int, optional): Total batch size among all cards used in
  322. training. Defaults to 2.
  323. eval_dataset (paddlers.datasets.SegDataset|None, optional): Evaluation dataset.
  324. If None, the model will not be evaluated during training process.
  325. Defaults to None.
  326. optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in
  327. training. If None, a default optimizer will be used. Defaults to None.
  328. save_interval_epochs (int, optional): Epoch interval for saving the model.
  329. Defaults to 1.
  330. log_interval_steps (int, optional): Step interval for printing training
  331. information. Defaults to 2.
  332. save_dir (str, optional): Directory to save the model. Defaults to 'output'.
  333. learning_rate (float, optional): Learning rate for training.
  334. Defaults to .0001.
  335. lr_decay_power (float, optional): Learning decay power. Defaults to .9.
  336. early_stop (bool, optional): Whether to adopt early stop strategy.
  337. Defaults to False.
  338. early_stop_patience (int, optional): Early stop patience. Defaults to 5.
  339. use_vdl (bool, optional): Whether to use VisualDL to monitor the training
  340. process. Defaults to True.
  341. quant_config (dict|None, optional): Quantization configuration. If None,
  342. a default rule of thumb configuration will be used. Defaults to None.
  343. resume_checkpoint (str|None, optional): Path of the checkpoint to resume
  344. quantization-aware training from. If None, no training checkpoint will
  345. be resumed. Defaults to None.
  346. """
  347. self._prepare_qat(quant_config)
  348. self.train(
  349. num_epochs=num_epochs,
  350. train_dataset=train_dataset,
  351. train_batch_size=train_batch_size,
  352. eval_dataset=eval_dataset,
  353. optimizer=optimizer,
  354. save_interval_epochs=save_interval_epochs,
  355. log_interval_steps=log_interval_steps,
  356. save_dir=save_dir,
  357. pretrain_weights=None,
  358. learning_rate=learning_rate,
  359. lr_decay_power=lr_decay_power,
  360. early_stop=early_stop,
  361. early_stop_patience=early_stop_patience,
  362. use_vdl=use_vdl,
  363. resume_checkpoint=resume_checkpoint)
  364. def evaluate(self, eval_dataset, batch_size=1, return_details=False):
  365. """
  366. Evaluate the model.
  367. Args:
  368. eval_dataset (paddlers.datasets.SegDataset): Evaluation dataset.
  369. batch_size (int, optional): Total batch size among all cards used for
  370. evaluation. Defaults to 1.
  371. return_details (bool, optional): Whether to return evaluation details.
  372. Defaults to False.
  373. Returns:
  374. collections.OrderedDict with key-value pairs:
  375. {"miou": mean intersection over union,
  376. "category_iou": category-wise mean intersection over union,
  377. "oacc": overall accuracy,
  378. "category_acc": category-wise accuracy,
  379. "kappa": kappa coefficient,
  380. "category_F1-score": F1 score}.
  381. """
  382. self._check_transforms(eval_dataset.transforms)
  383. self.net.eval()
  384. nranks = paddle.distributed.get_world_size()
  385. local_rank = paddle.distributed.get_rank()
  386. if nranks > 1:
  387. # Initialize parallel environment if not done.
  388. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  389. ):
  390. paddle.distributed.init_parallel_env()
  391. batch_size_each_card = get_single_card_bs(batch_size)
  392. if batch_size_each_card > 1:
  393. batch_size_each_card = 1
  394. batch_size = batch_size_each_card * paddlers.env_info['num']
  395. logging.warning(
  396. "Segmenter only supports batch_size=1 for each gpu/cpu card " \
  397. "during evaluation, so batch_size " \
  398. "is forcibly set to {}.".format(batch_size))
  399. self.eval_data_loader = self.build_data_loader(
  400. eval_dataset, batch_size=batch_size, mode='eval')
  401. self._check_arrange(eval_dataset.transforms, 'eval')
  402. intersect_area_all = 0
  403. pred_area_all = 0
  404. label_area_all = 0
  405. conf_mat_all = []
  406. logging.info(
  407. "Start to evaluate(total_samples={}, total_steps={})...".format(
  408. eval_dataset.num_samples,
  409. math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
  410. with paddle.no_grad():
  411. for step, data in enumerate(self.eval_data_loader):
  412. outputs = self.run(self.net, data, 'eval')
  413. pred_area = outputs['pred_area']
  414. label_area = outputs['label_area']
  415. intersect_area = outputs['intersect_area']
  416. conf_mat = outputs['conf_mat']
  417. # Gather from all ranks
  418. if nranks > 1:
  419. intersect_area_list = []
  420. pred_area_list = []
  421. label_area_list = []
  422. conf_mat_list = []
  423. paddle.distributed.all_gather(intersect_area_list,
  424. intersect_area)
  425. paddle.distributed.all_gather(pred_area_list, pred_area)
  426. paddle.distributed.all_gather(label_area_list, label_area)
  427. paddle.distributed.all_gather(conf_mat_list, conf_mat)
  428. # Some image has been evaluated and should be eliminated in last iter
  429. if (step + 1) * nranks > len(eval_dataset):
  430. valid = len(eval_dataset) - step * nranks
  431. intersect_area_list = intersect_area_list[:valid]
  432. pred_area_list = pred_area_list[:valid]
  433. label_area_list = label_area_list[:valid]
  434. conf_mat_list = conf_mat_list[:valid]
  435. intersect_area_all += sum(intersect_area_list)
  436. pred_area_all += sum(pred_area_list)
  437. label_area_all += sum(label_area_list)
  438. conf_mat_all.extend(conf_mat_list)
  439. else:
  440. intersect_area_all = intersect_area_all + intersect_area
  441. pred_area_all = pred_area_all + pred_area
  442. label_area_all = label_area_all + label_area
  443. conf_mat_all.append(conf_mat)
  444. class_iou, miou = ppseg.utils.metrics.mean_iou(
  445. intersect_area_all, pred_area_all, label_area_all)
  446. class_acc, oacc = ppseg.utils.metrics.accuracy(intersect_area_all,
  447. pred_area_all)
  448. kappa = ppseg.utils.metrics.kappa(intersect_area_all, pred_area_all,
  449. label_area_all)
  450. category_f1score = metrics.f1_score(intersect_area_all, pred_area_all,
  451. label_area_all)
  452. eval_metrics = OrderedDict(
  453. zip([
  454. 'miou', 'category_iou', 'oacc', 'category_acc', 'kappa',
  455. 'category_F1-score'
  456. ], [miou, class_iou, oacc, class_acc, kappa, category_f1score]))
  457. if return_details:
  458. conf_mat = sum(conf_mat_all)
  459. eval_details = {'confusion_matrix': conf_mat.tolist()}
  460. return eval_metrics, eval_details
  461. return eval_metrics
  462. @paddle.no_grad()
  463. def predict(self, img_file, transforms=None):
  464. """
  465. Do inference.
  466. Args:
  467. img_file (list[np.ndarray|str] | str | np.ndarray): Image path or decoded
  468. image data, which also could constitute a list, meaning all images to be
  469. predicted as a mini-batch.
  470. transforms (paddlers.transforms.Compose|None, optional): Transforms for
  471. inputs. If None, the transforms for evaluation process will be used.
  472. Defaults to None.
  473. Returns:
  474. If `img_file` is a tuple of string or np.array, the result is a dict with
  475. the following key-value pairs:
  476. label_map (np.ndarray): Predicted label map (HW).
  477. score_map (np.ndarray): Prediction score map (HWC).
  478. If `img_file` is a list, the result is a list composed of dicts with the
  479. above keys.
  480. """
  481. if transforms is None and not hasattr(self, 'test_transforms'):
  482. raise ValueError("transforms need to be defined, now is None.")
  483. if transforms is None:
  484. transforms = self.test_transforms
  485. if isinstance(img_file, (str, np.ndarray)):
  486. images = [img_file]
  487. else:
  488. images = img_file
  489. transforms = self._build_transforms(transforms, "test")
  490. self._check_arrange(transforms, "test")
  491. data = self.preprocess(images, transforms, self.model_type)
  492. self.net.eval()
  493. outputs = self.run(self.net, data, 'test')
  494. label_map_list = outputs['label_map']
  495. score_map_list = outputs['score_map']
  496. if isinstance(img_file, list):
  497. prediction = [{
  498. 'label_map': l,
  499. 'score_map': s
  500. } for l, s in zip(label_map_list, score_map_list)]
  501. else:
  502. prediction = {
  503. 'label_map': label_map_list[0],
  504. 'score_map': score_map_list[0]
  505. }
  506. return prediction
  507. def slider_predict(self,
  508. img_file,
  509. save_dir,
  510. block_size,
  511. overlap=36,
  512. transforms=None,
  513. invalid_value=255,
  514. merge_strategy='keep_last',
  515. batch_size=1,
  516. eager_load=False,
  517. quiet=False):
  518. """
  519. Do inference using sliding windows.
  520. Args:
  521. img_file (str): Image path.
  522. save_dir (str): Directory that contains saved geotiff file.
  523. block_size (list[int] | tuple[int] | int):
  524. Size of block. If `block_size` is list or tuple, it should be in
  525. (W, H) format.
  526. overlap (list[int] | tuple[int] | int, optional):
  527. Overlap between two blocks. If `overlap` is list or tuple, it should
  528. be in (W, H) format. Defaults to 36.
  529. transforms (paddlers.transforms.Compose|None, optional): Transforms for
  530. inputs. If None, the transforms for evaluation process will be used.
  531. Defaults to None.
  532. invalid_value (int, optional): Value that marks invalid pixels in output
  533. image. Defaults to 255.
  534. merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices
  535. are {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last'
  536. means keeping the values of the first and the last block in traversal
  537. order, respectively. 'accum' means determining the class of an overlapping
  538. pixel according to accumulated probabilities. Defaults to 'keep_last'.
  539. batch_size (int, optional): Batch size used in inference. Defaults to 1.
  540. eager_load (bool, optional): Whether to load the whole image(s) eagerly.
  541. Defaults to False.
  542. quiet (bool, optional): If True, disable the progress bar. Defaults to False.
  543. """
  544. slider_predict(self.predict, img_file, save_dir, block_size, overlap,
  545. transforms, invalid_value, merge_strategy, batch_size,
  546. eager_load, not quiet)
  547. def preprocess(self, images, transforms, to_tensor=True):
  548. self._check_transforms(transforms)
  549. batch_im = list()
  550. batch_trans_info = list()
  551. for im in images:
  552. if isinstance(im, str):
  553. im = decode_image(im, read_raw=True)
  554. sample = construct_sample(image=im)
  555. data = transforms(sample)
  556. im = data[0][0]
  557. trans_info = data[1]
  558. batch_im.append(im)
  559. batch_trans_info.append(trans_info)
  560. if to_tensor:
  561. batch_im = paddle.to_tensor(batch_im)
  562. else:
  563. batch_im = np.asarray(batch_im)
  564. return (batch_im, ), batch_trans_info
  565. def postprocess(self, batch_pred, batch_restore_list):
  566. if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
  567. return self._infer_postprocess(
  568. batch_label_map=batch_pred[0],
  569. batch_score_map=batch_pred[1],
  570. batch_restore_list=batch_restore_list)
  571. results = []
  572. if batch_pred.dtype == paddle.float32:
  573. mode = 'bilinear'
  574. else:
  575. mode = 'nearest'
  576. for pred, restore_list in zip(batch_pred, batch_restore_list):
  577. pred = paddle.unsqueeze(pred, axis=0)
  578. for item in restore_list[::-1]:
  579. h, w = item[1][0], item[1][1]
  580. if item[0] == 'resize':
  581. pred = F.interpolate(
  582. pred, (h, w), mode=mode, data_format='NCHW')
  583. elif item[0] == 'padding':
  584. x, y = item[2]
  585. pred = pred[:, :, y:y + h, x:x + w]
  586. else:
  587. raise RuntimeError
  588. results.append(pred)
  589. return results
  590. def _infer_postprocess(self, batch_label_map, batch_score_map,
  591. batch_restore_list):
  592. label_maps = []
  593. score_maps = []
  594. for label_map, score_map, restore_list in zip(
  595. batch_label_map, batch_score_map, batch_restore_list):
  596. if not isinstance(label_map, np.ndarray):
  597. label_map = paddle.unsqueeze(label_map, axis=[0, 3])
  598. score_map = paddle.unsqueeze(score_map, axis=0)
  599. for item in restore_list[::-1]:
  600. h, w = item[1][0], item[1][1]
  601. if item[0] == 'resize':
  602. if isinstance(label_map, np.ndarray):
  603. label_map = cv2.resize(
  604. label_map, (w, h), interpolation=cv2.INTER_NEAREST)
  605. score_map = cv2.resize(
  606. score_map, (w, h), interpolation=cv2.INTER_LINEAR)
  607. else:
  608. label_map = F.interpolate(
  609. label_map, (h, w),
  610. mode='nearest',
  611. data_format='NHWC')
  612. score_map = F.interpolate(
  613. score_map, (h, w),
  614. mode='bilinear',
  615. data_format='NHWC')
  616. elif item[0] == 'padding':
  617. x, y = item[2]
  618. if isinstance(label_map, np.ndarray):
  619. label_map = label_map[y:y + h, x:x + w]
  620. score_map = score_map[y:y + h, x:x + w]
  621. else:
  622. label_map = label_map[:, y:y + h, x:x + w, :]
  623. score_map = score_map[:, y:y + h, x:x + w, :]
  624. else:
  625. raise RuntimeError
  626. label_map = label_map.squeeze()
  627. score_map = score_map.squeeze()
  628. if not isinstance(label_map, np.ndarray):
  629. label_map = label_map.numpy()
  630. score_map = score_map.numpy()
  631. label_maps.append(label_map.squeeze())
  632. score_maps.append(score_map.squeeze())
  633. return label_maps, score_maps
  634. def _check_arrange(self, transforms, mode):
  635. super()._check_arrange(transforms, mode)
  636. if not isinstance(transforms.arrange, ArrangeSegmenter):
  637. raise TypeError(
  638. "`transforms.arrange` must be an `ArrangeSegmenter` object.")
  639. def set_losses(self, losses, weights=None):
  640. if weights is None:
  641. weights = [1. for _ in range(len(losses))]
  642. self.losses = {'types': losses, 'coef': weights}
  643. class UNet(BaseSegmenter):
  644. def __init__(self,
  645. in_channels=3,
  646. num_classes=2,
  647. use_mixed_loss=False,
  648. losses=None,
  649. use_deconv=False,
  650. align_corners=False,
  651. **params):
  652. params.update({
  653. 'use_deconv': use_deconv,
  654. 'align_corners': align_corners
  655. })
  656. super(UNet, self).__init__(
  657. model_name='UNet',
  658. in_channels=in_channels,
  659. num_classes=num_classes,
  660. use_mixed_loss=use_mixed_loss,
  661. losses=losses,
  662. **params)
  663. class DeepLabV3P(BaseSegmenter):
  664. def __init__(self,
  665. in_channels=3,
  666. num_classes=2,
  667. backbone='ResNet50_vd',
  668. use_mixed_loss=False,
  669. losses=None,
  670. output_stride=8,
  671. backbone_indices=(0, 3),
  672. aspp_ratios=(1, 12, 24, 36),
  673. aspp_out_channels=256,
  674. align_corners=False,
  675. **params):
  676. self.backbone_name = backbone
  677. if backbone not in ['ResNet50_vd', 'ResNet101_vd']:
  678. raise ValueError(
  679. "backbone: {} is not supported. Please choose one of "
  680. "{'ResNet50_vd', 'ResNet101_vd'}.".format(backbone))
  681. if params.get('with_net', True):
  682. with DisablePrint():
  683. backbone = getattr(ppseg.models, backbone)(
  684. in_channels=in_channels, output_stride=output_stride)
  685. else:
  686. backbone = None
  687. params.update({
  688. 'backbone': backbone,
  689. 'backbone_indices': backbone_indices,
  690. 'aspp_ratios': aspp_ratios,
  691. 'aspp_out_channels': aspp_out_channels,
  692. 'align_corners': align_corners
  693. })
  694. super(DeepLabV3P, self).__init__(
  695. model_name='DeepLabV3P',
  696. num_classes=num_classes,
  697. use_mixed_loss=use_mixed_loss,
  698. losses=losses,
  699. **params)
  700. class FastSCNN(BaseSegmenter):
  701. def __init__(self,
  702. in_channels=3,
  703. num_classes=2,
  704. use_mixed_loss=False,
  705. losses=None,
  706. align_corners=False,
  707. **params):
  708. params.update({'align_corners': align_corners})
  709. super(FastSCNN, self).__init__(
  710. model_name='FastSCNN',
  711. in_channels=in_channels,
  712. num_classes=num_classes,
  713. use_mixed_loss=use_mixed_loss,
  714. losses=losses,
  715. **params)
  716. def default_loss(self):
  717. losses = super(FastSCNN, self).default_loss()
  718. losses['types'] *= 2
  719. losses['coef'] = [1.0, 0.4]
  720. return losses
  721. class HRNet(BaseSegmenter):
  722. def __init__(self,
  723. in_channels=3,
  724. num_classes=2,
  725. width=48,
  726. use_mixed_loss=False,
  727. losses=None,
  728. align_corners=False,
  729. **params):
  730. if width not in (18, 48):
  731. raise ValueError(
  732. "width={} is not supported, please choose from {18, 48}.".
  733. format(width))
  734. self.backbone_name = 'HRNet_W{}'.format(width)
  735. if params.get('with_net', True):
  736. with DisablePrint():
  737. backbone = getattr(ppseg.models, self.backbone_name)(
  738. in_channels=in_channels, align_corners=align_corners)
  739. else:
  740. backbone = None
  741. params.update({'backbone': backbone, 'align_corners': align_corners})
  742. super(HRNet, self).__init__(
  743. model_name='FCN',
  744. num_classes=num_classes,
  745. use_mixed_loss=use_mixed_loss,
  746. losses=losses,
  747. **params)
  748. self.model_name = 'HRNet'
  749. class BiSeNetV2(BaseSegmenter):
  750. def __init__(self,
  751. in_channels=3,
  752. num_classes=2,
  753. use_mixed_loss=False,
  754. losses=None,
  755. align_corners=False,
  756. **params):
  757. params.update({'align_corners': align_corners})
  758. super(BiSeNetV2, self).__init__(
  759. model_name='BiSeNetV2',
  760. in_channels=in_channels,
  761. num_classes=num_classes,
  762. use_mixed_loss=use_mixed_loss,
  763. losses=losses,
  764. **params)
  765. def default_loss(self):
  766. losses = super(BiSeNetV2, self).default_loss()
  767. losses['types'] *= 5
  768. losses['coef'] = [1.0] * 5
  769. return losses
  770. class FarSeg(BaseSegmenter):
  771. def __init__(self,
  772. in_channels=3,
  773. num_classes=2,
  774. use_mixed_loss=False,
  775. losses=None,
  776. **params):
  777. super(FarSeg, self).__init__(
  778. model_name='FarSeg',
  779. num_classes=num_classes,
  780. use_mixed_loss=use_mixed_loss,
  781. losses=losses,
  782. in_channels=in_channels,
  783. **params)
  784. class FactSeg(BaseSegmenter):
  785. def __init__(self,
  786. in_channels=3,
  787. num_classes=2,
  788. use_mixed_loss=False,
  789. losses=None,
  790. **params):
  791. super(FactSeg, self).__init__(
  792. model_name='FactSeg',
  793. num_classes=num_classes,
  794. use_mixed_loss=use_mixed_loss,
  795. losses=losses,
  796. in_channels=in_channels,
  797. **params)
  798. class C2FNet(BaseSegmenter):
  799. def __init__(self,
  800. in_channels=3,
  801. num_classes=2,
  802. backbone='HRNet_W18',
  803. use_mixed_loss=False,
  804. losses=None,
  805. output_stride=8,
  806. backbone_indices=(-1, ),
  807. kernel_sizes=(128, 128),
  808. training_stride=32,
  809. samples_per_gpu=32,
  810. channels=None,
  811. align_corners=False,
  812. coarse_model=None,
  813. coarse_model_backbone=None,
  814. coarse_model_path=None,
  815. **params):
  816. self.backbone_name = backbone
  817. if params.get('with_net', True):
  818. with DisablePrint():
  819. backbone = getattr(ppseg.models, self.backbone_name)(
  820. in_channels=in_channels, align_corners=align_corners)
  821. else:
  822. backbone = None
  823. if coarse_model_backbone in ['ResNet50_vd', 'ResNet101_vd']:
  824. self.coarse_model_backbone = getattr(
  825. ppseg.models, coarse_model_backbone)(output_stride=8)
  826. elif coarse_model_backbone in ['HRNet_W18', 'HRNet_W48']:
  827. self.coarse_model_backbone = getattr(
  828. ppseg.models, coarse_model_backbone)(align_corners=False)
  829. else:
  830. raise ValueError(
  831. "coarse_model_backbone: {} is not supported. Please choose one of "
  832. "{'ResNet50_vd', 'ResNet101_vd', 'HRNet_W18', 'HRNet_W48'}.".
  833. format(coarse_model_backbone))
  834. self.coarse_model = dict(ppseg.models.__dict__)[coarse_model](
  835. num_classes=num_classes, backbone=self.coarse_model_backbone)
  836. self.coarse_params = paddle.load(coarse_model_path)
  837. self.coarse_model.set_state_dict(self.coarse_params)
  838. self.coarse_model.eval()
  839. params.update({
  840. 'backbone': backbone,
  841. 'backbone_indices': backbone_indices,
  842. 'kernel_sizes': kernel_sizes,
  843. 'training_stride': training_stride,
  844. 'samples_per_gpu': samples_per_gpu,
  845. 'align_corners': align_corners
  846. })
  847. super(C2FNet, self).__init__(
  848. model_name='C2FNet',
  849. num_classes=num_classes,
  850. use_mixed_loss=use_mixed_loss,
  851. losses=losses,
  852. **params)
  853. def run(self, net, inputs, mode):
  854. inputs, batch_restore_list = inputs
  855. with paddle.no_grad():
  856. pre_coarse = self.coarse_model(inputs[0])
  857. pre_coarse = pre_coarse[0]
  858. heatmaps = pre_coarse
  859. if mode == 'test':
  860. net_out = net(inputs[0], heatmaps)
  861. logit = net_out[0]
  862. outputs = OrderedDict()
  863. origin_shape = inputs[1]
  864. if self.status == 'Infer':
  865. label_map_list, score_map_list = self.postprocess(
  866. net_out, batch_restore_list)
  867. else:
  868. logit_list = self.postprocess(logit, batch_restore_list)
  869. label_map_list = []
  870. score_map_list = []
  871. for logit in logit_list:
  872. logit = paddle.transpose(logit, perm=[0, 2, 3, 1]) # NHWC
  873. label_map_list.append(
  874. paddle.argmax(
  875. logit, axis=-1, keepdim=False, dtype='int32')
  876. .squeeze().numpy())
  877. score_map_list.append(
  878. F.softmax(
  879. logit, axis=-1).squeeze().numpy().astype('float32'))
  880. outputs['label_map'] = label_map_list
  881. outputs['score_map'] = score_map_list
  882. if mode == 'eval':
  883. net_out = net(inputs[0], heatmaps)
  884. logit = net_out[0]
  885. outputs = OrderedDict()
  886. if self.status == 'Infer':
  887. pred = paddle.unsqueeze(net_out[0], axis=1) # NCHW
  888. else:
  889. pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
  890. label = inputs[1]
  891. if label.ndim == 3:
  892. paddle.unsqueeze_(label, axis=1)
  893. if label.ndim != 4:
  894. raise ValueError("Expected label.ndim == 4 but got {}".format(
  895. label.ndim))
  896. pred = self.postprocess(pred, batch_restore_list)[0] # NCHW
  897. intersect_area, pred_area, label_area = ppseg.utils.metrics.calculate_area(
  898. pred, label, self.num_classes)
  899. outputs['intersect_area'] = intersect_area
  900. outputs['pred_area'] = pred_area
  901. outputs['label_area'] = label_area
  902. outputs['conf_mat'] = metrics.confusion_matrix(pred, label,
  903. self.num_classes)
  904. if mode == 'train':
  905. net_out = net(inputs[0], heatmaps, inputs[1])
  906. logit = [net_out[0], ]
  907. labels = net_out[1]
  908. outputs = OrderedDict()
  909. loss_list = metrics.loss_computation(
  910. logits_list=logit, labels=labels, losses=self.losses)
  911. loss = sum(loss_list)
  912. outputs['loss'] = loss
  913. return outputs