segmenter.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import os
  16. import os.path as osp
  17. from collections import OrderedDict
  18. import numpy as np
  19. import cv2
  20. import paddle
  21. import paddle.nn.functional as F
  22. from paddle.static import InputSpec
  23. import paddlers
  24. import paddlers.models.ppseg as ppseg
  25. import paddlers.rs_models.seg as cmseg
  26. import paddlers.utils.logging as logging
  27. from paddlers.models import seg_losses
  28. from paddlers.transforms import Resize, decode_image
  29. from paddlers.utils import get_single_card_bs, DisablePrint
  30. from paddlers.utils.checkpoint import seg_pretrain_weights_dict
  31. from .base import BaseModel
  32. from .utils import seg_metrics as metrics
  33. __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"]
  34. class BaseSegmenter(BaseModel):
  35. def __init__(self,
  36. model_name,
  37. num_classes=2,
  38. use_mixed_loss=False,
  39. losses=None,
  40. **params):
  41. self.init_params = locals()
  42. if 'with_net' in self.init_params:
  43. del self.init_params['with_net']
  44. super(BaseSegmenter, self).__init__('segmenter')
  45. if not hasattr(ppseg.models, model_name) and \
  46. not hasattr(cmseg, model_name):
  47. raise ValueError("ERROR: There is no model named {}.".format(
  48. model_name))
  49. self.model_name = model_name
  50. self.num_classes = num_classes
  51. self.use_mixed_loss = use_mixed_loss
  52. self.losses = losses
  53. self.labels = None
  54. if params.get('with_net', True):
  55. params.pop('with_net', None)
  56. self.net = self.build_net(**params)
  57. self.find_unused_parameters = True
  58. def build_net(self, **params):
  59. # TODO: when using paddle.utils.unique_name.guard,
  60. # DeepLabv3p and HRNet will raise a error
  61. net = dict(ppseg.models.__dict__, **cmseg.__dict__)[self.model_name](
  62. num_classes=self.num_classes, **params)
  63. return net
  64. def _fix_transforms_shape(self, image_shape):
  65. if hasattr(self, 'test_transforms'):
  66. if self.test_transforms is not None:
  67. has_resize_op = False
  68. resize_op_idx = -1
  69. normalize_op_idx = len(self.test_transforms.transforms)
  70. for idx, op in enumerate(self.test_transforms.transforms):
  71. name = op.__class__.__name__
  72. if name == 'Normalize':
  73. normalize_op_idx = idx
  74. if 'Resize' in name:
  75. has_resize_op = True
  76. resize_op_idx = idx
  77. if not has_resize_op:
  78. self.test_transforms.transforms.insert(
  79. normalize_op_idx, Resize(target_size=image_shape))
  80. else:
  81. self.test_transforms.transforms[resize_op_idx] = Resize(
  82. target_size=image_shape)
  83. def _get_test_inputs(self, image_shape):
  84. if image_shape is not None:
  85. if len(image_shape) == 2:
  86. image_shape = [1, 3] + image_shape
  87. self._fix_transforms_shape(image_shape[-2:])
  88. else:
  89. image_shape = [None, 3, -1, -1]
  90. self.fixed_input_shape = image_shape
  91. input_spec = [
  92. InputSpec(
  93. shape=image_shape, name='image', dtype='float32')
  94. ]
  95. return input_spec
  96. def run(self, net, inputs, mode):
  97. net_out = net(inputs[0])
  98. logit = net_out[0]
  99. outputs = OrderedDict()
  100. if mode == 'test':
  101. origin_shape = inputs[1]
  102. if self.status == 'Infer':
  103. label_map_list, score_map_list = self.postprocess(
  104. net_out, origin_shape, transforms=inputs[2])
  105. else:
  106. logit_list = self.postprocess(
  107. logit, origin_shape, transforms=inputs[2])
  108. label_map_list = []
  109. score_map_list = []
  110. for logit in logit_list:
  111. logit = paddle.transpose(logit, perm=[0, 2, 3, 1]) # NHWC
  112. label_map_list.append(
  113. paddle.argmax(
  114. logit, axis=-1, keepdim=False, dtype='int32')
  115. .squeeze().numpy())
  116. score_map_list.append(
  117. F.softmax(
  118. logit, axis=-1).squeeze().numpy().astype('float32'))
  119. outputs['label_map'] = label_map_list
  120. outputs['score_map'] = score_map_list
  121. if mode == 'eval':
  122. if self.status == 'Infer':
  123. pred = paddle.unsqueeze(net_out[0], axis=1) # NCHW
  124. else:
  125. pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
  126. label = inputs[1]
  127. if label.ndim == 3:
  128. paddle.unsqueeze_(label, axis=1)
  129. if label.ndim != 4:
  130. raise ValueError("Expected label.ndim == 4 but got {}".format(
  131. label.ndim))
  132. origin_shape = [label.shape[-2:]]
  133. pred = self.postprocess(
  134. pred, origin_shape, transforms=inputs[2])[0] # NCHW
  135. intersect_area, pred_area, label_area = ppseg.utils.metrics.calculate_area(
  136. pred, label, self.num_classes)
  137. outputs['intersect_area'] = intersect_area
  138. outputs['pred_area'] = pred_area
  139. outputs['label_area'] = label_area
  140. outputs['conf_mat'] = metrics.confusion_matrix(pred, label,
  141. self.num_classes)
  142. if mode == 'train':
  143. loss_list = metrics.loss_computation(
  144. logits_list=net_out, labels=inputs[1], losses=self.losses)
  145. loss = sum(loss_list)
  146. outputs['loss'] = loss
  147. return outputs
  148. def default_loss(self):
  149. if isinstance(self.use_mixed_loss, bool):
  150. if self.use_mixed_loss:
  151. losses = [
  152. seg_losses.CrossEntropyLoss(),
  153. seg_losses.LovaszSoftmaxLoss()
  154. ]
  155. coef = [.8, .2]
  156. loss_type = [seg_losses.MixedLoss(losses=losses, coef=coef), ]
  157. else:
  158. loss_type = [seg_losses.CrossEntropyLoss()]
  159. else:
  160. losses, coef = list(zip(*self.use_mixed_loss))
  161. if not set(losses).issubset(
  162. ['CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss']):
  163. raise ValueError(
  164. "Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported."
  165. )
  166. losses = [getattr(seg_losses, loss)() for loss in losses]
  167. loss_type = [seg_losses.MixedLoss(losses=losses, coef=list(coef))]
  168. if self.model_name == 'FastSCNN':
  169. loss_type *= 2
  170. loss_coef = [1.0, 0.4]
  171. elif self.model_name == 'BiSeNetV2':
  172. loss_type *= 5
  173. loss_coef = [1.0] * 5
  174. else:
  175. loss_coef = [1.0]
  176. losses = {'types': loss_type, 'coef': loss_coef}
  177. return losses
  178. def default_optimizer(self,
  179. parameters,
  180. learning_rate,
  181. num_epochs,
  182. num_steps_each_epoch,
  183. lr_decay_power=0.9):
  184. decay_step = num_epochs * num_steps_each_epoch
  185. lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
  186. learning_rate, decay_step, end_lr=0, power=lr_decay_power)
  187. optimizer = paddle.optimizer.Momentum(
  188. learning_rate=lr_scheduler,
  189. parameters=parameters,
  190. momentum=0.9,
  191. weight_decay=4e-5)
  192. return optimizer
  193. def train(self,
  194. num_epochs,
  195. train_dataset,
  196. train_batch_size=2,
  197. eval_dataset=None,
  198. optimizer=None,
  199. save_interval_epochs=1,
  200. log_interval_steps=2,
  201. save_dir='output',
  202. pretrain_weights='CITYSCAPES',
  203. learning_rate=0.01,
  204. lr_decay_power=0.9,
  205. early_stop=False,
  206. early_stop_patience=5,
  207. use_vdl=True,
  208. resume_checkpoint=None):
  209. """
  210. Train the model.
  211. Args:
  212. num_epochs (int): Number of epochs.
  213. train_dataset (paddlers.datasets.SegDataset): Training dataset.
  214. train_batch_size (int, optional): Total batch size among all cards used in
  215. training. Defaults to 2.
  216. eval_dataset (paddlers.datasets.SegDataset|None, optional): Evaluation dataset.
  217. If None, the model will not be evaluated during training process.
  218. Defaults to None.
  219. optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in
  220. training. If None, a default optimizer will be used. Defaults to None.
  221. save_interval_epochs (int, optional): Epoch interval for saving the model.
  222. Defaults to 1.
  223. log_interval_steps (int, optional): Step interval for printing training
  224. information. Defaults to 2.
  225. save_dir (str, optional): Directory to save the model. Defaults to 'output'.
  226. pretrain_weights (str|None, optional): None or name/path of pretrained
  227. weights. If None, no pretrained weights will be loaded.
  228. Defaults to 'CITYSCAPES'.
  229. learning_rate (float, optional): Learning rate for training. Defaults to .01.
  230. lr_decay_power (float, optional): Learning decay power. Defaults to .9.
  231. early_stop (bool, optional): Whether to adopt early stop strategy. Defaults
  232. to False.
  233. early_stop_patience (int, optional): Early stop patience. Defaults to 5.
  234. use_vdl (bool, optional): Whether to use VisualDL to monitor the training
  235. process. Defaults to True.
  236. resume_checkpoint (str|None, optional): Path of the checkpoint to resume
  237. training from. If None, no training checkpoint will be resumed. At most
  238. Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
  239. Defaults to None.
  240. """
  241. if self.status == 'Infer':
  242. logging.error(
  243. "Exported inference model does not support training.",
  244. exit=True)
  245. if pretrain_weights is not None and resume_checkpoint is not None:
  246. logging.error(
  247. "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
  248. exit=True)
  249. self.labels = train_dataset.labels
  250. if self.losses is None:
  251. self.losses = self.default_loss()
  252. if optimizer is None:
  253. num_steps_each_epoch = train_dataset.num_samples // train_batch_size
  254. self.optimizer = self.default_optimizer(
  255. self.net.parameters(), learning_rate, num_epochs,
  256. num_steps_each_epoch, lr_decay_power)
  257. else:
  258. self.optimizer = optimizer
  259. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  260. if pretrain_weights not in seg_pretrain_weights_dict[
  261. self.model_name]:
  262. logging.warning(
  263. "Path of pretrain_weights('{}') does not exist!".format(
  264. pretrain_weights))
  265. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  266. "If don't want to use pretrain weights, "
  267. "set pretrain_weights to be None.".format(
  268. seg_pretrain_weights_dict[self.model_name][
  269. 0]))
  270. pretrain_weights = seg_pretrain_weights_dict[self.model_name][0]
  271. elif pretrain_weights is not None and osp.exists(pretrain_weights):
  272. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  273. logging.error(
  274. "Invalid pretrain weights. Please specify a '.pdparams' file.",
  275. exit=True)
  276. pretrained_dir = osp.join(save_dir, 'pretrain')
  277. is_backbone_weights = pretrain_weights == 'IMAGENET'
  278. self.net_initialize(
  279. pretrain_weights=pretrain_weights,
  280. save_dir=pretrained_dir,
  281. resume_checkpoint=resume_checkpoint,
  282. is_backbone_weights=is_backbone_weights)
  283. self.train_loop(
  284. num_epochs=num_epochs,
  285. train_dataset=train_dataset,
  286. train_batch_size=train_batch_size,
  287. eval_dataset=eval_dataset,
  288. save_interval_epochs=save_interval_epochs,
  289. log_interval_steps=log_interval_steps,
  290. save_dir=save_dir,
  291. early_stop=early_stop,
  292. early_stop_patience=early_stop_patience,
  293. use_vdl=use_vdl)
  294. def quant_aware_train(self,
  295. num_epochs,
  296. train_dataset,
  297. train_batch_size=2,
  298. eval_dataset=None,
  299. optimizer=None,
  300. save_interval_epochs=1,
  301. log_interval_steps=2,
  302. save_dir='output',
  303. learning_rate=0.0001,
  304. lr_decay_power=0.9,
  305. early_stop=False,
  306. early_stop_patience=5,
  307. use_vdl=True,
  308. resume_checkpoint=None,
  309. quant_config=None):
  310. """
  311. Quantization-aware training.
  312. Args:
  313. num_epochs (int): Number of epochs.
  314. train_dataset (paddlers.datasets.SegDataset): Training dataset.
  315. train_batch_size (int, optional): Total batch size among all cards used in
  316. training. Defaults to 2.
  317. eval_dataset (paddlers.datasets.SegDataset|None, optional): Evaluation dataset.
  318. If None, the model will not be evaluated during training process.
  319. Defaults to None.
  320. optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in
  321. training. If None, a default optimizer will be used. Defaults to None.
  322. save_interval_epochs (int, optional): Epoch interval for saving the model.
  323. Defaults to 1.
  324. log_interval_steps (int, optional): Step interval for printing training
  325. information. Defaults to 2.
  326. save_dir (str, optional): Directory to save the model. Defaults to 'output'.
  327. learning_rate (float, optional): Learning rate for training.
  328. Defaults to .0001.
  329. lr_decay_power (float, optional): Learning decay power. Defaults to .9.
  330. early_stop (bool, optional): Whether to adopt early stop strategy.
  331. Defaults to False.
  332. early_stop_patience (int, optional): Early stop patience. Defaults to 5.
  333. use_vdl (bool, optional): Whether to use VisualDL to monitor the training
  334. process. Defaults to True.
  335. quant_config (dict|None, optional): Quantization configuration. If None,
  336. a default rule of thumb configuration will be used. Defaults to None.
  337. resume_checkpoint (str|None, optional): Path of the checkpoint to resume
  338. quantization-aware training from. If None, no training checkpoint will
  339. be resumed. Defaults to None.
  340. """
  341. self._prepare_qat(quant_config)
  342. self.train(
  343. num_epochs=num_epochs,
  344. train_dataset=train_dataset,
  345. train_batch_size=train_batch_size,
  346. eval_dataset=eval_dataset,
  347. optimizer=optimizer,
  348. save_interval_epochs=save_interval_epochs,
  349. log_interval_steps=log_interval_steps,
  350. save_dir=save_dir,
  351. pretrain_weights=None,
  352. learning_rate=learning_rate,
  353. lr_decay_power=lr_decay_power,
  354. early_stop=early_stop,
  355. early_stop_patience=early_stop_patience,
  356. use_vdl=use_vdl,
  357. resume_checkpoint=resume_checkpoint)
  358. def evaluate(self, eval_dataset, batch_size=1, return_details=False):
  359. """
  360. Evaluate the model.
  361. Args:
  362. eval_dataset (paddlers.datasets.SegDataset): Evaluation dataset.
  363. batch_size (int, optional): Total batch size among all cards used for
  364. evaluation. Defaults to 1.
  365. return_details (bool, optional): Whether to return evaluation details.
  366. Defaults to False.
  367. Returns:
  368. collections.OrderedDict with key-value pairs:
  369. {"miou": `mean intersection over union`,
  370. "category_iou": `category-wise mean intersection over union`,
  371. "oacc": `overall accuracy`,
  372. "category_acc": `category-wise accuracy`,
  373. "kappa": ` kappa coefficient`,
  374. "category_F1-score": `F1 score`}.
  375. """
  376. self._check_transforms(eval_dataset.transforms, 'eval')
  377. self.net.eval()
  378. nranks = paddle.distributed.get_world_size()
  379. local_rank = paddle.distributed.get_rank()
  380. if nranks > 1:
  381. # Initialize parallel environment if not done.
  382. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  383. ):
  384. paddle.distributed.init_parallel_env()
  385. batch_size_each_card = get_single_card_bs(batch_size)
  386. if batch_size_each_card > 1:
  387. batch_size_each_card = 1
  388. batch_size = batch_size_each_card * paddlers.env_info['num']
  389. logging.warning(
  390. "Segmenter only supports batch_size=1 for each gpu/cpu card " \
  391. "during evaluation, so batch_size " \
  392. "is forcibly set to {}.".format(batch_size))
  393. self.eval_data_loader = self.build_data_loader(
  394. eval_dataset, batch_size=batch_size, mode='eval')
  395. intersect_area_all = 0
  396. pred_area_all = 0
  397. label_area_all = 0
  398. conf_mat_all = []
  399. logging.info(
  400. "Start to evaluate(total_samples={}, total_steps={})...".format(
  401. eval_dataset.num_samples,
  402. math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
  403. with paddle.no_grad():
  404. for step, data in enumerate(self.eval_data_loader):
  405. data.append(eval_dataset.transforms.transforms)
  406. outputs = self.run(self.net, data, 'eval')
  407. pred_area = outputs['pred_area']
  408. label_area = outputs['label_area']
  409. intersect_area = outputs['intersect_area']
  410. conf_mat = outputs['conf_mat']
  411. # Gather from all ranks
  412. if nranks > 1:
  413. intersect_area_list = []
  414. pred_area_list = []
  415. label_area_list = []
  416. conf_mat_list = []
  417. paddle.distributed.all_gather(intersect_area_list,
  418. intersect_area)
  419. paddle.distributed.all_gather(pred_area_list, pred_area)
  420. paddle.distributed.all_gather(label_area_list, label_area)
  421. paddle.distributed.all_gather(conf_mat_list, conf_mat)
  422. # Some image has been evaluated and should be eliminated in last iter
  423. if (step + 1) * nranks > len(eval_dataset):
  424. valid = len(eval_dataset) - step * nranks
  425. intersect_area_list = intersect_area_list[:valid]
  426. pred_area_list = pred_area_list[:valid]
  427. label_area_list = label_area_list[:valid]
  428. conf_mat_list = conf_mat_list[:valid]
  429. intersect_area_all += sum(intersect_area_list)
  430. pred_area_all += sum(pred_area_list)
  431. label_area_all += sum(label_area_list)
  432. conf_mat_all.extend(conf_mat_list)
  433. else:
  434. intersect_area_all = intersect_area_all + intersect_area
  435. pred_area_all = pred_area_all + pred_area
  436. label_area_all = label_area_all + label_area
  437. conf_mat_all.append(conf_mat)
  438. class_iou, miou = ppseg.utils.metrics.mean_iou(
  439. intersect_area_all, pred_area_all, label_area_all)
  440. # TODO 确认是按oacc还是macc
  441. class_acc, oacc = ppseg.utils.metrics.accuracy(intersect_area_all,
  442. pred_area_all)
  443. kappa = ppseg.utils.metrics.kappa(intersect_area_all, pred_area_all,
  444. label_area_all)
  445. category_f1score = metrics.f1_score(intersect_area_all, pred_area_all,
  446. label_area_all)
  447. eval_metrics = OrderedDict(
  448. zip([
  449. 'miou', 'category_iou', 'oacc', 'category_acc', 'kappa',
  450. 'category_F1-score'
  451. ], [miou, class_iou, oacc, class_acc, kappa, category_f1score]))
  452. if return_details:
  453. conf_mat = sum(conf_mat_all)
  454. eval_details = {'confusion_matrix': conf_mat.tolist()}
  455. return eval_metrics, eval_details
  456. return eval_metrics
  457. def predict(self, img_file, transforms=None):
  458. """
  459. Do inference.
  460. Args:
  461. img_file (list[np.ndarray|str] | str | np.ndarray): Image path or decoded
  462. image data, which also could constitute a list, meaning all images to be
  463. predicted as a mini-batch.
  464. transforms (paddlers.transforms.Compose|None, optional): Transforms for
  465. inputs. If None, the transforms for evaluation process will be used.
  466. Defaults to None.
  467. Returns:
  468. If `img_file` is a string or np.array, the result is a dict with key-value
  469. pairs:
  470. {"label map": `label map`, "score_map": `score map`}.
  471. If `img_file` is a list, the result is a list composed of dicts with the
  472. corresponding fields:
  473. label_map (np.ndarray): the predicted label map (HW)
  474. score_map (np.ndarray): the prediction score map (HWC)
  475. """
  476. if transforms is None and not hasattr(self, 'test_transforms'):
  477. raise ValueError("transforms need to be defined, now is None.")
  478. if transforms is None:
  479. transforms = self.test_transforms
  480. if isinstance(img_file, (str, np.ndarray)):
  481. images = [img_file]
  482. else:
  483. images = img_file
  484. batch_im, batch_origin_shape = self.preprocess(images, transforms,
  485. self.model_type)
  486. self.net.eval()
  487. data = (batch_im, batch_origin_shape, transforms.transforms)
  488. outputs = self.run(self.net, data, 'test')
  489. label_map_list = outputs['label_map']
  490. score_map_list = outputs['score_map']
  491. if isinstance(img_file, list):
  492. prediction = [{
  493. 'label_map': l,
  494. 'score_map': s
  495. } for l, s in zip(label_map_list, score_map_list)]
  496. else:
  497. prediction = {
  498. 'label_map': label_map_list[0],
  499. 'score_map': score_map_list[0]
  500. }
  501. return prediction
  502. def slider_predict(self,
  503. img_file,
  504. save_dir,
  505. block_size,
  506. overlap=36,
  507. transforms=None):
  508. """
  509. Do inference.
  510. Args:
  511. img_file (str): Image path.
  512. save_dir (str): Directory that contains saved geotiff file.
  513. block_size (list[int] | tuple[int] | int):
  514. Size of block.
  515. overlap (list[int] | tuple[int] | int, optional):
  516. Overlap between two blocks. Defaults to 36.
  517. transforms (paddlers.transforms.Compose|None, optional): Transforms for
  518. inputs. If None, the transforms for evaluation process will be used.
  519. Defaults to None.
  520. """
  521. try:
  522. from osgeo import gdal
  523. except:
  524. import gdal
  525. if isinstance(block_size, int):
  526. block_size = (block_size, block_size)
  527. elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
  528. block_size = tuple(block_size)
  529. else:
  530. raise ValueError(
  531. "`block_size` must be a tuple/list of length 2 or an integer.")
  532. if isinstance(overlap, int):
  533. overlap = (overlap, overlap)
  534. elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
  535. overlap = tuple(overlap)
  536. else:
  537. raise ValueError(
  538. "`overlap` must be a tuple/list of length 2 or an integer.")
  539. src_data = gdal.Open(img_file)
  540. width = src_data.RasterXSize
  541. height = src_data.RasterYSize
  542. bands = src_data.RasterCount
  543. driver = gdal.GetDriverByName("GTiff")
  544. file_name = osp.splitext(osp.normpath(img_file).split(os.sep)[-1])[
  545. 0] + ".tif"
  546. if not osp.exists(save_dir):
  547. os.makedirs(save_dir)
  548. save_file = osp.join(save_dir, file_name)
  549. dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
  550. dst_data.SetGeoTransform(src_data.GetGeoTransform())
  551. dst_data.SetProjection(src_data.GetProjection())
  552. band = dst_data.GetRasterBand(1)
  553. band.WriteArray(255 * np.ones((height, width), dtype="uint8"))
  554. step = np.array(block_size) - np.array(overlap)
  555. for yoff in range(0, height, step[1]):
  556. for xoff in range(0, width, step[0]):
  557. xsize, ysize = block_size
  558. if xoff + xsize > width:
  559. xsize = int(width - xoff)
  560. if yoff + ysize > height:
  561. ysize = int(height - yoff)
  562. im = src_data.ReadAsArray(int(xoff), int(yoff), xsize,
  563. ysize).transpose((1, 2, 0))
  564. # Fill
  565. h, w = im.shape[:2]
  566. im_fill = np.zeros(
  567. (block_size[1], block_size[0], bands), dtype=im.dtype)
  568. im_fill[:h, :w, :] = im
  569. # Predict
  570. pred = self.predict(im_fill,
  571. transforms)["label_map"].astype("uint8")
  572. # Overlap
  573. rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize)
  574. mask = (rd_block == pred[:h, :w]) | (rd_block == 255)
  575. temp = pred[:h, :w].copy()
  576. temp[mask == False] = 0
  577. band.WriteArray(temp, int(xoff), int(yoff))
  578. dst_data.FlushCache()
  579. dst_data = None
  580. print("GeoTiff saved in {}.".format(save_file))
  581. def preprocess(self, images, transforms, to_tensor=True):
  582. self._check_transforms(transforms, 'test')
  583. batch_im = list()
  584. batch_ori_shape = list()
  585. for im in images:
  586. if isinstance(im, str):
  587. im = decode_image(im, to_rgb=False)
  588. ori_shape = im.shape[:2]
  589. sample = {'image': im}
  590. im = transforms(sample)[0]
  591. batch_im.append(im)
  592. batch_ori_shape.append(ori_shape)
  593. if to_tensor:
  594. batch_im = paddle.to_tensor(batch_im)
  595. else:
  596. batch_im = np.asarray(batch_im)
  597. return batch_im, batch_ori_shape
  598. @staticmethod
  599. def get_transforms_shape_info(batch_ori_shape, transforms):
  600. batch_restore_list = list()
  601. for ori_shape in batch_ori_shape:
  602. restore_list = list()
  603. h, w = ori_shape[0], ori_shape[1]
  604. for op in transforms:
  605. if op.__class__.__name__ == 'Resize':
  606. restore_list.append(('resize', (h, w)))
  607. h, w = op.target_size
  608. elif op.__class__.__name__ == 'ResizeByShort':
  609. restore_list.append(('resize', (h, w)))
  610. im_short_size = min(h, w)
  611. im_long_size = max(h, w)
  612. scale = float(op.short_size) / float(im_short_size)
  613. if 0 < op.max_size < np.round(scale * im_long_size):
  614. scale = float(op.max_size) / float(im_long_size)
  615. h = int(round(h * scale))
  616. w = int(round(w * scale))
  617. elif op.__class__.__name__ == 'ResizeByLong':
  618. restore_list.append(('resize', (h, w)))
  619. im_long_size = max(h, w)
  620. scale = float(op.long_size) / float(im_long_size)
  621. h = int(round(h * scale))
  622. w = int(round(w * scale))
  623. elif op.__class__.__name__ == 'Pad':
  624. if op.target_size:
  625. target_h, target_w = op.target_size
  626. else:
  627. target_h = int(
  628. (np.ceil(h / op.size_divisor) * op.size_divisor))
  629. target_w = int(
  630. (np.ceil(w / op.size_divisor) * op.size_divisor))
  631. if op.pad_mode == -1:
  632. offsets = op.offsets
  633. elif op.pad_mode == 0:
  634. offsets = [0, 0]
  635. elif op.pad_mode == 1:
  636. offsets = [(target_h - h) // 2, (target_w - w) // 2]
  637. else:
  638. offsets = [target_h - h, target_w - w]
  639. restore_list.append(('padding', (h, w), offsets))
  640. h, w = target_h, target_w
  641. batch_restore_list.append(restore_list)
  642. return batch_restore_list
  643. def postprocess(self, batch_pred, batch_origin_shape, transforms):
  644. batch_restore_list = BaseSegmenter.get_transforms_shape_info(
  645. batch_origin_shape, transforms)
  646. if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
  647. return self._infer_postprocess(
  648. batch_label_map=batch_pred[0],
  649. batch_score_map=batch_pred[1],
  650. batch_restore_list=batch_restore_list)
  651. results = []
  652. if batch_pred.dtype == paddle.float32:
  653. mode = 'bilinear'
  654. else:
  655. mode = 'nearest'
  656. for pred, restore_list in zip(batch_pred, batch_restore_list):
  657. pred = paddle.unsqueeze(pred, axis=0)
  658. for item in restore_list[::-1]:
  659. h, w = item[1][0], item[1][1]
  660. if item[0] == 'resize':
  661. pred = F.interpolate(
  662. pred, (h, w), mode=mode, data_format='NCHW')
  663. elif item[0] == 'padding':
  664. x, y = item[2]
  665. pred = pred[:, :, y:y + h, x:x + w]
  666. else:
  667. pass
  668. results.append(pred)
  669. return results
  670. def _infer_postprocess(self, batch_label_map, batch_score_map,
  671. batch_restore_list):
  672. label_maps = []
  673. score_maps = []
  674. for label_map, score_map, restore_list in zip(
  675. batch_label_map, batch_score_map, batch_restore_list):
  676. if not isinstance(label_map, np.ndarray):
  677. label_map = paddle.unsqueeze(label_map, axis=[0, 3])
  678. score_map = paddle.unsqueeze(score_map, axis=0)
  679. for item in restore_list[::-1]:
  680. h, w = item[1][0], item[1][1]
  681. if item[0] == 'resize':
  682. if isinstance(label_map, np.ndarray):
  683. label_map = cv2.resize(
  684. label_map, (w, h), interpolation=cv2.INTER_NEAREST)
  685. score_map = cv2.resize(
  686. score_map, (w, h), interpolation=cv2.INTER_LINEAR)
  687. else:
  688. label_map = F.interpolate(
  689. label_map, (h, w),
  690. mode='nearest',
  691. data_format='NHWC')
  692. score_map = F.interpolate(
  693. score_map, (h, w),
  694. mode='bilinear',
  695. data_format='NHWC')
  696. elif item[0] == 'padding':
  697. x, y = item[2]
  698. if isinstance(label_map, np.ndarray):
  699. label_map = label_map[..., y:y + h, x:x + w]
  700. score_map = score_map[..., y:y + h, x:x + w]
  701. else:
  702. label_map = label_map[:, :, y:y + h, x:x + w]
  703. score_map = score_map[:, :, y:y + h, x:x + w]
  704. else:
  705. pass
  706. label_map = label_map.squeeze()
  707. score_map = score_map.squeeze()
  708. if not isinstance(label_map, np.ndarray):
  709. label_map = label_map.numpy()
  710. score_map = score_map.numpy()
  711. label_maps.append(label_map.squeeze())
  712. score_maps.append(score_map.squeeze())
  713. return label_maps, score_maps
  714. def _check_transforms(self, transforms, mode):
  715. super()._check_transforms(transforms, mode)
  716. if not isinstance(transforms.arrange,
  717. paddlers.transforms.ArrangeSegmenter):
  718. raise TypeError(
  719. "`transforms.arrange` must be an ArrangeSegmenter object.")
  720. def set_losses(self, losses, weights=None):
  721. if weights is None:
  722. weights = [1. for _ in range(len(losses))]
  723. self.losses = {'types': losses, 'coef': weights}
  724. class UNet(BaseSegmenter):
  725. def __init__(self,
  726. in_channels=3,
  727. num_classes=2,
  728. use_mixed_loss=False,
  729. losses=None,
  730. use_deconv=False,
  731. align_corners=False,
  732. **params):
  733. params.update({
  734. 'use_deconv': use_deconv,
  735. 'align_corners': align_corners
  736. })
  737. super(UNet, self).__init__(
  738. model_name='UNet',
  739. input_channel=in_channels,
  740. num_classes=num_classes,
  741. use_mixed_loss=use_mixed_loss,
  742. losses=losses,
  743. **params)
  744. class DeepLabV3P(BaseSegmenter):
  745. def __init__(self,
  746. in_channels=3,
  747. num_classes=2,
  748. backbone='ResNet50_vd',
  749. use_mixed_loss=False,
  750. losses=None,
  751. output_stride=8,
  752. backbone_indices=(0, 3),
  753. aspp_ratios=(1, 12, 24, 36),
  754. aspp_out_channels=256,
  755. align_corners=False,
  756. **params):
  757. self.backbone_name = backbone
  758. if backbone not in ['ResNet50_vd', 'ResNet101_vd']:
  759. raise ValueError(
  760. "backbone: {} is not supported. Please choose one of "
  761. "{'ResNet50_vd', 'ResNet101_vd'}.".format(backbone))
  762. if params.get('with_net', True):
  763. with DisablePrint():
  764. backbone = getattr(ppseg.models, backbone)(
  765. input_channel=in_channels, output_stride=output_stride)
  766. else:
  767. backbone = None
  768. params.update({
  769. 'backbone': backbone,
  770. 'backbone_indices': backbone_indices,
  771. 'aspp_ratios': aspp_ratios,
  772. 'aspp_out_channels': aspp_out_channels,
  773. 'align_corners': align_corners
  774. })
  775. super(DeepLabV3P, self).__init__(
  776. model_name='DeepLabV3P',
  777. num_classes=num_classes,
  778. use_mixed_loss=use_mixed_loss,
  779. losses=losses,
  780. **params)
  781. class FastSCNN(BaseSegmenter):
  782. def __init__(self,
  783. num_classes=2,
  784. use_mixed_loss=False,
  785. losses=None,
  786. align_corners=False,
  787. **params):
  788. params.update({'align_corners': align_corners})
  789. super(FastSCNN, self).__init__(
  790. model_name='FastSCNN',
  791. num_classes=num_classes,
  792. use_mixed_loss=use_mixed_loss,
  793. losses=losses,
  794. **params)
  795. class HRNet(BaseSegmenter):
  796. def __init__(self,
  797. num_classes=2,
  798. width=48,
  799. use_mixed_loss=False,
  800. losses=None,
  801. align_corners=False,
  802. **params):
  803. if width not in (18, 48):
  804. raise ValueError(
  805. "width={} is not supported, please choose from {18, 48}.".
  806. format(width))
  807. self.backbone_name = 'HRNet_W{}'.format(width)
  808. if params.get('with_net', True):
  809. with DisablePrint():
  810. backbone = getattr(ppseg.models, self.backbone_name)(
  811. align_corners=align_corners)
  812. else:
  813. backbone = None
  814. params.update({'backbone': backbone, 'align_corners': align_corners})
  815. super(HRNet, self).__init__(
  816. model_name='FCN',
  817. num_classes=num_classes,
  818. use_mixed_loss=use_mixed_loss,
  819. losses=losses,
  820. **params)
  821. self.model_name = 'HRNet'
  822. class BiSeNetV2(BaseSegmenter):
  823. def __init__(self,
  824. num_classes=2,
  825. use_mixed_loss=False,
  826. losses=None,
  827. align_corners=False,
  828. **params):
  829. params.update({'align_corners': align_corners})
  830. super(BiSeNetV2, self).__init__(
  831. model_name='BiSeNetV2',
  832. num_classes=num_classes,
  833. use_mixed_loss=use_mixed_loss,
  834. losses=losses,
  835. **params)
  836. class FarSeg(BaseSegmenter):
  837. def __init__(self,
  838. num_classes=2,
  839. use_mixed_loss=False,
  840. losses=None,
  841. **params):
  842. super(FarSeg, self).__init__(
  843. model_name='FarSeg',
  844. num_classes=num_classes,
  845. use_mixed_loss=use_mixed_loss,
  846. losses=losses,
  847. **params)