segmenter.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import os
  16. import os.path as osp
  17. from collections import OrderedDict
  18. import numpy as np
  19. import cv2
  20. import paddle
  21. import paddle.nn.functional as F
  22. from paddle.static import InputSpec
  23. import paddlers.models.ppseg as paddleseg
  24. import paddlers.rs_models.seg as cmseg
  25. import paddlers
  26. from paddlers.utils import get_single_card_bs, DisablePrint
  27. import paddlers.utils.logging as logging
  28. from .base import BaseModel
  29. from .utils import seg_metrics as metrics
  30. from paddlers.utils.checkpoint import seg_pretrain_weights_dict
  31. from paddlers.transforms import Resize, decode_image
  32. __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"]
  33. class BaseSegmenter(BaseModel):
  34. def __init__(self,
  35. model_name,
  36. num_classes=2,
  37. use_mixed_loss=False,
  38. **params):
  39. self.init_params = locals()
  40. if 'with_net' in self.init_params:
  41. del self.init_params['with_net']
  42. super(BaseSegmenter, self).__init__('segmenter')
  43. if not hasattr(paddleseg.models, model_name) and \
  44. not hasattr(cmseg, model_name):
  45. raise ValueError("ERROR: There is no model named {}.".format(
  46. model_name))
  47. self.model_name = model_name
  48. self.num_classes = num_classes
  49. self.use_mixed_loss = use_mixed_loss
  50. self.losses = None
  51. self.labels = None
  52. if params.get('with_net', True):
  53. params.pop('with_net', None)
  54. self.net = self.build_net(**params)
  55. self.find_unused_parameters = True
  56. def build_net(self, **params):
  57. # TODO: when using paddle.utils.unique_name.guard,
  58. # DeepLabv3p and HRNet will raise a error
  59. net = dict(paddleseg.models.__dict__,
  60. **cmseg.__dict__)[self.model_name](
  61. num_classes=self.num_classes, **params)
  62. return net
  63. def _fix_transforms_shape(self, image_shape):
  64. if hasattr(self, 'test_transforms'):
  65. if self.test_transforms is not None:
  66. has_resize_op = False
  67. resize_op_idx = -1
  68. normalize_op_idx = len(self.test_transforms.transforms)
  69. for idx, op in enumerate(self.test_transforms.transforms):
  70. name = op.__class__.__name__
  71. if name == 'Normalize':
  72. normalize_op_idx = idx
  73. if 'Resize' in name:
  74. has_resize_op = True
  75. resize_op_idx = idx
  76. if not has_resize_op:
  77. self.test_transforms.transforms.insert(
  78. normalize_op_idx, Resize(target_size=image_shape))
  79. else:
  80. self.test_transforms.transforms[resize_op_idx] = Resize(
  81. target_size=image_shape)
  82. def _get_test_inputs(self, image_shape):
  83. if image_shape is not None:
  84. if len(image_shape) == 2:
  85. image_shape = [1, 3] + image_shape
  86. self._fix_transforms_shape(image_shape[-2:])
  87. else:
  88. image_shape = [None, 3, -1, -1]
  89. self.fixed_input_shape = image_shape
  90. input_spec = [
  91. InputSpec(
  92. shape=image_shape, name='image', dtype='float32')
  93. ]
  94. return input_spec
  95. def run(self, net, inputs, mode):
  96. net_out = net(inputs[0])
  97. logit = net_out[0]
  98. outputs = OrderedDict()
  99. if mode == 'test':
  100. origin_shape = inputs[1]
  101. if self.status == 'Infer':
  102. label_map_list, score_map_list = self._postprocess(
  103. net_out, origin_shape, transforms=inputs[2])
  104. else:
  105. logit_list = self._postprocess(
  106. logit, origin_shape, transforms=inputs[2])
  107. label_map_list = []
  108. score_map_list = []
  109. for logit in logit_list:
  110. logit = paddle.transpose(logit, perm=[0, 2, 3, 1]) # NHWC
  111. label_map_list.append(
  112. paddle.argmax(
  113. logit, axis=-1, keepdim=False, dtype='int32')
  114. .squeeze().numpy())
  115. score_map_list.append(
  116. F.softmax(
  117. logit, axis=-1).squeeze().numpy().astype('float32'))
  118. outputs['label_map'] = label_map_list
  119. outputs['score_map'] = score_map_list
  120. if mode == 'eval':
  121. if self.status == 'Infer':
  122. pred = paddle.unsqueeze(net_out[0], axis=1) # NCHW
  123. else:
  124. pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
  125. label = inputs[1]
  126. if label.ndim == 3:
  127. paddle.unsqueeze_(label, axis=1)
  128. if label.ndim != 4:
  129. raise ValueError("Expected label.ndim == 4 but got {}".format(
  130. label.ndim))
  131. origin_shape = [label.shape[-2:]]
  132. pred = self._postprocess(
  133. pred, origin_shape, transforms=inputs[2])[0] # NCHW
  134. intersect_area, pred_area, label_area = paddleseg.utils.metrics.calculate_area(
  135. pred, label, self.num_classes)
  136. outputs['intersect_area'] = intersect_area
  137. outputs['pred_area'] = pred_area
  138. outputs['label_area'] = label_area
  139. outputs['conf_mat'] = metrics.confusion_matrix(pred, label,
  140. self.num_classes)
  141. if mode == 'train':
  142. loss_list = metrics.loss_computation(
  143. logits_list=net_out, labels=inputs[1], losses=self.losses)
  144. loss = sum(loss_list)
  145. outputs['loss'] = loss
  146. return outputs
  147. def default_loss(self):
  148. if isinstance(self.use_mixed_loss, bool):
  149. if self.use_mixed_loss:
  150. losses = [
  151. paddleseg.models.CrossEntropyLoss(),
  152. paddleseg.models.LovaszSoftmaxLoss()
  153. ]
  154. coef = [.8, .2]
  155. loss_type = [
  156. paddleseg.models.MixedLoss(
  157. losses=losses, coef=coef),
  158. ]
  159. else:
  160. loss_type = [paddleseg.models.CrossEntropyLoss()]
  161. else:
  162. losses, coef = list(zip(*self.use_mixed_loss))
  163. if not set(losses).issubset(
  164. ['CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss']):
  165. raise ValueError(
  166. "Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported."
  167. )
  168. losses = [getattr(paddleseg.models, loss)() for loss in losses]
  169. loss_type = [
  170. paddleseg.models.MixedLoss(
  171. losses=losses, coef=list(coef))
  172. ]
  173. if self.model_name == 'FastSCNN':
  174. loss_type *= 2
  175. loss_coef = [1.0, 0.4]
  176. elif self.model_name == 'BiSeNetV2':
  177. loss_type *= 5
  178. loss_coef = [1.0] * 5
  179. else:
  180. loss_coef = [1.0]
  181. losses = {'types': loss_type, 'coef': loss_coef}
  182. return losses
  183. def default_optimizer(self,
  184. parameters,
  185. learning_rate,
  186. num_epochs,
  187. num_steps_each_epoch,
  188. lr_decay_power=0.9):
  189. decay_step = num_epochs * num_steps_each_epoch
  190. lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
  191. learning_rate, decay_step, end_lr=0, power=lr_decay_power)
  192. optimizer = paddle.optimizer.Momentum(
  193. learning_rate=lr_scheduler,
  194. parameters=parameters,
  195. momentum=0.9,
  196. weight_decay=4e-5)
  197. return optimizer
  198. def train(self,
  199. num_epochs,
  200. train_dataset,
  201. train_batch_size=2,
  202. eval_dataset=None,
  203. optimizer=None,
  204. save_interval_epochs=1,
  205. log_interval_steps=2,
  206. save_dir='output',
  207. pretrain_weights='CITYSCAPES',
  208. learning_rate=0.01,
  209. lr_decay_power=0.9,
  210. early_stop=False,
  211. early_stop_patience=5,
  212. use_vdl=True,
  213. resume_checkpoint=None):
  214. """
  215. Train the model.
  216. Args:
  217. num_epochs(int): The number of epochs.
  218. train_dataset(paddlers.dataset): Training dataset.
  219. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
  220. eval_dataset(paddlers.dataset, optional):
  221. Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
  222. optimizer(paddle.optimizer.Optimizer or None, optional):
  223. Optimizer used in training. If None, a default optimizer is used. Defaults to None.
  224. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  225. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  226. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  227. pretrain_weights(str or None, optional):
  228. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'CITYSCAPES'.
  229. learning_rate(float, optional): Learning rate for training. Defaults to .025.
  230. lr_decay_power(float, optional): Learning decay power. Defaults to .9.
  231. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  232. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  233. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  234. resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
  235. If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
  236. `pretrain_weights` can be set simultaneously. Defaults to None.
  237. """
  238. if self.status == 'Infer':
  239. logging.error(
  240. "Exported inference model does not support training.",
  241. exit=True)
  242. if pretrain_weights is not None and resume_checkpoint is not None:
  243. logging.error(
  244. "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
  245. exit=True)
  246. self.labels = train_dataset.labels
  247. if self.losses is None:
  248. self.losses = self.default_loss()
  249. if optimizer is None:
  250. num_steps_each_epoch = train_dataset.num_samples // train_batch_size
  251. self.optimizer = self.default_optimizer(
  252. self.net.parameters(), learning_rate, num_epochs,
  253. num_steps_each_epoch, lr_decay_power)
  254. else:
  255. self.optimizer = optimizer
  256. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  257. if pretrain_weights not in seg_pretrain_weights_dict[
  258. self.model_name]:
  259. logging.warning(
  260. "Path of pretrain_weights('{}') does not exist!".format(
  261. pretrain_weights))
  262. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  263. "If don't want to use pretrain weights, "
  264. "set pretrain_weights to be None.".format(
  265. seg_pretrain_weights_dict[self.model_name][
  266. 0]))
  267. pretrain_weights = seg_pretrain_weights_dict[self.model_name][0]
  268. elif pretrain_weights is not None and osp.exists(pretrain_weights):
  269. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  270. logging.error(
  271. "Invalid pretrain weights. Please specify a '.pdparams' file.",
  272. exit=True)
  273. pretrained_dir = osp.join(save_dir, 'pretrain')
  274. is_backbone_weights = pretrain_weights == 'IMAGENET'
  275. self.net_initialize(
  276. pretrain_weights=pretrain_weights,
  277. save_dir=pretrained_dir,
  278. resume_checkpoint=resume_checkpoint,
  279. is_backbone_weights=is_backbone_weights)
  280. self.train_loop(
  281. num_epochs=num_epochs,
  282. train_dataset=train_dataset,
  283. train_batch_size=train_batch_size,
  284. eval_dataset=eval_dataset,
  285. save_interval_epochs=save_interval_epochs,
  286. log_interval_steps=log_interval_steps,
  287. save_dir=save_dir,
  288. early_stop=early_stop,
  289. early_stop_patience=early_stop_patience,
  290. use_vdl=use_vdl)
  291. def quant_aware_train(self,
  292. num_epochs,
  293. train_dataset,
  294. train_batch_size=2,
  295. eval_dataset=None,
  296. optimizer=None,
  297. save_interval_epochs=1,
  298. log_interval_steps=2,
  299. save_dir='output',
  300. learning_rate=0.0001,
  301. lr_decay_power=0.9,
  302. early_stop=False,
  303. early_stop_patience=5,
  304. use_vdl=True,
  305. resume_checkpoint=None,
  306. quant_config=None):
  307. """
  308. Quantization-aware training.
  309. Args:
  310. num_epochs(int): The number of epochs.
  311. train_dataset(paddlers.dataset): Training dataset.
  312. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
  313. eval_dataset(paddlers.dataset, optional):
  314. Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
  315. optimizer(paddle.optimizer.Optimizer or None, optional):
  316. Optimizer used in training. If None, a default optimizer is used. Defaults to None.
  317. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  318. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  319. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  320. learning_rate(float, optional): Learning rate for training. Defaults to .025.
  321. lr_decay_power(float, optional): Learning decay power. Defaults to .9.
  322. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  323. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  324. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  325. quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
  326. configuration will be used. Defaults to None.
  327. resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
  328. from. If None, no training checkpoint will be resumed. Defaults to None.
  329. """
  330. self._prepare_qat(quant_config)
  331. self.train(
  332. num_epochs=num_epochs,
  333. train_dataset=train_dataset,
  334. train_batch_size=train_batch_size,
  335. eval_dataset=eval_dataset,
  336. optimizer=optimizer,
  337. save_interval_epochs=save_interval_epochs,
  338. log_interval_steps=log_interval_steps,
  339. save_dir=save_dir,
  340. pretrain_weights=None,
  341. learning_rate=learning_rate,
  342. lr_decay_power=lr_decay_power,
  343. early_stop=early_stop,
  344. early_stop_patience=early_stop_patience,
  345. use_vdl=use_vdl,
  346. resume_checkpoint=resume_checkpoint)
  347. def evaluate(self, eval_dataset, batch_size=1, return_details=False):
  348. """
  349. Evaluate the model.
  350. Args:
  351. eval_dataset(paddlers.dataset): Evaluation dataset.
  352. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  353. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  354. Returns:
  355. collections.OrderedDict with key-value pairs:
  356. {"miou": `mean intersection over union`,
  357. "category_iou": `category-wise mean intersection over union`,
  358. "oacc": `overall accuracy`,
  359. "category_acc": `category-wise accuracy`,
  360. "kappa": ` kappa coefficient`,
  361. "category_F1-score": `F1 score`}.
  362. """
  363. self._check_transforms(eval_dataset.transforms, 'eval')
  364. self.net.eval()
  365. nranks = paddle.distributed.get_world_size()
  366. local_rank = paddle.distributed.get_rank()
  367. if nranks > 1:
  368. # Initialize parallel environment if not done.
  369. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  370. ):
  371. paddle.distributed.init_parallel_env()
  372. batch_size_each_card = get_single_card_bs(batch_size)
  373. if batch_size_each_card > 1:
  374. batch_size_each_card = 1
  375. batch_size = batch_size_each_card * paddlers.env_info['num']
  376. logging.warning(
  377. "Segmenter only supports batch_size=1 for each gpu/cpu card " \
  378. "during evaluation, so batch_size " \
  379. "is forcibly set to {}.".format(batch_size))
  380. self.eval_data_loader = self.build_data_loader(
  381. eval_dataset, batch_size=batch_size, mode='eval')
  382. intersect_area_all = 0
  383. pred_area_all = 0
  384. label_area_all = 0
  385. conf_mat_all = []
  386. logging.info(
  387. "Start to evaluate(total_samples={}, total_steps={})...".format(
  388. eval_dataset.num_samples,
  389. math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
  390. with paddle.no_grad():
  391. for step, data in enumerate(self.eval_data_loader):
  392. data.append(eval_dataset.transforms.transforms)
  393. outputs = self.run(self.net, data, 'eval')
  394. pred_area = outputs['pred_area']
  395. label_area = outputs['label_area']
  396. intersect_area = outputs['intersect_area']
  397. conf_mat = outputs['conf_mat']
  398. # Gather from all ranks
  399. if nranks > 1:
  400. intersect_area_list = []
  401. pred_area_list = []
  402. label_area_list = []
  403. conf_mat_list = []
  404. paddle.distributed.all_gather(intersect_area_list,
  405. intersect_area)
  406. paddle.distributed.all_gather(pred_area_list, pred_area)
  407. paddle.distributed.all_gather(label_area_list, label_area)
  408. paddle.distributed.all_gather(conf_mat_list, conf_mat)
  409. # Some image has been evaluated and should be eliminated in last iter
  410. if (step + 1) * nranks > len(eval_dataset):
  411. valid = len(eval_dataset) - step * nranks
  412. intersect_area_list = intersect_area_list[:valid]
  413. pred_area_list = pred_area_list[:valid]
  414. label_area_list = label_area_list[:valid]
  415. conf_mat_list = conf_mat_list[:valid]
  416. intersect_area_all += sum(intersect_area_list)
  417. pred_area_all += sum(pred_area_list)
  418. label_area_all += sum(label_area_list)
  419. conf_mat_all.extend(conf_mat_list)
  420. else:
  421. intersect_area_all = intersect_area_all + intersect_area
  422. pred_area_all = pred_area_all + pred_area
  423. label_area_all = label_area_all + label_area
  424. conf_mat_all.append(conf_mat)
  425. class_iou, miou = paddleseg.utils.metrics.mean_iou(
  426. intersect_area_all, pred_area_all, label_area_all)
  427. # TODO 确认是按oacc还是macc
  428. class_acc, oacc = paddleseg.utils.metrics.accuracy(intersect_area_all,
  429. pred_area_all)
  430. kappa = paddleseg.utils.metrics.kappa(intersect_area_all, pred_area_all,
  431. label_area_all)
  432. category_f1score = metrics.f1_score(intersect_area_all, pred_area_all,
  433. label_area_all)
  434. eval_metrics = OrderedDict(
  435. zip([
  436. 'miou', 'category_iou', 'oacc', 'category_acc', 'kappa',
  437. 'category_F1-score'
  438. ], [miou, class_iou, oacc, class_acc, kappa, category_f1score]))
  439. if return_details:
  440. conf_mat = sum(conf_mat_all)
  441. eval_details = {'confusion_matrix': conf_mat.tolist()}
  442. return eval_metrics, eval_details
  443. return eval_metrics
  444. def predict(self, img_file, transforms=None):
  445. """
  446. Do inference.
  447. Args:
  448. Args:
  449. img_file(list[np.ndarray | str] | str | np.ndarray):
  450. Image path or decoded image data, which also could constitute a list,meaning all images to be
  451. predicted as a mini-batch.
  452. transforms(paddlers.transforms.Compose or None, optional):
  453. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  454. Returns:
  455. If img_file is a string or np.array, the result is a dict with key-value pairs:
  456. {"label map": `label map`, "score_map": `score map`}.
  457. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  458. label_map(np.ndarray): the predicted label map (HW)
  459. score_map(np.ndarray): the prediction score map (HWC)
  460. """
  461. if transforms is None and not hasattr(self, 'test_transforms'):
  462. raise Exception("transforms need to be defined, now is None.")
  463. if transforms is None:
  464. transforms = self.test_transforms
  465. if isinstance(img_file, (str, np.ndarray)):
  466. images = [img_file]
  467. else:
  468. images = img_file
  469. batch_im, batch_origin_shape = self._preprocess(images, transforms,
  470. self.model_type)
  471. self.net.eval()
  472. data = (batch_im, batch_origin_shape, transforms.transforms)
  473. outputs = self.run(self.net, data, 'test')
  474. label_map_list = outputs['label_map']
  475. score_map_list = outputs['score_map']
  476. if isinstance(img_file, list):
  477. prediction = [{
  478. 'label_map': l,
  479. 'score_map': s
  480. } for l, s in zip(label_map_list, score_map_list)]
  481. else:
  482. prediction = {
  483. 'label_map': label_map_list[0],
  484. 'score_map': score_map_list[0]
  485. }
  486. return prediction
  487. def slider_predict(self,
  488. img_file,
  489. save_dir,
  490. block_size,
  491. overlap=36,
  492. transforms=None):
  493. """
  494. Do inference.
  495. Args:
  496. Args:
  497. img_file(str):
  498. Image path.
  499. save_dir(str):
  500. Directory that contains saved geotiff file.
  501. block_size(list[int] | tuple[int] | int):
  502. Size of block.
  503. overlap(list[int] | tuple[int] | int, optional):
  504. Overlap between two blocks. Defaults to 36.
  505. transforms(paddlers.transforms.Compose or None, optional):
  506. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  507. """
  508. try:
  509. from osgeo import gdal
  510. except:
  511. import gdal
  512. if isinstance(block_size, int):
  513. block_size = (block_size, block_size)
  514. elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
  515. block_size = tuple(block_size)
  516. else:
  517. raise ValueError(
  518. "`block_size` must be a tuple/list of length 2 or an integer.")
  519. if isinstance(overlap, int):
  520. overlap = (overlap, overlap)
  521. elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
  522. overlap = tuple(overlap)
  523. else:
  524. raise ValueError(
  525. "`overlap` must be a tuple/list of length 2 or an integer.")
  526. src_data = gdal.Open(img_file)
  527. width = src_data.RasterXSize
  528. height = src_data.RasterYSize
  529. bands = src_data.RasterCount
  530. driver = gdal.GetDriverByName("GTiff")
  531. file_name = osp.splitext(osp.normpath(img_file).split(os.sep)[-1])[
  532. 0] + ".tif"
  533. if not osp.exists(save_dir):
  534. os.makedirs(save_dir)
  535. save_file = osp.join(save_dir, file_name)
  536. dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
  537. dst_data.SetGeoTransform(src_data.GetGeoTransform())
  538. dst_data.SetProjection(src_data.GetProjection())
  539. band = dst_data.GetRasterBand(1)
  540. band.WriteArray(255 * np.ones((height, width), dtype="uint8"))
  541. step = np.array(block_size) - np.array(overlap)
  542. for yoff in range(0, height, step[1]):
  543. for xoff in range(0, width, step[0]):
  544. xsize, ysize = block_size
  545. if xoff + xsize > width:
  546. xsize = int(width - xoff)
  547. if yoff + ysize > height:
  548. ysize = int(height - yoff)
  549. im = src_data.ReadAsArray(int(xoff), int(yoff), xsize,
  550. ysize).transpose((1, 2, 0))
  551. # fill
  552. h, w = im.shape[:2]
  553. im_fill = np.zeros(
  554. (block_size[1], block_size[0], bands), dtype=im.dtype)
  555. im_fill[:h, :w, :] = im
  556. # predict
  557. pred = self.predict(im_fill,
  558. transforms)["label_map"].astype("uint8")
  559. # overlap
  560. rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize)
  561. mask = (rd_block == pred[:h, :w]) | (rd_block == 255)
  562. temp = pred[:h, :w].copy()
  563. temp[mask == False] = 0
  564. band.WriteArray(temp, int(xoff), int(yoff))
  565. dst_data.FlushCache()
  566. dst_data = None
  567. print("GeoTiff saved in {}.".format(save_file))
  568. def _preprocess(self, images, transforms, to_tensor=True):
  569. self._check_transforms(transforms, 'test')
  570. batch_im = list()
  571. batch_ori_shape = list()
  572. for im in images:
  573. if isinstance(im, str):
  574. im = decode_image(im, to_rgb=False)
  575. ori_shape = im.shape[:2]
  576. sample = {'image': im}
  577. im = transforms(sample)[0]
  578. batch_im.append(im)
  579. batch_ori_shape.append(ori_shape)
  580. if to_tensor:
  581. batch_im = paddle.to_tensor(batch_im)
  582. else:
  583. batch_im = np.asarray(batch_im)
  584. return batch_im, batch_ori_shape
  585. @staticmethod
  586. def get_transforms_shape_info(batch_ori_shape, transforms):
  587. batch_restore_list = list()
  588. for ori_shape in batch_ori_shape:
  589. restore_list = list()
  590. h, w = ori_shape[0], ori_shape[1]
  591. for op in transforms:
  592. if op.__class__.__name__ == 'Resize':
  593. restore_list.append(('resize', (h, w)))
  594. h, w = op.target_size
  595. elif op.__class__.__name__ == 'ResizeByShort':
  596. restore_list.append(('resize', (h, w)))
  597. im_short_size = min(h, w)
  598. im_long_size = max(h, w)
  599. scale = float(op.short_size) / float(im_short_size)
  600. if 0 < op.max_size < np.round(scale * im_long_size):
  601. scale = float(op.max_size) / float(im_long_size)
  602. h = int(round(h * scale))
  603. w = int(round(w * scale))
  604. elif op.__class__.__name__ == 'ResizeByLong':
  605. restore_list.append(('resize', (h, w)))
  606. im_long_size = max(h, w)
  607. scale = float(op.long_size) / float(im_long_size)
  608. h = int(round(h * scale))
  609. w = int(round(w * scale))
  610. elif op.__class__.__name__ == 'Pad':
  611. if op.target_size:
  612. target_h, target_w = op.target_size
  613. else:
  614. target_h = int(
  615. (np.ceil(h / op.size_divisor) * op.size_divisor))
  616. target_w = int(
  617. (np.ceil(w / op.size_divisor) * op.size_divisor))
  618. if op.pad_mode == -1:
  619. offsets = op.offsets
  620. elif op.pad_mode == 0:
  621. offsets = [0, 0]
  622. elif op.pad_mode == 1:
  623. offsets = [(target_h - h) // 2, (target_w - w) // 2]
  624. else:
  625. offsets = [target_h - h, target_w - w]
  626. restore_list.append(('padding', (h, w), offsets))
  627. h, w = target_h, target_w
  628. batch_restore_list.append(restore_list)
  629. return batch_restore_list
  630. def _postprocess(self, batch_pred, batch_origin_shape, transforms):
  631. batch_restore_list = BaseSegmenter.get_transforms_shape_info(
  632. batch_origin_shape, transforms)
  633. if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
  634. return self._infer_postprocess(
  635. batch_label_map=batch_pred[0],
  636. batch_score_map=batch_pred[1],
  637. batch_restore_list=batch_restore_list)
  638. results = []
  639. if batch_pred.dtype == paddle.float32:
  640. mode = 'bilinear'
  641. else:
  642. mode = 'nearest'
  643. for pred, restore_list in zip(batch_pred, batch_restore_list):
  644. pred = paddle.unsqueeze(pred, axis=0)
  645. for item in restore_list[::-1]:
  646. h, w = item[1][0], item[1][1]
  647. if item[0] == 'resize':
  648. pred = F.interpolate(
  649. pred, (h, w), mode=mode, data_format='NCHW')
  650. elif item[0] == 'padding':
  651. x, y = item[2]
  652. pred = pred[:, :, y:y + h, x:x + w]
  653. else:
  654. pass
  655. results.append(pred)
  656. return results
  657. def _infer_postprocess(self, batch_label_map, batch_score_map,
  658. batch_restore_list):
  659. label_maps = []
  660. score_maps = []
  661. for label_map, score_map, restore_list in zip(
  662. batch_label_map, batch_score_map, batch_restore_list):
  663. if not isinstance(label_map, np.ndarray):
  664. label_map = paddle.unsqueeze(label_map, axis=[0, 3])
  665. score_map = paddle.unsqueeze(score_map, axis=0)
  666. for item in restore_list[::-1]:
  667. h, w = item[1][0], item[1][1]
  668. if item[0] == 'resize':
  669. if isinstance(label_map, np.ndarray):
  670. label_map = cv2.resize(
  671. label_map, (w, h), interpolation=cv2.INTER_NEAREST)
  672. score_map = cv2.resize(
  673. score_map, (w, h), interpolation=cv2.INTER_LINEAR)
  674. else:
  675. label_map = F.interpolate(
  676. label_map, (h, w),
  677. mode='nearest',
  678. data_format='NHWC')
  679. score_map = F.interpolate(
  680. score_map, (h, w),
  681. mode='bilinear',
  682. data_format='NHWC')
  683. elif item[0] == 'padding':
  684. x, y = item[2]
  685. if isinstance(label_map, np.ndarray):
  686. label_map = label_map[..., y:y + h, x:x + w]
  687. score_map = score_map[..., y:y + h, x:x + w]
  688. else:
  689. label_map = label_map[:, :, y:y + h, x:x + w]
  690. score_map = score_map[:, :, y:y + h, x:x + w]
  691. else:
  692. pass
  693. label_map = label_map.squeeze()
  694. score_map = score_map.squeeze()
  695. if not isinstance(label_map, np.ndarray):
  696. label_map = label_map.numpy()
  697. score_map = score_map.numpy()
  698. label_maps.append(label_map.squeeze())
  699. score_maps.append(score_map.squeeze())
  700. return label_maps, score_maps
  701. def _check_transforms(self, transforms, mode):
  702. super()._check_transforms(transforms, mode)
  703. if not isinstance(transforms.arrange,
  704. paddlers.transforms.ArrangeSegmenter):
  705. raise TypeError(
  706. "`transforms.arrange` must be an ArrangeSegmenter object.")
  707. class UNet(BaseSegmenter):
  708. def __init__(self,
  709. input_channel=3,
  710. num_classes=2,
  711. use_mixed_loss=False,
  712. use_deconv=False,
  713. align_corners=False,
  714. **params):
  715. params.update({
  716. 'use_deconv': use_deconv,
  717. 'align_corners': align_corners
  718. })
  719. super(UNet, self).__init__(
  720. model_name='UNet',
  721. input_channel=input_channel,
  722. num_classes=num_classes,
  723. use_mixed_loss=use_mixed_loss,
  724. **params)
  725. class DeepLabV3P(BaseSegmenter):
  726. def __init__(self,
  727. input_channel=3,
  728. num_classes=2,
  729. backbone='ResNet50_vd',
  730. use_mixed_loss=False,
  731. output_stride=8,
  732. backbone_indices=(0, 3),
  733. aspp_ratios=(1, 12, 24, 36),
  734. aspp_out_channels=256,
  735. align_corners=False,
  736. **params):
  737. self.backbone_name = backbone
  738. if backbone not in ['ResNet50_vd', 'ResNet101_vd']:
  739. raise ValueError(
  740. "backbone: {} is not supported. Please choose one of "
  741. "('ResNet50_vd', 'ResNet101_vd')".format(backbone))
  742. if params.get('with_net', True):
  743. with DisablePrint():
  744. backbone = getattr(paddleseg.models, backbone)(
  745. input_channel=input_channel, output_stride=output_stride)
  746. else:
  747. backbone = None
  748. params.update({
  749. 'backbone': backbone,
  750. 'backbone_indices': backbone_indices,
  751. 'aspp_ratios': aspp_ratios,
  752. 'aspp_out_channels': aspp_out_channels,
  753. 'align_corners': align_corners
  754. })
  755. super(DeepLabV3P, self).__init__(
  756. model_name='DeepLabV3P',
  757. num_classes=num_classes,
  758. use_mixed_loss=use_mixed_loss,
  759. **params)
  760. class FastSCNN(BaseSegmenter):
  761. def __init__(self,
  762. num_classes=2,
  763. use_mixed_loss=False,
  764. align_corners=False,
  765. **params):
  766. params.update({'align_corners': align_corners})
  767. super(FastSCNN, self).__init__(
  768. model_name='FastSCNN',
  769. num_classes=num_classes,
  770. use_mixed_loss=use_mixed_loss,
  771. **params)
  772. class HRNet(BaseSegmenter):
  773. def __init__(self,
  774. num_classes=2,
  775. width=48,
  776. use_mixed_loss=False,
  777. align_corners=False,
  778. **params):
  779. if width not in (18, 48):
  780. raise ValueError(
  781. "width={} is not supported, please choose from [18, 48]".format(
  782. width))
  783. self.backbone_name = 'HRNet_W{}'.format(width)
  784. if params.get('with_net', True):
  785. with DisablePrint():
  786. backbone = getattr(paddleseg.models, self.backbone_name)(
  787. align_corners=align_corners)
  788. else:
  789. backbone = None
  790. params.update({'backbone': backbone, 'align_corners': align_corners})
  791. super(HRNet, self).__init__(
  792. model_name='FCN',
  793. num_classes=num_classes,
  794. use_mixed_loss=use_mixed_loss,
  795. **params)
  796. self.model_name = 'HRNet'
  797. class BiSeNetV2(BaseSegmenter):
  798. def __init__(self,
  799. num_classes=2,
  800. use_mixed_loss=False,
  801. align_corners=False,
  802. **params):
  803. params.update({'align_corners': align_corners})
  804. super(BiSeNetV2, self).__init__(
  805. model_name='BiSeNetV2',
  806. num_classes=num_classes,
  807. use_mixed_loss=use_mixed_loss,
  808. **params)
  809. class FarSeg(BaseSegmenter):
  810. def __init__(self, num_classes=2, use_mixed_loss=False, **params):
  811. super(FarSeg, self).__init__(
  812. model_name='FarSeg',
  813. num_classes=num_classes,
  814. use_mixed_loss=use_mixed_loss,
  815. **params)