changedetector.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import os.path as osp
  16. from collections import OrderedDict
  17. from operator import attrgetter
  18. import cv2
  19. import numpy as np
  20. import paddle
  21. import paddle.nn.functional as F
  22. from paddle.static import InputSpec
  23. import paddlers
  24. import paddlers.custom_models.cd as cmcd
  25. import paddlers.utils.logging as logging
  26. import paddlers.models.ppseg as paddleseg
  27. from paddlers.transforms import arrange_transforms
  28. from paddlers.transforms import ImgDecoder, Resize
  29. from paddlers.utils import get_single_card_bs, DisablePrint
  30. from paddlers.utils.checkpoint import seg_pretrain_weights_dict
  31. from .base import BaseModel
  32. from .utils import seg_metrics as metrics
  33. __all__ = [
  34. "CDNet", "UNetEarlyFusion", "UNetSiamConc", "UNetSiamDiff", "STANet", "BIT",
  35. "SNUNet", "DSIFN", "DSAMNet", "ChangeStar"
  36. ]
  37. class BaseChangeDetector(BaseModel):
  38. def __init__(self,
  39. model_name,
  40. num_classes=2,
  41. use_mixed_loss=False,
  42. **params):
  43. self.init_params = locals()
  44. if 'with_net' in self.init_params:
  45. del self.init_params['with_net']
  46. super(BaseChangeDetector, self).__init__('changedetector')
  47. if model_name not in __all__:
  48. raise Exception("ERROR: There's no model named {}.".format(
  49. model_name))
  50. self.model_name = model_name
  51. self.num_classes = num_classes
  52. self.use_mixed_loss = use_mixed_loss
  53. self.losses = None
  54. self.labels = None
  55. if params.get('with_net', True):
  56. params.pop('with_net', None)
  57. self.net = self.build_net(**params)
  58. self.find_unused_parameters = True
  59. def build_net(self, **params):
  60. # TODO: add other model
  61. net = cmcd.__dict__[self.model_name](num_classes=self.num_classes,
  62. **params)
  63. return net
  64. def _fix_transforms_shape(self, image_shape):
  65. if hasattr(self, 'test_transforms'):
  66. if self.test_transforms is not None:
  67. has_resize_op = False
  68. resize_op_idx = -1
  69. normalize_op_idx = len(self.test_transforms.transforms)
  70. for idx, op in enumerate(self.test_transforms.transforms):
  71. name = op.__class__.__name__
  72. if name == 'Normalize':
  73. normalize_op_idx = idx
  74. if 'Resize' in name:
  75. has_resize_op = True
  76. resize_op_idx = idx
  77. if not has_resize_op:
  78. self.test_transforms.transforms.insert(
  79. normalize_op_idx, Resize(target_size=image_shape))
  80. else:
  81. self.test_transforms.transforms[resize_op_idx] = Resize(
  82. target_size=image_shape)
  83. def _get_test_inputs(self, image_shape):
  84. if image_shape is not None:
  85. if len(image_shape) == 2:
  86. image_shape = [1, 3] + image_shape
  87. self._fix_transforms_shape(image_shape[-2:])
  88. else:
  89. image_shape = [None, 3, -1, -1]
  90. self.fixed_input_shape = image_shape
  91. input_spec = [
  92. InputSpec(
  93. shape=image_shape, name='image', dtype='float32')
  94. ]
  95. return input_spec
  96. def run(self, net, inputs, mode):
  97. net_out = net(inputs[0], inputs[1])
  98. logit = net_out[0]
  99. outputs = OrderedDict()
  100. if mode == 'test':
  101. origin_shape = inputs[2]
  102. if self.status == 'Infer':
  103. label_map_list, score_map_list = self._postprocess(
  104. net_out, origin_shape, transforms=inputs[3])
  105. else:
  106. logit_list = self._postprocess(
  107. logit, origin_shape, transforms=inputs[3])
  108. label_map_list = []
  109. score_map_list = []
  110. for logit in logit_list:
  111. logit = paddle.transpose(logit, perm=[0, 2, 3, 1]) # NHWC
  112. label_map_list.append(
  113. paddle.argmax(
  114. logit, axis=-1, keepdim=False, dtype='int32')
  115. .squeeze().numpy())
  116. score_map_list.append(
  117. F.softmax(
  118. logit, axis=-1).squeeze().numpy().astype('float32'))
  119. outputs['label_map'] = label_map_list
  120. outputs['score_map'] = score_map_list
  121. if mode == 'eval':
  122. if self.status == 'Infer':
  123. pred = paddle.unsqueeze(net_out[0], axis=1) # NCHW
  124. else:
  125. pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
  126. label = inputs[2]
  127. origin_shape = [label.shape[-2:]]
  128. pred = self._postprocess(
  129. pred, origin_shape, transforms=inputs[3])[0] # NCHW
  130. intersect_area, pred_area, label_area = paddleseg.utils.metrics.calculate_area(
  131. pred, label, self.num_classes)
  132. outputs['intersect_area'] = intersect_area
  133. outputs['pred_area'] = pred_area
  134. outputs['label_area'] = label_area
  135. outputs['conf_mat'] = metrics.confusion_matrix(pred, label,
  136. self.num_classes)
  137. if mode == 'train':
  138. if hasattr(net, 'USE_MULTITASK_DECODER') and \
  139. net.USE_MULTITASK_DECODER is True:
  140. # CD+Seg
  141. if len(inputs) != 5:
  142. raise ValueError(
  143. "Cannot perform loss computation with {} inputs.".
  144. format(len(inputs)))
  145. labels_list = [
  146. inputs[2 + idx]
  147. for idx in map(attrgetter('value'), net.OUT_TYPES)
  148. ]
  149. loss_list = metrics.multitask_loss_computation(
  150. logits_list=net_out,
  151. labels_list=labels_list,
  152. losses=self.losses)
  153. else:
  154. loss_list = metrics.loss_computation(
  155. logits_list=net_out, labels=inputs[2], losses=self.losses)
  156. loss = sum(loss_list)
  157. outputs['loss'] = loss
  158. return outputs
  159. def default_loss(self):
  160. if isinstance(self.use_mixed_loss, bool):
  161. if self.use_mixed_loss:
  162. losses = [
  163. paddleseg.models.CrossEntropyLoss(),
  164. paddleseg.models.LovaszSoftmaxLoss()
  165. ]
  166. coef = [.8, .2]
  167. loss_type = [
  168. paddleseg.models.MixedLoss(
  169. losses=losses, coef=coef),
  170. ]
  171. else:
  172. loss_type = [paddleseg.models.CrossEntropyLoss()]
  173. else:
  174. losses, coef = list(zip(*self.use_mixed_loss))
  175. if not set(losses).issubset(
  176. ['CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss']):
  177. raise ValueError(
  178. "Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported."
  179. )
  180. losses = [getattr(paddleseg.models, loss)() for loss in losses]
  181. loss_type = [
  182. paddleseg.models.MixedLoss(
  183. losses=losses, coef=list(coef))
  184. ]
  185. loss_coef = [1.0]
  186. losses = {'types': loss_type, 'coef': loss_coef}
  187. return losses
  188. def default_optimizer(self,
  189. parameters,
  190. learning_rate,
  191. num_epochs,
  192. num_steps_each_epoch,
  193. lr_decay_power=0.9):
  194. decay_step = num_epochs * num_steps_each_epoch
  195. lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
  196. learning_rate, decay_step, end_lr=0, power=lr_decay_power)
  197. optimizer = paddle.optimizer.Momentum(
  198. learning_rate=lr_scheduler,
  199. parameters=parameters,
  200. momentum=0.9,
  201. weight_decay=4e-5)
  202. return optimizer
  203. def train(self,
  204. num_epochs,
  205. train_dataset,
  206. train_batch_size=2,
  207. eval_dataset=None,
  208. optimizer=None,
  209. save_interval_epochs=1,
  210. log_interval_steps=2,
  211. save_dir='output',
  212. pretrain_weights=None,
  213. learning_rate=0.01,
  214. lr_decay_power=0.9,
  215. early_stop=False,
  216. early_stop_patience=5,
  217. use_vdl=True,
  218. resume_checkpoint=None):
  219. """
  220. Train the model.
  221. Args:
  222. num_epochs(int): The number of epochs.
  223. train_dataset(paddlers.dataset): Training dataset.
  224. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
  225. eval_dataset(paddlers.dataset, optional):
  226. Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
  227. optimizer(paddle.optimizer.Optimizer or None, optional):
  228. Optimizer used in training. If None, a default optimizer is used. Defaults to None.
  229. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  230. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  231. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  232. pretrain_weights(str or None, optional):
  233. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to None.
  234. learning_rate(float, optional): Learning rate for training. Defaults to .025.
  235. lr_decay_power(float, optional): Learning decay power. Defaults to .9.
  236. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  237. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  238. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  239. resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
  240. If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
  241. `pretrain_weights` can be set simultaneously. Defaults to None.
  242. """
  243. if self.status == 'Infer':
  244. logging.error(
  245. "Exported inference model does not support training.",
  246. exit=True)
  247. if pretrain_weights is not None and resume_checkpoint is not None:
  248. logging.error(
  249. "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
  250. exit=True)
  251. self.labels = train_dataset.labels
  252. if self.losses is None:
  253. self.losses = self.default_loss()
  254. if optimizer is None:
  255. num_steps_each_epoch = train_dataset.num_samples // train_batch_size
  256. self.optimizer = self.default_optimizer(
  257. self.net.parameters(), learning_rate, num_epochs,
  258. num_steps_each_epoch, lr_decay_power)
  259. else:
  260. self.optimizer = optimizer
  261. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  262. if pretrain_weights not in seg_pretrain_weights_dict[
  263. self.model_name]:
  264. logging.warning(
  265. "Path of pretrain_weights('{}') does not exist!".format(
  266. pretrain_weights))
  267. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  268. "If don't want to use pretrain weights, "
  269. "set pretrain_weights to be None.".format(
  270. seg_pretrain_weights_dict[self.model_name][
  271. 0]))
  272. pretrain_weights = seg_pretrain_weights_dict[self.model_name][0]
  273. elif pretrain_weights is not None and osp.exists(pretrain_weights):
  274. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  275. logging.error(
  276. "Invalid pretrain weights. Please specify a '.pdparams' file.",
  277. exit=True)
  278. pretrained_dir = osp.join(save_dir, 'pretrain')
  279. is_backbone_weights = pretrain_weights == 'IMAGENET'
  280. self.net_initialize(
  281. pretrain_weights=pretrain_weights,
  282. save_dir=pretrained_dir,
  283. resume_checkpoint=resume_checkpoint,
  284. is_backbone_weights=is_backbone_weights)
  285. self.train_loop(
  286. num_epochs=num_epochs,
  287. train_dataset=train_dataset,
  288. train_batch_size=train_batch_size,
  289. eval_dataset=eval_dataset,
  290. save_interval_epochs=save_interval_epochs,
  291. log_interval_steps=log_interval_steps,
  292. save_dir=save_dir,
  293. early_stop=early_stop,
  294. early_stop_patience=early_stop_patience,
  295. use_vdl=use_vdl)
  296. def quant_aware_train(self,
  297. num_epochs,
  298. train_dataset,
  299. train_batch_size=2,
  300. eval_dataset=None,
  301. optimizer=None,
  302. save_interval_epochs=1,
  303. log_interval_steps=2,
  304. save_dir='output',
  305. learning_rate=0.0001,
  306. lr_decay_power=0.9,
  307. early_stop=False,
  308. early_stop_patience=5,
  309. use_vdl=True,
  310. resume_checkpoint=None,
  311. quant_config=None):
  312. """
  313. Quantization-aware training.
  314. Args:
  315. num_epochs(int): The number of epochs.
  316. train_dataset(paddlers.dataset): Training dataset.
  317. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
  318. eval_dataset(paddlers.dataset, optional):
  319. Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
  320. optimizer(paddle.optimizer.Optimizer or None, optional):
  321. Optimizer used in training. If None, a default optimizer is used. Defaults to None.
  322. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  323. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  324. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  325. learning_rate(float, optional): Learning rate for training. Defaults to .025.
  326. lr_decay_power(float, optional): Learning decay power. Defaults to .9.
  327. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  328. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  329. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  330. quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
  331. configuration will be used. Defaults to None.
  332. resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
  333. from. If None, no training checkpoint will be resumed. Defaults to None.
  334. """
  335. self._prepare_qat(quant_config)
  336. self.train(
  337. num_epochs=num_epochs,
  338. train_dataset=train_dataset,
  339. train_batch_size=train_batch_size,
  340. eval_dataset=eval_dataset,
  341. optimizer=optimizer,
  342. save_interval_epochs=save_interval_epochs,
  343. log_interval_steps=log_interval_steps,
  344. save_dir=save_dir,
  345. pretrain_weights=None,
  346. learning_rate=learning_rate,
  347. lr_decay_power=lr_decay_power,
  348. early_stop=early_stop,
  349. early_stop_patience=early_stop_patience,
  350. use_vdl=use_vdl,
  351. resume_checkpoint=resume_checkpoint)
  352. def evaluate(self, eval_dataset, batch_size=1, return_details=False):
  353. """
  354. Evaluate the model.
  355. Args:
  356. eval_dataset(paddlers.dataset): Evaluation dataset.
  357. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  358. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  359. Returns:
  360. collections.OrderedDict with key-value pairs:
  361. {"miou": `mean intersection over union`,
  362. "category_iou": `category-wise mean intersection over union`,
  363. "oacc": `overall accuracy`,
  364. "category_acc": `category-wise accuracy`,
  365. "kappa": ` kappa coefficient`,
  366. "category_F1-score": `F1 score`}.
  367. """
  368. arrange_transforms(
  369. model_type=self.model_type,
  370. transforms=eval_dataset.transforms,
  371. mode='eval')
  372. self.net.eval()
  373. nranks = paddle.distributed.get_world_size()
  374. local_rank = paddle.distributed.get_rank()
  375. if nranks > 1:
  376. # Initialize parallel environment if not done.
  377. if not (paddle.distributed.parallel.parallel_helper.
  378. _is_parallel_ctx_initialized()):
  379. paddle.distributed.init_parallel_env()
  380. batch_size_each_card = get_single_card_bs(batch_size)
  381. if batch_size_each_card > 1:
  382. batch_size_each_card = 1
  383. batch_size = batch_size_each_card * paddlers.env_info['num']
  384. logging.warning(
  385. "Segmenter only supports batch_size=1 for each gpu/cpu card " \
  386. "during evaluation, so batch_size " \
  387. "is forcibly set to {}.".format(batch_size)
  388. )
  389. self.eval_data_loader = self.build_data_loader(
  390. eval_dataset, batch_size=batch_size, mode='eval')
  391. intersect_area_all = 0
  392. pred_area_all = 0
  393. label_area_all = 0
  394. conf_mat_all = []
  395. logging.info(
  396. "Start to evaluate(total_samples={}, total_steps={})...".format(
  397. eval_dataset.num_samples,
  398. math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
  399. with paddle.no_grad():
  400. for step, data in enumerate(self.eval_data_loader):
  401. data.append(eval_dataset.transforms.transforms)
  402. outputs = self.run(self.net, data, 'eval')
  403. pred_area = outputs['pred_area']
  404. label_area = outputs['label_area']
  405. intersect_area = outputs['intersect_area']
  406. conf_mat = outputs['conf_mat']
  407. # Gather from all ranks
  408. if nranks > 1:
  409. intersect_area_list = []
  410. pred_area_list = []
  411. label_area_list = []
  412. conf_mat_list = []
  413. paddle.distributed.all_gather(intersect_area_list,
  414. intersect_area)
  415. paddle.distributed.all_gather(pred_area_list, pred_area)
  416. paddle.distributed.all_gather(label_area_list, label_area)
  417. paddle.distributed.all_gather(conf_mat_list, conf_mat)
  418. # Some image has been evaluated and should be eliminated in last iter
  419. if (step + 1) * nranks > len(eval_dataset):
  420. valid = len(eval_dataset) - step * nranks
  421. intersect_area_list = intersect_area_list[:valid]
  422. pred_area_list = pred_area_list[:valid]
  423. label_area_list = label_area_list[:valid]
  424. conf_mat_list = conf_mat_list[:valid]
  425. intersect_area_all += sum(intersect_area_list)
  426. pred_area_all += sum(pred_area_list)
  427. label_area_all += sum(label_area_list)
  428. conf_mat_all.extend(conf_mat_list)
  429. else:
  430. intersect_area_all = intersect_area_all + intersect_area
  431. pred_area_all = pred_area_all + pred_area
  432. label_area_all = label_area_all + label_area
  433. conf_mat_all.append(conf_mat)
  434. class_iou, miou = paddleseg.utils.metrics.mean_iou(
  435. intersect_area_all, pred_area_all, label_area_all)
  436. # TODO 确认是按oacc还是macc
  437. class_acc, oacc = paddleseg.utils.metrics.accuracy(intersect_area_all,
  438. pred_area_all)
  439. kappa = paddleseg.utils.metrics.kappa(intersect_area_all, pred_area_all,
  440. label_area_all)
  441. category_f1score = metrics.f1_score(intersect_area_all, pred_area_all,
  442. label_area_all)
  443. eval_metrics = OrderedDict(
  444. zip([
  445. 'miou', 'category_iou', 'oacc', 'category_acc', 'kappa',
  446. 'category_F1-score'
  447. ], [miou, class_iou, oacc, class_acc, kappa, category_f1score]))
  448. if return_details:
  449. conf_mat = sum(conf_mat_all)
  450. eval_details = {'confusion_matrix': conf_mat.tolist()}
  451. return eval_metrics, eval_details
  452. return eval_metrics
  453. def predict(self, img_file, transforms=None):
  454. """
  455. Do inference.
  456. Args:
  457. Args:
  458. img_file(List[np.ndarray or str], str or np.ndarray):
  459. Image path or decoded image data in a BGR format, which also could constitute a list,
  460. meaning all images to be predicted as a mini-batch.
  461. transforms(paddlers.transforms.Compose or None, optional):
  462. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  463. Returns:
  464. If img_file is a string or np.array, the result is a dict with key-value pairs:
  465. {"label map": `label map`, "score_map": `score map`}.
  466. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  467. label_map(np.ndarray): the predicted label map (HW)
  468. score_map(np.ndarray): the prediction score map (HWC)
  469. """
  470. if transforms is None and not hasattr(self, 'test_transforms'):
  471. raise Exception("transforms need to be defined, now is None.")
  472. if transforms is None:
  473. transforms = self.test_transforms
  474. if isinstance(img_file, (str, np.ndarray)):
  475. images = [img_file]
  476. else:
  477. images = img_file
  478. batch_im, batch_origin_shape = self._preprocess(images, transforms,
  479. self.model_type)
  480. self.net.eval()
  481. data = (batch_im, batch_origin_shape, transforms.transforms)
  482. outputs = self.run(self.net, data, 'test')
  483. label_map_list = outputs['label_map']
  484. score_map_list = outputs['score_map']
  485. if isinstance(img_file, list):
  486. prediction = [{
  487. 'label_map': l,
  488. 'score_map': s
  489. } for l, s in zip(label_map_list, score_map_list)]
  490. else:
  491. prediction = {
  492. 'label_map': label_map_list[0],
  493. 'score_map': score_map_list[0]
  494. }
  495. return prediction
  496. def _preprocess(self, images, transforms, to_tensor=True):
  497. arrange_transforms(
  498. model_type=self.model_type, transforms=transforms, mode='test')
  499. batch_im = list()
  500. batch_ori_shape = list()
  501. for im in images:
  502. sample = {'image': im}
  503. if isinstance(sample['image'], str):
  504. sample = ImgDecoder(to_rgb=False)(sample)
  505. ori_shape = sample['image'].shape[:2]
  506. im = transforms(sample)[0]
  507. batch_im.append(im)
  508. batch_ori_shape.append(ori_shape)
  509. if to_tensor:
  510. batch_im = paddle.to_tensor(batch_im)
  511. else:
  512. batch_im = np.asarray(batch_im)
  513. return batch_im, batch_ori_shape
  514. @staticmethod
  515. def get_transforms_shape_info(batch_ori_shape, transforms):
  516. batch_restore_list = list()
  517. for ori_shape in batch_ori_shape:
  518. restore_list = list()
  519. h, w = ori_shape[0], ori_shape[1]
  520. for op in transforms:
  521. if op.__class__.__name__ == 'Resize':
  522. restore_list.append(('resize', (h, w)))
  523. h, w = op.target_size
  524. elif op.__class__.__name__ == 'ResizeByShort':
  525. restore_list.append(('resize', (h, w)))
  526. im_short_size = min(h, w)
  527. im_long_size = max(h, w)
  528. scale = float(op.short_size) / float(im_short_size)
  529. if 0 < op.max_size < np.round(scale * im_long_size):
  530. scale = float(op.max_size) / float(im_long_size)
  531. h = int(round(h * scale))
  532. w = int(round(w * scale))
  533. elif op.__class__.__name__ == 'ResizeByLong':
  534. restore_list.append(('resize', (h, w)))
  535. im_long_size = max(h, w)
  536. scale = float(op.long_size) / float(im_long_size)
  537. h = int(round(h * scale))
  538. w = int(round(w * scale))
  539. elif op.__class__.__name__ == 'Padding':
  540. if op.target_size:
  541. target_h, target_w = op.target_size
  542. else:
  543. target_h = int(
  544. (np.ceil(h / op.size_divisor) * op.size_divisor))
  545. target_w = int(
  546. (np.ceil(w / op.size_divisor) * op.size_divisor))
  547. if op.pad_mode == -1:
  548. offsets = op.offsets
  549. elif op.pad_mode == 0:
  550. offsets = [0, 0]
  551. elif op.pad_mode == 1:
  552. offsets = [(target_h - h) // 2, (target_w - w) // 2]
  553. else:
  554. offsets = [target_h - h, target_w - w]
  555. restore_list.append(('padding', (h, w), offsets))
  556. h, w = target_h, target_w
  557. batch_restore_list.append(restore_list)
  558. return batch_restore_list
  559. def _postprocess(self, batch_pred, batch_origin_shape, transforms):
  560. batch_restore_list = BaseChangeDetector.get_transforms_shape_info(
  561. batch_origin_shape, transforms)
  562. if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
  563. return self._infer_postprocess(
  564. batch_label_map=batch_pred[0],
  565. batch_score_map=batch_pred[1],
  566. batch_restore_list=batch_restore_list)
  567. results = []
  568. if batch_pred.dtype == paddle.float32:
  569. mode = 'bilinear'
  570. else:
  571. mode = 'nearest'
  572. for pred, restore_list in zip(batch_pred, batch_restore_list):
  573. pred = paddle.unsqueeze(pred, axis=0)
  574. for item in restore_list[::-1]:
  575. h, w = item[1][0], item[1][1]
  576. if item[0] == 'resize':
  577. pred = F.interpolate(
  578. pred, (h, w), mode=mode, data_format='NCHW')
  579. elif item[0] == 'padding':
  580. x, y = item[2]
  581. pred = pred[:, :, y:y + h, x:x + w]
  582. else:
  583. pass
  584. results.append(pred)
  585. return results
  586. def _infer_postprocess(self, batch_label_map, batch_score_map,
  587. batch_restore_list):
  588. label_maps = []
  589. score_maps = []
  590. for label_map, score_map, restore_list in zip(
  591. batch_label_map, batch_score_map, batch_restore_list):
  592. if not isinstance(label_map, np.ndarray):
  593. label_map = paddle.unsqueeze(label_map, axis=[0, 3])
  594. score_map = paddle.unsqueeze(score_map, axis=0)
  595. for item in restore_list[::-1]:
  596. h, w = item[1][0], item[1][1]
  597. if item[0] == 'resize':
  598. if isinstance(label_map, np.ndarray):
  599. label_map = cv2.resize(
  600. label_map, (w, h), interpolation=cv2.INTER_NEAREST)
  601. score_map = cv2.resize(
  602. score_map, (w, h), interpolation=cv2.INTER_LINEAR)
  603. else:
  604. label_map = F.interpolate(
  605. label_map, (h, w),
  606. mode='nearest',
  607. data_format='NHWC')
  608. score_map = F.interpolate(
  609. score_map, (h, w),
  610. mode='bilinear',
  611. data_format='NHWC')
  612. elif item[0] == 'padding':
  613. x, y = item[2]
  614. if isinstance(label_map, np.ndarray):
  615. label_map = label_map[..., y:y + h, x:x + w]
  616. score_map = score_map[..., y:y + h, x:x + w]
  617. else:
  618. label_map = label_map[:, :, y:y + h, x:x + w]
  619. score_map = score_map[:, :, y:y + h, x:x + w]
  620. else:
  621. pass
  622. label_map = label_map.squeeze()
  623. score_map = score_map.squeeze()
  624. if not isinstance(label_map, np.ndarray):
  625. label_map = label_map.numpy()
  626. score_map = score_map.numpy()
  627. label_maps.append(label_map.squeeze())
  628. score_maps.append(score_map.squeeze())
  629. return label_maps, score_maps
  630. class CDNet(BaseChangeDetector):
  631. def __init__(self,
  632. num_classes=2,
  633. use_mixed_loss=False,
  634. in_channels=6,
  635. **params):
  636. params.update({'in_channels': in_channels})
  637. super(CDNet, self).__init__(
  638. model_name='CDNet',
  639. num_classes=num_classes,
  640. use_mixed_loss=use_mixed_loss,
  641. **params)
  642. class UNetEarlyFusion(BaseChangeDetector):
  643. def __init__(self,
  644. num_classes=2,
  645. use_mixed_loss=False,
  646. in_channels=6,
  647. use_dropout=False,
  648. **params):
  649. params.update({'in_channels': in_channels, 'use_dropout': use_dropout})
  650. super(UNetEarlyFusion, self).__init__(
  651. model_name='UNetEarlyFusion',
  652. num_classes=num_classes,
  653. use_mixed_loss=use_mixed_loss,
  654. **params)
  655. class UNetSiamConc(BaseChangeDetector):
  656. def __init__(self,
  657. num_classes=2,
  658. use_mixed_loss=False,
  659. in_channels=3,
  660. use_dropout=False,
  661. **params):
  662. params.update({'in_channels': in_channels, 'use_dropout': use_dropout})
  663. super(UNetSiamConc, self).__init__(
  664. model_name='UNetSiamConc',
  665. num_classes=num_classes,
  666. use_mixed_loss=use_mixed_loss,
  667. **params)
  668. class UNetSiamDiff(BaseChangeDetector):
  669. def __init__(self,
  670. num_classes=2,
  671. use_mixed_loss=False,
  672. in_channels=3,
  673. use_dropout=False,
  674. **params):
  675. params.update({'in_channels': in_channels, 'use_dropout': use_dropout})
  676. super(UNetSiamDiff, self).__init__(
  677. model_name='UNetSiamDiff',
  678. num_classes=num_classes,
  679. use_mixed_loss=use_mixed_loss,
  680. **params)
  681. class STANet(BaseChangeDetector):
  682. def __init__(self,
  683. num_classes=2,
  684. use_mixed_loss=False,
  685. in_channels=3,
  686. att_type='BAM',
  687. ds_factor=1,
  688. **params):
  689. params.update({
  690. 'in_channels': in_channels,
  691. 'att_type': att_type,
  692. 'ds_factor': ds_factor
  693. })
  694. super(STANet, self).__init__(
  695. model_name='STANet',
  696. num_classes=num_classes,
  697. use_mixed_loss=use_mixed_loss,
  698. **params)
  699. class BIT(BaseChangeDetector):
  700. def __init__(self,
  701. num_classes=2,
  702. use_mixed_loss=False,
  703. in_channels=3,
  704. backbone='resnet18',
  705. n_stages=4,
  706. use_tokenizer=True,
  707. token_len=4,
  708. pool_mode='max',
  709. pool_size=2,
  710. enc_with_pos=True,
  711. enc_depth=1,
  712. enc_head_dim=64,
  713. dec_depth=8,
  714. dec_head_dim=8,
  715. **params):
  716. params.update({
  717. 'in_channels': in_channels,
  718. 'backbone': backbone,
  719. 'n_stages': n_stages,
  720. 'use_tokenizer': use_tokenizer,
  721. 'token_len': token_len,
  722. 'pool_mode': pool_mode,
  723. 'pool_size': pool_size,
  724. 'enc_with_pos': enc_with_pos,
  725. 'enc_depth': enc_depth,
  726. 'enc_head_dim': enc_head_dim,
  727. 'dec_depth': dec_depth,
  728. 'dec_head_dim': dec_head_dim
  729. })
  730. super(BIT, self).__init__(
  731. model_name='BIT',
  732. num_classes=num_classes,
  733. use_mixed_loss=use_mixed_loss,
  734. **params)
  735. class SNUNet(BaseChangeDetector):
  736. def __init__(self,
  737. num_classes=2,
  738. use_mixed_loss=False,
  739. in_channels=3,
  740. width=32,
  741. **params):
  742. params.update({'in_channels': in_channels, 'width': width})
  743. super(SNUNet, self).__init__(
  744. model_name='SNUNet',
  745. num_classes=num_classes,
  746. use_mixed_loss=use_mixed_loss,
  747. **params)
  748. class DSIFN(BaseChangeDetector):
  749. def __init__(self,
  750. num_classes=2,
  751. use_mixed_loss=False,
  752. use_dropout=False,
  753. **params):
  754. params.update({'use_dropout': use_dropout})
  755. super(DSIFN, self).__init__(
  756. model_name='DSIFN',
  757. num_classes=num_classes,
  758. use_mixed_loss=use_mixed_loss,
  759. **params)
  760. def default_loss(self):
  761. if self.use_mixed_loss is False:
  762. return {
  763. # XXX: make sure the shallow copy works correctly here.
  764. 'types': [paddleseg.models.CrossEntropyLoss()] * 5,
  765. 'coef': [1.0] * 5
  766. }
  767. else:
  768. return super().default_loss()
  769. class DSAMNet(BaseChangeDetector):
  770. def __init__(self,
  771. num_classes=2,
  772. use_mixed_loss=False,
  773. in_channels=3,
  774. ca_ratio=8,
  775. sa_kernel=7,
  776. **params):
  777. params.update({
  778. 'in_channels': in_channels,
  779. 'ca_ratio': ca_ratio,
  780. 'sa_kernel': sa_kernel
  781. })
  782. super(DSAMNet, self).__init__(
  783. model_name='DSAMNet',
  784. num_classes=num_classes,
  785. use_mixed_loss=use_mixed_loss,
  786. **params)
  787. def default_loss(self):
  788. if self.use_mixed_loss is False:
  789. return {
  790. 'types': [
  791. paddleseg.models.CrossEntropyLoss(),
  792. paddleseg.models.DiceLoss(), paddleseg.models.DiceLoss()
  793. ],
  794. 'coef': [1.0, 0.05, 0.05]
  795. }
  796. else:
  797. return super().default_loss()
  798. class ChangeStar(BaseChangeDetector):
  799. def __init__(self,
  800. num_classes=2,
  801. use_mixed_loss=False,
  802. mid_channels=256,
  803. inner_channels=16,
  804. num_convs=4,
  805. scale_factor=4.0,
  806. **params):
  807. params.update({
  808. 'mid_channels': mid_channels,
  809. 'inner_channels': inner_channels,
  810. 'num_convs': num_convs,
  811. 'scale_factor': scale_factor
  812. })
  813. super(ChangeStar, self).__init__(
  814. model_name='ChangeStar',
  815. num_classes=num_classes,
  816. use_mixed_loss=use_mixed_loss,
  817. **params)
  818. def default_loss(self):
  819. if self.use_mixed_loss is False:
  820. return {
  821. # XXX: make sure the shallow copy works correctly here.
  822. 'types': [paddleseg.models.CrossEntropyLoss()] * 4,
  823. 'coef': [1.0] * 4
  824. }
  825. else:
  826. return super().default_loss()