change_detector.py 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import os
  16. import os.path as osp
  17. from collections import OrderedDict
  18. from operator import attrgetter
  19. import cv2
  20. import numpy as np
  21. import paddle
  22. import paddle.nn.functional as F
  23. from paddle.static import InputSpec
  24. import paddlers
  25. import paddlers.models.ppseg as ppseg
  26. import paddlers.rs_models.cd as cmcd
  27. import paddlers.utils.logging as logging
  28. from paddlers.models import seg_losses
  29. from paddlers.transforms import Resize, decode_image
  30. from paddlers.utils import get_single_card_bs
  31. from paddlers.utils.checkpoint import seg_pretrain_weights_dict
  32. from .base import BaseModel
  33. from .utils import seg_metrics as metrics
  34. from .utils.infer_nets import InferCDNet
  35. __all__ = [
  36. "CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT",
  37. "SNUNet", "DSIFN", "DSAMNet", "ChangeStar", "ChangeFormer"
  38. ]
  39. class BaseChangeDetector(BaseModel):
  40. def __init__(self,
  41. model_name,
  42. num_classes=2,
  43. use_mixed_loss=False,
  44. losses=None,
  45. **params):
  46. self.init_params = locals()
  47. if 'with_net' in self.init_params:
  48. del self.init_params['with_net']
  49. super(BaseChangeDetector, self).__init__('change_detector')
  50. if model_name not in __all__:
  51. raise ValueError("ERROR: There is no model named {}.".format(
  52. model_name))
  53. self.model_name = model_name
  54. self.num_classes = num_classes
  55. self.use_mixed_loss = use_mixed_loss
  56. self.losses = losses
  57. self.labels = None
  58. if params.get('with_net', True):
  59. params.pop('with_net', None)
  60. self.net = self.build_net(**params)
  61. self.find_unused_parameters = True
  62. def build_net(self, **params):
  63. # TODO: add other model
  64. net = cmcd.__dict__[self.model_name](num_classes=self.num_classes,
  65. **params)
  66. return net
  67. def _build_inference_net(self):
  68. infer_net = InferCDNet(self.net)
  69. infer_net.eval()
  70. return infer_net
  71. def _fix_transforms_shape(self, image_shape):
  72. if hasattr(self, 'test_transforms'):
  73. if self.test_transforms is not None:
  74. has_resize_op = False
  75. resize_op_idx = -1
  76. normalize_op_idx = len(self.test_transforms.transforms)
  77. for idx, op in enumerate(self.test_transforms.transforms):
  78. name = op.__class__.__name__
  79. if name == 'Normalize':
  80. normalize_op_idx = idx
  81. if 'Resize' in name:
  82. has_resize_op = True
  83. resize_op_idx = idx
  84. if not has_resize_op:
  85. self.test_transforms.transforms.insert(
  86. normalize_op_idx, Resize(target_size=image_shape))
  87. else:
  88. self.test_transforms.transforms[resize_op_idx] = Resize(
  89. target_size=image_shape)
  90. def _get_test_inputs(self, image_shape):
  91. if image_shape is not None:
  92. if len(image_shape) == 2:
  93. image_shape = [1, 3] + image_shape
  94. self._fix_transforms_shape(image_shape[-2:])
  95. else:
  96. image_shape = [None, 3, -1, -1]
  97. self.fixed_input_shape = image_shape
  98. return [
  99. InputSpec(
  100. shape=image_shape, name='image', dtype='float32'), InputSpec(
  101. shape=image_shape, name='image2', dtype='float32')
  102. ]
  103. def run(self, net, inputs, mode):
  104. net_out = net(inputs[0], inputs[1])
  105. logit = net_out[0]
  106. outputs = OrderedDict()
  107. if mode == 'test':
  108. origin_shape = inputs[2]
  109. if self.status == 'Infer':
  110. label_map_list, score_map_list = self._postprocess(
  111. net_out, origin_shape, transforms=inputs[3])
  112. else:
  113. logit_list = self._postprocess(
  114. logit, origin_shape, transforms=inputs[3])
  115. label_map_list = []
  116. score_map_list = []
  117. for logit in logit_list:
  118. logit = paddle.transpose(logit, perm=[0, 2, 3, 1]) # NHWC
  119. label_map_list.append(
  120. paddle.argmax(
  121. logit, axis=-1, keepdim=False, dtype='int32')
  122. .squeeze().numpy())
  123. score_map_list.append(
  124. F.softmax(
  125. logit, axis=-1).squeeze().numpy().astype('float32'))
  126. outputs['label_map'] = label_map_list
  127. outputs['score_map'] = score_map_list
  128. if mode == 'eval':
  129. if self.status == 'Infer':
  130. pred = paddle.unsqueeze(net_out[0], axis=1) # NCHW
  131. else:
  132. pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
  133. label = inputs[2]
  134. if label.ndim == 3:
  135. paddle.unsqueeze_(label, axis=1)
  136. if label.ndim != 4:
  137. raise ValueError("Expected label.ndim == 4 but got {}".format(
  138. label.ndim))
  139. origin_shape = [label.shape[-2:]]
  140. pred = self._postprocess(
  141. pred, origin_shape, transforms=inputs[3])[0] # NCHW
  142. intersect_area, pred_area, label_area = ppseg.utils.metrics.calculate_area(
  143. pred, label, self.num_classes)
  144. outputs['intersect_area'] = intersect_area
  145. outputs['pred_area'] = pred_area
  146. outputs['label_area'] = label_area
  147. outputs['conf_mat'] = metrics.confusion_matrix(pred, label,
  148. self.num_classes)
  149. if mode == 'train':
  150. if hasattr(net, 'USE_MULTITASK_DECODER') and \
  151. net.USE_MULTITASK_DECODER is True:
  152. # CD+Seg
  153. if len(inputs) != 5:
  154. raise ValueError(
  155. "Cannot perform loss computation with {} inputs.".
  156. format(len(inputs)))
  157. labels_list = [
  158. inputs[2 + idx]
  159. for idx in map(attrgetter('value'), net.OUT_TYPES)
  160. ]
  161. loss_list = metrics.multitask_loss_computation(
  162. logits_list=net_out,
  163. labels_list=labels_list,
  164. losses=self.losses)
  165. else:
  166. loss_list = metrics.loss_computation(
  167. logits_list=net_out, labels=inputs[2], losses=self.losses)
  168. loss = sum(loss_list)
  169. outputs['loss'] = loss
  170. return outputs
  171. def default_loss(self):
  172. if isinstance(self.use_mixed_loss, bool):
  173. if self.use_mixed_loss:
  174. losses = [
  175. seg_losses.CrossEntropyLoss(),
  176. seg_losses.LovaszSoftmaxLoss()
  177. ]
  178. coef = [.8, .2]
  179. loss_type = [seg_losses.MixedLoss(losses=losses, coef=coef), ]
  180. else:
  181. loss_type = [seg_losses.CrossEntropyLoss()]
  182. else:
  183. losses, coef = list(zip(*self.use_mixed_loss))
  184. if not set(losses).issubset(
  185. ['CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss']):
  186. raise ValueError(
  187. "Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported."
  188. )
  189. losses = [getattr(seg_losses, loss)() for loss in losses]
  190. loss_type = [seg_losses.MixedLoss(losses=losses, coef=list(coef))]
  191. loss_coef = [1.0]
  192. losses = {'types': loss_type, 'coef': loss_coef}
  193. return losses
  194. def default_optimizer(self,
  195. parameters,
  196. learning_rate,
  197. num_epochs,
  198. num_steps_each_epoch,
  199. lr_decay_power=0.9):
  200. decay_step = num_epochs * num_steps_each_epoch
  201. lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
  202. learning_rate, decay_step, end_lr=0, power=lr_decay_power)
  203. optimizer = paddle.optimizer.Momentum(
  204. learning_rate=lr_scheduler,
  205. parameters=parameters,
  206. momentum=0.9,
  207. weight_decay=4e-5)
  208. return optimizer
  209. def train(self,
  210. num_epochs,
  211. train_dataset,
  212. train_batch_size=2,
  213. eval_dataset=None,
  214. optimizer=None,
  215. save_interval_epochs=1,
  216. log_interval_steps=2,
  217. save_dir='output',
  218. pretrain_weights=None,
  219. learning_rate=0.01,
  220. lr_decay_power=0.9,
  221. early_stop=False,
  222. early_stop_patience=5,
  223. use_vdl=True,
  224. resume_checkpoint=None):
  225. """
  226. Train the model.
  227. Args:
  228. num_epochs (int): Number of epochs.
  229. train_dataset (paddlers.datasets.CDDataset): Training dataset.
  230. train_batch_size (int, optional): Total batch size among all cards used in
  231. training. Defaults to 2.
  232. eval_dataset (paddlers.datasets.CDDataset|None, optional): Evaluation dataset.
  233. If None, the model will not be evaluated during training process.
  234. Defaults to None.
  235. optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in
  236. training. If None, a default optimizer will be used. Defaults to None.
  237. save_interval_epochs (int, optional): Epoch interval for saving the model.
  238. Defaults to 1.
  239. log_interval_steps (int, optional): Step interval for printing training
  240. information. Defaults to 2.
  241. save_dir (str, optional): Directory to save the model. Defaults to 'output'.
  242. pretrain_weights (str|None, optional): None or name/path of pretrained
  243. weights. If None, no pretrained weights will be loaded. Defaults to None.
  244. learning_rate (float, optional): Learning rate for training. Defaults to .01.
  245. lr_decay_power (float, optional): Learning decay power. Defaults to .9.
  246. early_stop (bool, optional): Whether to adopt early stop strategy. Defaults
  247. to False.
  248. early_stop_patience (int, optional): Early stop patience. Defaults to 5.
  249. use_vdl (bool, optional): Whether to use VisualDL to monitor the training
  250. process. Defaults to True.
  251. resume_checkpoint (str|None, optional): Path of the checkpoint to resume
  252. training from. If None, no training checkpoint will be resumed. At most
  253. Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
  254. Defaults to None.
  255. """
  256. if self.status == 'Infer':
  257. logging.error(
  258. "Exported inference model does not support training.",
  259. exit=True)
  260. if pretrain_weights is not None and resume_checkpoint is not None:
  261. logging.error(
  262. "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
  263. exit=True)
  264. self.labels = train_dataset.labels
  265. if self.losses is None:
  266. self.losses = self.default_loss()
  267. if optimizer is None:
  268. num_steps_each_epoch = train_dataset.num_samples // train_batch_size
  269. self.optimizer = self.default_optimizer(
  270. self.net.parameters(), learning_rate, num_epochs,
  271. num_steps_each_epoch, lr_decay_power)
  272. else:
  273. self.optimizer = optimizer
  274. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  275. if pretrain_weights not in seg_pretrain_weights_dict[
  276. self.model_name]:
  277. logging.warning(
  278. "Path of pretrain_weights('{}') does not exist!".format(
  279. pretrain_weights))
  280. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  281. "If don't want to use pretrain weights, "
  282. "set pretrain_weights to be None.".format(
  283. seg_pretrain_weights_dict[self.model_name][
  284. 0]))
  285. pretrain_weights = seg_pretrain_weights_dict[self.model_name][0]
  286. elif pretrain_weights is not None and osp.exists(pretrain_weights):
  287. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  288. logging.error(
  289. "Invalid pretrain weights. Please specify a '.pdparams' file.",
  290. exit=True)
  291. pretrained_dir = osp.join(save_dir, 'pretrain')
  292. is_backbone_weights = pretrain_weights == 'IMAGENET'
  293. self.net_initialize(
  294. pretrain_weights=pretrain_weights,
  295. save_dir=pretrained_dir,
  296. resume_checkpoint=resume_checkpoint,
  297. is_backbone_weights=is_backbone_weights)
  298. self.train_loop(
  299. num_epochs=num_epochs,
  300. train_dataset=train_dataset,
  301. train_batch_size=train_batch_size,
  302. eval_dataset=eval_dataset,
  303. save_interval_epochs=save_interval_epochs,
  304. log_interval_steps=log_interval_steps,
  305. save_dir=save_dir,
  306. early_stop=early_stop,
  307. early_stop_patience=early_stop_patience,
  308. use_vdl=use_vdl)
  309. def quant_aware_train(self,
  310. num_epochs,
  311. train_dataset,
  312. train_batch_size=2,
  313. eval_dataset=None,
  314. optimizer=None,
  315. save_interval_epochs=1,
  316. log_interval_steps=2,
  317. save_dir='output',
  318. learning_rate=0.0001,
  319. lr_decay_power=0.9,
  320. early_stop=False,
  321. early_stop_patience=5,
  322. use_vdl=True,
  323. resume_checkpoint=None,
  324. quant_config=None):
  325. """
  326. Quantization-aware training.
  327. Args:
  328. num_epochs (int): Number of epochs.
  329. train_dataset (paddlers.datasets.CDDataset): Training dataset.
  330. train_batch_size (int, optional): Total batch size among all cards used in
  331. training. Defaults to 2.
  332. eval_dataset (paddlers.datasets.CDDataset, optional): Evaluation dataset.
  333. If None, the model will not be evaluated during training process.
  334. Defaults to None.
  335. optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in
  336. training. If None, a default optimizer will be used. Defaults to None.
  337. save_interval_epochs (int, optional): Epoch interval for saving the model.
  338. Defaults to 1.
  339. log_interval_steps (int, optional): Step interval for printing training
  340. information. Defaults to 2.
  341. save_dir (str, optional): Directory to save the model. Defaults to 'output'.
  342. learning_rate (float, optional): Learning rate for training.
  343. Defaults to .0001.
  344. lr_decay_power (float, optional): Learning decay power. Defaults to .9.
  345. early_stop (bool, optional): Whether to adopt early stop strategy.
  346. Defaults to False.
  347. early_stop_patience (int, optional): Early stop patience. Defaults to 5.
  348. use_vdl (bool, optional): Whether to use VisualDL to monitor the training
  349. process. Defaults to True.
  350. quant_config (dict|None, optional): Quantization configuration. If None,
  351. a default rule of thumb configuration will be used. Defaults to None.
  352. resume_checkpoint (str|None, optional): Path of the checkpoint to resume
  353. quantization-aware training from. If None, no training checkpoint will
  354. be resumed. Defaults to None.
  355. """
  356. self._prepare_qat(quant_config)
  357. self.train(
  358. num_epochs=num_epochs,
  359. train_dataset=train_dataset,
  360. train_batch_size=train_batch_size,
  361. eval_dataset=eval_dataset,
  362. optimizer=optimizer,
  363. save_interval_epochs=save_interval_epochs,
  364. log_interval_steps=log_interval_steps,
  365. save_dir=save_dir,
  366. pretrain_weights=None,
  367. learning_rate=learning_rate,
  368. lr_decay_power=lr_decay_power,
  369. early_stop=early_stop,
  370. early_stop_patience=early_stop_patience,
  371. use_vdl=use_vdl,
  372. resume_checkpoint=resume_checkpoint)
  373. def evaluate(self, eval_dataset, batch_size=1, return_details=False):
  374. """
  375. Evaluate the model.
  376. Args:
  377. eval_dataset (paddlers.datasets.CDDataset): Evaluation dataset.
  378. batch_size (int, optional): Total batch size among all cards used for
  379. evaluation. Defaults to 1.
  380. return_details (bool, optional): Whether to return evaluation details.
  381. Defaults to False.
  382. Returns:
  383. If `return_details` is False, return collections.OrderedDict with
  384. key-value pairs:
  385. For binary change detection (number of classes == 2), the key-value
  386. pairs are like:
  387. {"iou": `intersection over union for the change class`,
  388. "f1": `F1 score for the change class`,
  389. "oacc": `overall accuracy`,
  390. "kappa": ` kappa coefficient`}.
  391. For multi-class change detection (number of classes > 2), the key-value
  392. pairs are like:
  393. {"miou": `mean intersection over union`,
  394. "category_iou": `category-wise mean intersection over union`,
  395. "oacc": `overall accuracy`,
  396. "category_acc": `category-wise accuracy`,
  397. "kappa": ` kappa coefficient`,
  398. "category_F1-score": `F1 score`}.
  399. """
  400. self._check_transforms(eval_dataset.transforms, 'eval')
  401. self.net.eval()
  402. nranks = paddle.distributed.get_world_size()
  403. local_rank = paddle.distributed.get_rank()
  404. if nranks > 1:
  405. # Initialize parallel environment if not done.
  406. if not (paddle.distributed.parallel.parallel_helper.
  407. _is_parallel_ctx_initialized()):
  408. paddle.distributed.init_parallel_env()
  409. batch_size_each_card = get_single_card_bs(batch_size)
  410. if batch_size_each_card > 1:
  411. batch_size_each_card = 1
  412. batch_size = batch_size_each_card * paddlers.env_info['num']
  413. logging.warning(
  414. "ChangeDetector only supports batch_size=1 for each gpu/cpu card " \
  415. "during evaluation, so batch_size " \
  416. "is forcibly set to {}.".format(batch_size)
  417. )
  418. self.eval_data_loader = self.build_data_loader(
  419. eval_dataset, batch_size=batch_size, mode='eval')
  420. intersect_area_all = 0
  421. pred_area_all = 0
  422. label_area_all = 0
  423. conf_mat_all = []
  424. logging.info(
  425. "Start to evaluate(total_samples={}, total_steps={})...".format(
  426. eval_dataset.num_samples,
  427. math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
  428. with paddle.no_grad():
  429. for step, data in enumerate(self.eval_data_loader):
  430. data.append(eval_dataset.transforms.transforms)
  431. outputs = self.run(self.net, data, 'eval')
  432. pred_area = outputs['pred_area']
  433. label_area = outputs['label_area']
  434. intersect_area = outputs['intersect_area']
  435. conf_mat = outputs['conf_mat']
  436. # Gather from all ranks
  437. if nranks > 1:
  438. intersect_area_list = []
  439. pred_area_list = []
  440. label_area_list = []
  441. conf_mat_list = []
  442. paddle.distributed.all_gather(intersect_area_list,
  443. intersect_area)
  444. paddle.distributed.all_gather(pred_area_list, pred_area)
  445. paddle.distributed.all_gather(label_area_list, label_area)
  446. paddle.distributed.all_gather(conf_mat_list, conf_mat)
  447. # Some image has been evaluated and should be eliminated in last iter
  448. if (step + 1) * nranks > len(eval_dataset):
  449. valid = len(eval_dataset) - step * nranks
  450. intersect_area_list = intersect_area_list[:valid]
  451. pred_area_list = pred_area_list[:valid]
  452. label_area_list = label_area_list[:valid]
  453. conf_mat_list = conf_mat_list[:valid]
  454. intersect_area_all += sum(intersect_area_list)
  455. pred_area_all += sum(pred_area_list)
  456. label_area_all += sum(label_area_list)
  457. conf_mat_all.extend(conf_mat_list)
  458. else:
  459. intersect_area_all = intersect_area_all + intersect_area
  460. pred_area_all = pred_area_all + pred_area
  461. label_area_all = label_area_all + label_area
  462. conf_mat_all.append(conf_mat)
  463. class_iou, miou = ppseg.utils.metrics.mean_iou(
  464. intersect_area_all, pred_area_all, label_area_all)
  465. # TODO 确认是按oacc还是macc
  466. class_acc, oacc = ppseg.utils.metrics.accuracy(intersect_area_all,
  467. pred_area_all)
  468. kappa = ppseg.utils.metrics.kappa(intersect_area_all, pred_area_all,
  469. label_area_all)
  470. category_f1score = metrics.f1_score(intersect_area_all, pred_area_all,
  471. label_area_all)
  472. if len(class_acc) > 2:
  473. eval_metrics = OrderedDict(
  474. zip([
  475. 'miou', 'category_iou', 'oacc', 'category_acc', 'kappa',
  476. 'category_F1-score'
  477. ], [miou, class_iou, oacc, class_acc, kappa, category_f1score]))
  478. else:
  479. eval_metrics = OrderedDict(
  480. zip(['iou', 'f1', 'oacc', 'kappa'],
  481. [class_iou[1], category_f1score[1], oacc, kappa]))
  482. if return_details:
  483. conf_mat = sum(conf_mat_all)
  484. eval_details = {'confusion_matrix': conf_mat.tolist()}
  485. return eval_metrics, eval_details
  486. return eval_metrics
  487. def predict(self, img_file, transforms=None):
  488. """
  489. Do inference.
  490. Args:
  491. img_file (list[tuple] | tuple[str|np.ndarray]): Tuple of image paths or
  492. decoded image data for bi-temporal images, which also could constitute
  493. a list, meaning all image pairs to be predicted as a mini-batch.
  494. transforms (paddlers.transforms.Compose|None, optional): Transforms for
  495. inputs. If None, the transforms for evaluation process will be used.
  496. Defaults to None.
  497. Returns:
  498. If `img_file` is a tuple of string or np.array, the result is a dict with
  499. the following key-value pairs:
  500. label_map (np.ndarray): Predicted label map (HW).
  501. score_map (np.ndarray): Prediction score map (HWC).
  502. If `img_file` is a list, the result is a list composed of dicts with the
  503. above keys.
  504. """
  505. if transforms is None and not hasattr(self, 'test_transforms'):
  506. raise ValueError("transforms need to be defined, now is None.")
  507. if transforms is None:
  508. transforms = self.test_transforms
  509. if isinstance(img_file, tuple):
  510. if not len(img_file) == 2 and any(
  511. map(lambda obj: not isinstance(obj, (str, np.ndarray)),
  512. img_file)):
  513. raise TypeError
  514. images = [img_file]
  515. else:
  516. images = img_file
  517. batch_im1, batch_im2, batch_origin_shape = self._preprocess(
  518. images, transforms, self.model_type)
  519. self.net.eval()
  520. data = (batch_im1, batch_im2, batch_origin_shape, transforms.transforms)
  521. outputs = self.run(self.net, data, 'test')
  522. label_map_list = outputs['label_map']
  523. score_map_list = outputs['score_map']
  524. if isinstance(img_file, list):
  525. prediction = [{
  526. 'label_map': l,
  527. 'score_map': s
  528. } for l, s in zip(label_map_list, score_map_list)]
  529. else:
  530. prediction = {
  531. 'label_map': label_map_list[0],
  532. 'score_map': score_map_list[0]
  533. }
  534. return prediction
  535. def slider_predict(self,
  536. img_file,
  537. save_dir,
  538. block_size,
  539. overlap=36,
  540. transforms=None):
  541. """
  542. Do inference.
  543. Args:
  544. img_file (tuple[str]): Tuple of image paths.
  545. save_dir (str): Directory that contains saved geotiff file.
  546. block_size (list[int] | tuple[int] | int, optional): Size of block.
  547. overlap (list[int] | tuple[int] | int, optional): Overlap between two blocks.
  548. Defaults to 36.
  549. transforms (paddlers.transforms.Compose|None, optional): Transforms for inputs.
  550. If None, the transforms for evaluation process will be used. Defaults to None.
  551. """
  552. try:
  553. from osgeo import gdal
  554. except:
  555. import gdal
  556. if not isinstance(img_file, tuple) or len(img_file) != 2:
  557. raise ValueError("`img_file` must be a tuple of length 2.")
  558. if isinstance(block_size, int):
  559. block_size = (block_size, block_size)
  560. elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
  561. block_size = tuple(block_size)
  562. else:
  563. raise ValueError(
  564. "`block_size` must be a tuple/list of length 2 or an integer.")
  565. if isinstance(overlap, int):
  566. overlap = (overlap, overlap)
  567. elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
  568. overlap = tuple(overlap)
  569. else:
  570. raise ValueError(
  571. "`overlap` must be a tuple/list of length 2 or an integer.")
  572. src1_data = gdal.Open(img_file[0])
  573. src2_data = gdal.Open(img_file[1])
  574. width = src1_data.RasterXSize
  575. height = src1_data.RasterYSize
  576. bands = src1_data.RasterCount
  577. driver = gdal.GetDriverByName("GTiff")
  578. file_name = osp.splitext(osp.normpath(img_file[0]).split(os.sep)[-1])[
  579. 0] + ".tif"
  580. if not osp.exists(save_dir):
  581. os.makedirs(save_dir)
  582. save_file = osp.join(save_dir, file_name)
  583. dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
  584. dst_data.SetGeoTransform(src1_data.GetGeoTransform())
  585. dst_data.SetProjection(src1_data.GetProjection())
  586. band = dst_data.GetRasterBand(1)
  587. band.WriteArray(255 * np.ones((height, width), dtype="uint8"))
  588. step = np.array(block_size) - np.array(overlap)
  589. for yoff in range(0, height, step[1]):
  590. for xoff in range(0, width, step[0]):
  591. xsize, ysize = block_size
  592. if xoff + xsize > width:
  593. xsize = int(width - xoff)
  594. if yoff + ysize > height:
  595. ysize = int(height - yoff)
  596. im1 = src1_data.ReadAsArray(
  597. int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0))
  598. im2 = src2_data.ReadAsArray(
  599. int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0))
  600. # Fill
  601. h, w = im1.shape[:2]
  602. im1_fill = np.zeros(
  603. (block_size[1], block_size[0], bands), dtype=im1.dtype)
  604. im2_fill = im1_fill.copy()
  605. im1_fill[:h, :w, :] = im1
  606. im2_fill[:h, :w, :] = im2
  607. im_fill = (im1_fill, im2_fill)
  608. # Predict
  609. pred = self.predict(im_fill,
  610. transforms)["label_map"].astype("uint8")
  611. # Overlap
  612. rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize)
  613. mask = (rd_block == pred[:h, :w]) | (rd_block == 255)
  614. temp = pred[:h, :w].copy()
  615. temp[mask == False] = 0
  616. band.WriteArray(temp, int(xoff), int(yoff))
  617. dst_data.FlushCache()
  618. dst_data = None
  619. print("GeoTiff saved in {}.".format(save_file))
  620. def _preprocess(self, images, transforms, to_tensor=True):
  621. self._check_transforms(transforms, 'test')
  622. batch_im1, batch_im2 = list(), list()
  623. batch_ori_shape = list()
  624. for im1, im2 in images:
  625. if isinstance(im1, str) or isinstance(im2, str):
  626. im1 = decode_image(im1, to_rgb=False)
  627. im2 = decode_image(im2, to_rgb=False)
  628. ori_shape = im1.shape[:2]
  629. # XXX: sample do not contain 'image_t1' and 'image_t2'.
  630. sample = {'image': im1, 'image2': im2}
  631. im1, im2 = transforms(sample)[:2]
  632. batch_im1.append(im1)
  633. batch_im2.append(im2)
  634. batch_ori_shape.append(ori_shape)
  635. if to_tensor:
  636. batch_im1 = paddle.to_tensor(batch_im1)
  637. batch_im2 = paddle.to_tensor(batch_im2)
  638. else:
  639. batch_im1 = np.asarray(batch_im1)
  640. batch_im2 = np.asarray(batch_im2)
  641. return batch_im1, batch_im2, batch_ori_shape
  642. @staticmethod
  643. def get_transforms_shape_info(batch_ori_shape, transforms):
  644. batch_restore_list = list()
  645. for ori_shape in batch_ori_shape:
  646. restore_list = list()
  647. h, w = ori_shape[0], ori_shape[1]
  648. for op in transforms:
  649. if op.__class__.__name__ == 'Resize':
  650. restore_list.append(('resize', (h, w)))
  651. h, w = op.target_size
  652. elif op.__class__.__name__ == 'ResizeByShort':
  653. restore_list.append(('resize', (h, w)))
  654. im_short_size = min(h, w)
  655. im_long_size = max(h, w)
  656. scale = float(op.short_size) / float(im_short_size)
  657. if 0 < op.max_size < np.round(scale * im_long_size):
  658. scale = float(op.max_size) / float(im_long_size)
  659. h = int(round(h * scale))
  660. w = int(round(w * scale))
  661. elif op.__class__.__name__ == 'ResizeByLong':
  662. restore_list.append(('resize', (h, w)))
  663. im_long_size = max(h, w)
  664. scale = float(op.long_size) / float(im_long_size)
  665. h = int(round(h * scale))
  666. w = int(round(w * scale))
  667. elif op.__class__.__name__ == 'Pad':
  668. if op.target_size:
  669. target_h, target_w = op.target_size
  670. else:
  671. target_h = int(
  672. (np.ceil(h / op.size_divisor) * op.size_divisor))
  673. target_w = int(
  674. (np.ceil(w / op.size_divisor) * op.size_divisor))
  675. if op.pad_mode == -1:
  676. offsets = op.offsets
  677. elif op.pad_mode == 0:
  678. offsets = [0, 0]
  679. elif op.pad_mode == 1:
  680. offsets = [(target_h - h) // 2, (target_w - w) // 2]
  681. else:
  682. offsets = [target_h - h, target_w - w]
  683. restore_list.append(('padding', (h, w), offsets))
  684. h, w = target_h, target_w
  685. batch_restore_list.append(restore_list)
  686. return batch_restore_list
  687. def _postprocess(self, batch_pred, batch_origin_shape, transforms):
  688. batch_restore_list = BaseChangeDetector.get_transforms_shape_info(
  689. batch_origin_shape, transforms)
  690. if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
  691. return self._infer_postprocess(
  692. batch_label_map=batch_pred[0],
  693. batch_score_map=batch_pred[1],
  694. batch_restore_list=batch_restore_list)
  695. results = []
  696. if batch_pred.dtype == paddle.float32:
  697. mode = 'bilinear'
  698. else:
  699. mode = 'nearest'
  700. for pred, restore_list in zip(batch_pred, batch_restore_list):
  701. pred = paddle.unsqueeze(pred, axis=0)
  702. for item in restore_list[::-1]:
  703. h, w = item[1][0], item[1][1]
  704. if item[0] == 'resize':
  705. pred = F.interpolate(
  706. pred, (h, w), mode=mode, data_format='NCHW')
  707. elif item[0] == 'padding':
  708. x, y = item[2]
  709. pred = pred[:, :, y:y + h, x:x + w]
  710. else:
  711. pass
  712. results.append(pred)
  713. return results
  714. def _infer_postprocess(self, batch_label_map, batch_score_map,
  715. batch_restore_list):
  716. label_maps = []
  717. score_maps = []
  718. for label_map, score_map, restore_list in zip(
  719. batch_label_map, batch_score_map, batch_restore_list):
  720. if not isinstance(label_map, np.ndarray):
  721. label_map = paddle.unsqueeze(label_map, axis=[0, 3])
  722. score_map = paddle.unsqueeze(score_map, axis=0)
  723. for item in restore_list[::-1]:
  724. h, w = item[1][0], item[1][1]
  725. if item[0] == 'resize':
  726. if isinstance(label_map, np.ndarray):
  727. label_map = cv2.resize(
  728. label_map, (w, h), interpolation=cv2.INTER_NEAREST)
  729. score_map = cv2.resize(
  730. score_map, (w, h), interpolation=cv2.INTER_LINEAR)
  731. else:
  732. label_map = F.interpolate(
  733. label_map, (h, w),
  734. mode='nearest',
  735. data_format='NHWC')
  736. score_map = F.interpolate(
  737. score_map, (h, w),
  738. mode='bilinear',
  739. data_format='NHWC')
  740. elif item[0] == 'padding':
  741. x, y = item[2]
  742. if isinstance(label_map, np.ndarray):
  743. label_map = label_map[..., y:y + h, x:x + w]
  744. score_map = score_map[..., y:y + h, x:x + w]
  745. else:
  746. label_map = label_map[:, :, y:y + h, x:x + w]
  747. score_map = score_map[:, :, y:y + h, x:x + w]
  748. else:
  749. pass
  750. label_map = label_map.squeeze()
  751. score_map = score_map.squeeze()
  752. if not isinstance(label_map, np.ndarray):
  753. label_map = label_map.numpy()
  754. score_map = score_map.numpy()
  755. label_maps.append(label_map.squeeze())
  756. score_maps.append(score_map.squeeze())
  757. return label_maps, score_maps
  758. def _check_transforms(self, transforms, mode):
  759. super()._check_transforms(transforms, mode)
  760. if not isinstance(transforms.arrange,
  761. paddlers.transforms.ArrangeChangeDetector):
  762. raise TypeError(
  763. "`transforms.arrange` must be an ArrangeChangeDetector object.")
  764. def set_losses(self, losses, weights=None):
  765. if weights is None:
  766. weights = [1. for _ in range(len(losses))]
  767. self.losses = {'types': losses, 'coef': weights}
  768. class CDNet(BaseChangeDetector):
  769. def __init__(self,
  770. num_classes=2,
  771. use_mixed_loss=False,
  772. losses=None,
  773. in_channels=6,
  774. **params):
  775. params.update({'in_channels': in_channels})
  776. super(CDNet, self).__init__(
  777. model_name='CDNet',
  778. num_classes=num_classes,
  779. use_mixed_loss=use_mixed_loss,
  780. losses=losses,
  781. **params)
  782. class FCEarlyFusion(BaseChangeDetector):
  783. def __init__(self,
  784. num_classes=2,
  785. use_mixed_loss=False,
  786. losses=None,
  787. in_channels=6,
  788. use_dropout=False,
  789. **params):
  790. params.update({'in_channels': in_channels, 'use_dropout': use_dropout})
  791. super(FCEarlyFusion, self).__init__(
  792. model_name='FCEarlyFusion',
  793. num_classes=num_classes,
  794. use_mixed_loss=use_mixed_loss,
  795. losses=losses,
  796. **params)
  797. class FCSiamConc(BaseChangeDetector):
  798. def __init__(self,
  799. num_classes=2,
  800. use_mixed_loss=False,
  801. losses=None,
  802. in_channels=3,
  803. use_dropout=False,
  804. **params):
  805. params.update({'in_channels': in_channels, 'use_dropout': use_dropout})
  806. super(FCSiamConc, self).__init__(
  807. model_name='FCSiamConc',
  808. num_classes=num_classes,
  809. use_mixed_loss=use_mixed_loss,
  810. losses=losses,
  811. **params)
  812. class FCSiamDiff(BaseChangeDetector):
  813. def __init__(self,
  814. num_classes=2,
  815. use_mixed_loss=False,
  816. losses=None,
  817. in_channels=3,
  818. use_dropout=False,
  819. **params):
  820. params.update({'in_channels': in_channels, 'use_dropout': use_dropout})
  821. super(FCSiamDiff, self).__init__(
  822. model_name='FCSiamDiff',
  823. num_classes=num_classes,
  824. use_mixed_loss=use_mixed_loss,
  825. losses=losses,
  826. **params)
  827. class STANet(BaseChangeDetector):
  828. def __init__(self,
  829. num_classes=2,
  830. use_mixed_loss=False,
  831. losses=None,
  832. in_channels=3,
  833. att_type='BAM',
  834. ds_factor=1,
  835. **params):
  836. params.update({
  837. 'in_channels': in_channels,
  838. 'att_type': att_type,
  839. 'ds_factor': ds_factor
  840. })
  841. super(STANet, self).__init__(
  842. model_name='STANet',
  843. num_classes=num_classes,
  844. use_mixed_loss=use_mixed_loss,
  845. losses=losses,
  846. **params)
  847. class BIT(BaseChangeDetector):
  848. def __init__(self,
  849. num_classes=2,
  850. use_mixed_loss=False,
  851. losses=None,
  852. in_channels=3,
  853. backbone='resnet18',
  854. n_stages=4,
  855. use_tokenizer=True,
  856. token_len=4,
  857. pool_mode='max',
  858. pool_size=2,
  859. enc_with_pos=True,
  860. enc_depth=1,
  861. enc_head_dim=64,
  862. dec_depth=8,
  863. dec_head_dim=8,
  864. **params):
  865. params.update({
  866. 'in_channels': in_channels,
  867. 'backbone': backbone,
  868. 'n_stages': n_stages,
  869. 'use_tokenizer': use_tokenizer,
  870. 'token_len': token_len,
  871. 'pool_mode': pool_mode,
  872. 'pool_size': pool_size,
  873. 'enc_with_pos': enc_with_pos,
  874. 'enc_depth': enc_depth,
  875. 'enc_head_dim': enc_head_dim,
  876. 'dec_depth': dec_depth,
  877. 'dec_head_dim': dec_head_dim
  878. })
  879. super(BIT, self).__init__(
  880. model_name='BIT',
  881. num_classes=num_classes,
  882. use_mixed_loss=use_mixed_loss,
  883. losses=losses,
  884. **params)
  885. class SNUNet(BaseChangeDetector):
  886. def __init__(self,
  887. num_classes=2,
  888. use_mixed_loss=False,
  889. losses=None,
  890. in_channels=3,
  891. width=32,
  892. **params):
  893. params.update({'in_channels': in_channels, 'width': width})
  894. super(SNUNet, self).__init__(
  895. model_name='SNUNet',
  896. num_classes=num_classes,
  897. use_mixed_loss=use_mixed_loss,
  898. losses=losses,
  899. **params)
  900. class DSIFN(BaseChangeDetector):
  901. def __init__(self,
  902. num_classes=2,
  903. use_mixed_loss=False,
  904. losses=None,
  905. use_dropout=False,
  906. **params):
  907. params.update({'use_dropout': use_dropout})
  908. super(DSIFN, self).__init__(
  909. model_name='DSIFN',
  910. num_classes=num_classes,
  911. use_mixed_loss=use_mixed_loss,
  912. losses=losses,
  913. **params)
  914. def default_loss(self):
  915. if self.use_mixed_loss is False:
  916. return {
  917. # XXX: make sure the shallow copy works correctly here.
  918. 'types': [seg_losses.CrossEntropyLoss()] * 5,
  919. 'coef': [1.0] * 5
  920. }
  921. else:
  922. raise ValueError(
  923. f"Currently `use_mixed_loss` must be set to False for {self.__class__}"
  924. )
  925. class DSAMNet(BaseChangeDetector):
  926. def __init__(self,
  927. num_classes=2,
  928. use_mixed_loss=False,
  929. losses=None,
  930. in_channels=3,
  931. ca_ratio=8,
  932. sa_kernel=7,
  933. **params):
  934. params.update({
  935. 'in_channels': in_channels,
  936. 'ca_ratio': ca_ratio,
  937. 'sa_kernel': sa_kernel
  938. })
  939. super(DSAMNet, self).__init__(
  940. model_name='DSAMNet',
  941. num_classes=num_classes,
  942. use_mixed_loss=use_mixed_loss,
  943. losses=losses,
  944. **params)
  945. def default_loss(self):
  946. if self.use_mixed_loss is False:
  947. return {
  948. 'types': [
  949. seg_losses.CrossEntropyLoss(), seg_losses.DiceLoss(),
  950. seg_losses.DiceLoss()
  951. ],
  952. 'coef': [1.0, 0.05, 0.05]
  953. }
  954. else:
  955. raise ValueError(
  956. f"Currently `use_mixed_loss` must be set to False for {self.__class__}"
  957. )
  958. class ChangeStar(BaseChangeDetector):
  959. def __init__(self,
  960. num_classes=2,
  961. use_mixed_loss=False,
  962. losses=None,
  963. mid_channels=256,
  964. inner_channels=16,
  965. num_convs=4,
  966. scale_factor=4.0,
  967. **params):
  968. params.update({
  969. 'mid_channels': mid_channels,
  970. 'inner_channels': inner_channels,
  971. 'num_convs': num_convs,
  972. 'scale_factor': scale_factor
  973. })
  974. super(ChangeStar, self).__init__(
  975. model_name='ChangeStar',
  976. num_classes=num_classes,
  977. use_mixed_loss=use_mixed_loss,
  978. losses=losses,
  979. **params)
  980. def default_loss(self):
  981. if self.use_mixed_loss is False:
  982. return {
  983. # XXX: make sure the shallow copy works correctly here.
  984. 'types': [seglosses.CrossEntropyLoss()] * 4,
  985. 'coef': [1.0] * 4
  986. }
  987. else:
  988. raise ValueError(
  989. f"Currently `use_mixed_loss` must be set to False for {self.__class__}"
  990. )
  991. class ChangeFormer(BaseChangeDetector):
  992. def __init__(self,
  993. in_channels=3,
  994. num_classes=2,
  995. decoder_softmax=False,
  996. embed_dim=256,
  997. use_mixed_loss=False,
  998. **params):
  999. params.update({
  1000. 'in_channels': in_channels,
  1001. 'embed_dim': embed_dim,
  1002. 'decoder_softmax': decoder_softmax
  1003. })
  1004. super(ChangeFormer, self).__init__(
  1005. model_name='ChangeFormer',
  1006. num_classes=num_classes,
  1007. use_mixed_loss=use_mixed_loss,
  1008. **params)