train.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import time
  16. from collections import deque
  17. import shutil
  18. import paddle
  19. import paddle.nn.functional as F
  20. from paddleseg.utils import (TimeAverager, calculate_eta, resume, logger,
  21. worker_init_fn, train_profiler, op_flops_funs)
  22. from paddleseg.core.val import evaluate
  23. def check_logits_losses(logits_list, losses):
  24. len_logits = len(logits_list)
  25. len_losses = len(losses['types'])
  26. if len_logits != len_losses:
  27. raise RuntimeError(
  28. 'The length of logits_list should equal to the types of loss config: {} != {}.'
  29. .format(len_logits, len_losses))
  30. def loss_computation(logits_list, labels, edges, losses):
  31. check_logits_losses(logits_list, losses)
  32. loss_list = []
  33. for i in range(len(logits_list)):
  34. logits = logits_list[i]
  35. loss_i = losses['types'][i]
  36. coef_i = losses['coef'][i]
  37. if loss_i.__class__.__name__ in ('BCELoss', ) and loss_i.edge_label:
  38. # Use edges as labels According to loss type.
  39. loss_list.append(coef_i * loss_i(logits, edges))
  40. elif loss_i.__class__.__name__ == 'MixedLoss':
  41. mixed_loss_list = loss_i(logits, labels)
  42. for mixed_loss in mixed_loss_list:
  43. loss_list.append(coef_i * mixed_loss)
  44. elif loss_i.__class__.__name__ in ("KLLoss", ):
  45. loss_list.append(coef_i *
  46. loss_i(logits_list[0], logits_list[1].detach()))
  47. else:
  48. loss_list.append(coef_i * loss_i(logits, labels))
  49. return loss_list
  50. def train(model,
  51. train_dataset,
  52. val_dataset=None,
  53. optimizer=None,
  54. save_dir='output',
  55. iters=10000,
  56. batch_size=2,
  57. resume_model=None,
  58. save_interval=1000,
  59. log_iters=10,
  60. num_workers=0,
  61. use_vdl=False,
  62. losses=None,
  63. keep_checkpoint_max=5,
  64. test_config=None,
  65. precision='fp32',
  66. amp_level='O1',
  67. profiler_options=None,
  68. to_static_training=False):
  69. """
  70. Launch training.
  71. Args:
  72. model(nn.Layer): A semantic segmentation model.
  73. train_dataset (paddle.io.Dataset): Used to read and process training datasets.
  74. val_dataset (paddle.io.Dataset, optional): Used to read and process validation datasets.
  75. optimizer (paddle.optimizer.Optimizer): The optimizer.
  76. save_dir (str, optional): The directory for saving the model snapshot. Default: 'output'.
  77. iters (int, optional): How may iters to train the model. Defualt: 10000.
  78. batch_size (int, optional): Mini batch size of one gpu or cpu. Default: 2.
  79. resume_model (str, optional): The path of resume model.
  80. save_interval (int, optional): How many iters to save a model snapshot once during training. Default: 1000.
  81. log_iters (int, optional): Display logging information at every log_iters. Default: 10.
  82. num_workers (int, optional): Num workers for data loader. Default: 0.
  83. use_vdl (bool, optional): Whether to record the data to VisualDL during training. Default: False.
  84. losses (dict, optional): A dict including 'types' and 'coef'. The length of coef should equal to 1 or len(losses['types']).
  85. The 'types' item is a list of object of paddleseg.models.losses while the 'coef' item is a list of the relevant coefficient.
  86. keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5.
  87. test_config(dict, optional): Evaluation config.
  88. precision (str, optional): Use AMP if precision='fp16'. If precision='fp32', the training is normal.
  89. amp_level (str, optional): Auto mixed precision level. Accepted values are “O1” and “O2”: O1 represent mixed precision,
  90. the input data type of each operator will be casted by white_list and black_list; O2 represent Pure fp16, all operators
  91. parameters and input data will be casted to fp16, except operators in black_list, don’t support fp16 kernel and batchnorm. Default is O1(amp)
  92. profiler_options (str, optional): The option of train profiler.
  93. to_static_training (bool, optional): Whether to use @to_static for training.
  94. """
  95. model.train()
  96. nranks = paddle.distributed.ParallelEnv().nranks
  97. local_rank = paddle.distributed.ParallelEnv().local_rank
  98. start_iter = 0
  99. if resume_model is not None:
  100. start_iter = resume(model, optimizer, resume_model)
  101. if not os.path.isdir(save_dir):
  102. if os.path.exists(save_dir):
  103. os.remove(save_dir)
  104. os.makedirs(save_dir, exist_ok=True)
  105. # use amp
  106. if precision == 'fp16':
  107. logger.info('use AMP to train. AMP level = {}'.format(amp_level))
  108. scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
  109. if amp_level == 'O2':
  110. model, optimizer = paddle.amp.decorate(
  111. models=model,
  112. optimizers=optimizer,
  113. level='O2',
  114. save_dtype='float32')
  115. if nranks > 1:
  116. paddle.distributed.fleet.init(is_collective=True)
  117. optimizer = paddle.distributed.fleet.distributed_optimizer(
  118. optimizer) # The return is Fleet object
  119. ddp_model = paddle.distributed.fleet.distributed_model(model)
  120. batch_sampler = paddle.io.DistributedBatchSampler(
  121. train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
  122. loader = paddle.io.DataLoader(
  123. train_dataset,
  124. batch_sampler=batch_sampler,
  125. num_workers=num_workers,
  126. return_list=True,
  127. worker_init_fn=worker_init_fn, )
  128. if use_vdl:
  129. from visualdl import LogWriter
  130. log_writer = LogWriter(save_dir)
  131. if to_static_training:
  132. model = paddle.jit.to_static(model)
  133. logger.info("Successfully applied @to_static")
  134. avg_loss = 0.0
  135. avg_loss_list = []
  136. iters_per_epoch = len(batch_sampler)
  137. best_mean_iou = -1.0
  138. best_model_iter = -1
  139. reader_cost_averager = TimeAverager()
  140. batch_cost_averager = TimeAverager()
  141. save_models = deque()
  142. batch_start = time.time()
  143. iter = start_iter
  144. while iter < iters:
  145. for data in loader:
  146. iter += 1
  147. if iter > iters:
  148. version = paddle.__version__
  149. if version == '2.1.2':
  150. continue
  151. else:
  152. break
  153. reader_cost_averager.record(time.time() - batch_start)
  154. images = data['img']
  155. labels = data['label'].astype('int64')
  156. edges = None
  157. if 'edge' in data.keys():
  158. edges = data['edge'].astype('int64')
  159. if hasattr(model, 'data_format') and model.data_format == 'NHWC':
  160. images = images.transpose((0, 2, 3, 1))
  161. if precision == 'fp16':
  162. with paddle.amp.auto_cast(
  163. level=amp_level,
  164. enable=True,
  165. custom_white_list={
  166. "elementwise_add", "batch_norm", "sync_batch_norm"
  167. },
  168. custom_black_list={'bilinear_interp_v2'}):
  169. logits_list = ddp_model(images) if nranks > 1 else model(
  170. images)
  171. loss_list = loss_computation(
  172. logits_list=logits_list,
  173. labels=labels,
  174. edges=edges,
  175. losses=losses)
  176. loss = sum(loss_list)
  177. scaled = scaler.scale(loss) # scale the loss
  178. scaled.backward() # do backward
  179. if isinstance(optimizer, paddle.distributed.fleet.Fleet):
  180. scaler.minimize(optimizer.user_defined_optimizer, scaled)
  181. else:
  182. scaler.minimize(optimizer, scaled) # update parameters
  183. else:
  184. logits_list = ddp_model(images) if nranks > 1 else model(images)
  185. loss_list = loss_computation(
  186. logits_list=logits_list,
  187. labels=labels,
  188. edges=edges,
  189. losses=losses)
  190. loss = sum(loss_list)
  191. loss.backward()
  192. # if the optimizer is ReduceOnPlateau, the loss is the one which has been pass into step.
  193. if isinstance(optimizer, paddle.optimizer.lr.ReduceOnPlateau):
  194. optimizer.step(loss)
  195. else:
  196. optimizer.step()
  197. lr = optimizer.get_lr()
  198. # update lr
  199. if isinstance(optimizer, paddle.distributed.fleet.Fleet):
  200. lr_sche = optimizer.user_defined_optimizer._learning_rate
  201. else:
  202. lr_sche = optimizer._learning_rate
  203. if isinstance(lr_sche, paddle.optimizer.lr.LRScheduler):
  204. lr_sche.step()
  205. train_profiler.add_profiler_step(profiler_options)
  206. model.clear_gradients()
  207. avg_loss += loss.numpy()[0]
  208. if not avg_loss_list:
  209. avg_loss_list = [l.numpy() for l in loss_list]
  210. else:
  211. for i in range(len(loss_list)):
  212. avg_loss_list[i] += loss_list[i].numpy()
  213. batch_cost_averager.record(
  214. time.time() - batch_start, num_samples=batch_size)
  215. if (iter) % log_iters == 0 and local_rank == 0:
  216. avg_loss /= log_iters
  217. avg_loss_list = [l[0] / log_iters for l in avg_loss_list]
  218. remain_iters = iters - iter
  219. avg_train_batch_cost = batch_cost_averager.get_average()
  220. avg_train_reader_cost = reader_cost_averager.get_average()
  221. eta = calculate_eta(remain_iters, avg_train_batch_cost)
  222. logger.info(
  223. "[TRAIN] epoch: {}, iter: {}/{}, loss: {:.4f}, lr: {:.6f}, batch_cost: {:.4f}, reader_cost: {:.5f}, ips: {:.4f} samples/sec | ETA {}"
  224. .format((iter - 1
  225. ) // iters_per_epoch + 1, iter, iters, avg_loss,
  226. lr, avg_train_batch_cost, avg_train_reader_cost,
  227. batch_cost_averager.get_ips_average(), eta))
  228. if use_vdl:
  229. log_writer.add_scalar('Train/loss', avg_loss, iter)
  230. # Record all losses if there are more than 2 losses.
  231. if len(avg_loss_list) > 1:
  232. avg_loss_dict = {}
  233. for i, value in enumerate(avg_loss_list):
  234. avg_loss_dict['loss_' + str(i)] = value
  235. for key, value in avg_loss_dict.items():
  236. log_tag = 'Train/' + key
  237. log_writer.add_scalar(log_tag, value, iter)
  238. log_writer.add_scalar('Train/lr', lr, iter)
  239. log_writer.add_scalar('Train/batch_cost',
  240. avg_train_batch_cost, iter)
  241. log_writer.add_scalar('Train/reader_cost',
  242. avg_train_reader_cost, iter)
  243. avg_loss = 0.0
  244. avg_loss_list = []
  245. reader_cost_averager.reset()
  246. batch_cost_averager.reset()
  247. if (iter % save_interval == 0 or
  248. iter == iters) and (val_dataset is not None):
  249. num_workers = 1 if num_workers > 0 else 0
  250. if test_config is None:
  251. test_config = {}
  252. mean_iou, acc, _, _, _ = evaluate(
  253. model,
  254. val_dataset,
  255. num_workers=num_workers,
  256. precision=precision,
  257. amp_level=amp_level,
  258. **test_config)
  259. model.train()
  260. if (iter % save_interval == 0 or iter == iters) and local_rank == 0:
  261. current_save_dir = os.path.join(save_dir,
  262. "iter_{}".format(iter))
  263. if not os.path.isdir(current_save_dir):
  264. os.makedirs(current_save_dir)
  265. paddle.save(model.state_dict(),
  266. os.path.join(current_save_dir, 'model.pdparams'))
  267. paddle.save(optimizer.state_dict(),
  268. os.path.join(current_save_dir, 'model.pdopt'))
  269. save_models.append(current_save_dir)
  270. if len(save_models) > keep_checkpoint_max > 0:
  271. model_to_remove = save_models.popleft()
  272. shutil.rmtree(model_to_remove)
  273. if val_dataset is not None:
  274. if mean_iou > best_mean_iou:
  275. best_mean_iou = mean_iou
  276. best_model_iter = iter
  277. best_model_dir = os.path.join(save_dir, "best_model")
  278. paddle.save(
  279. model.state_dict(),
  280. os.path.join(best_model_dir, 'model.pdparams'))
  281. logger.info(
  282. '[EVAL] The model with the best validation mIoU ({:.4f}) was saved at iter {}.'
  283. .format(best_mean_iou, best_model_iter))
  284. if use_vdl:
  285. log_writer.add_scalar('Evaluate/mIoU', mean_iou, iter)
  286. log_writer.add_scalar('Evaluate/Acc', acc, iter)
  287. batch_start = time.time()
  288. # Calculate flops.
  289. if local_rank == 0 and not (precision == 'fp16' and amp_level == 'O2'):
  290. _, c, h, w = images.shape
  291. _ = paddle.flops(
  292. model, [1, c, h, w],
  293. custom_ops={paddle.nn.SyncBatchNorm: op_flops_funs.count_syncbn})
  294. # Sleep for half a second to let dataloader release resources.
  295. time.sleep(0.5)
  296. if use_vdl:
  297. log_writer.close()