123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334 |
- import os
- import time
- from collections import deque
- import shutil
- import paddle
- import paddle.nn.functional as F
- from paddleseg.utils import (TimeAverager, calculate_eta, resume, logger,
- worker_init_fn, train_profiler, op_flops_funs)
- from paddleseg.core.val import evaluate
- def check_logits_losses(logits_list, losses):
- len_logits = len(logits_list)
- len_losses = len(losses['types'])
- if len_logits != len_losses:
- raise RuntimeError(
- 'The length of logits_list should equal to the types of loss config: {} != {}.'
- .format(len_logits, len_losses))
- def loss_computation(logits_list, labels, edges, losses):
- check_logits_losses(logits_list, losses)
- loss_list = []
- for i in range(len(logits_list)):
- logits = logits_list[i]
- loss_i = losses['types'][i]
- coef_i = losses['coef'][i]
- if loss_i.__class__.__name__ in ('BCELoss', ) and loss_i.edge_label:
-
- loss_list.append(coef_i * loss_i(logits, edges))
- elif loss_i.__class__.__name__ == 'MixedLoss':
- mixed_loss_list = loss_i(logits, labels)
- for mixed_loss in mixed_loss_list:
- loss_list.append(coef_i * mixed_loss)
- elif loss_i.__class__.__name__ in ("KLLoss", ):
- loss_list.append(coef_i *
- loss_i(logits_list[0], logits_list[1].detach()))
- else:
- loss_list.append(coef_i * loss_i(logits, labels))
- return loss_list
- def train(model,
- train_dataset,
- val_dataset=None,
- optimizer=None,
- save_dir='output',
- iters=10000,
- batch_size=2,
- resume_model=None,
- save_interval=1000,
- log_iters=10,
- num_workers=0,
- use_vdl=False,
- losses=None,
- keep_checkpoint_max=5,
- test_config=None,
- precision='fp32',
- amp_level='O1',
- profiler_options=None,
- to_static_training=False):
- """
- Launch training.
- Args:
- model(nn.Layer): A semantic segmentation model.
- train_dataset (paddle.io.Dataset): Used to read and process training datasets.
- val_dataset (paddle.io.Dataset, optional): Used to read and process validation datasets.
- optimizer (paddle.optimizer.Optimizer): The optimizer.
- save_dir (str, optional): The directory for saving the model snapshot. Default: 'output'.
- iters (int, optional): How may iters to train the model. Defualt: 10000.
- batch_size (int, optional): Mini batch size of one gpu or cpu. Default: 2.
- resume_model (str, optional): The path of resume model.
- save_interval (int, optional): How many iters to save a model snapshot once during training. Default: 1000.
- log_iters (int, optional): Display logging information at every log_iters. Default: 10.
- num_workers (int, optional): Num workers for data loader. Default: 0.
- use_vdl (bool, optional): Whether to record the data to VisualDL during training. Default: False.
- losses (dict, optional): A dict including 'types' and 'coef'. The length of coef should equal to 1 or len(losses['types']).
- The 'types' item is a list of object of paddleseg.models.losses while the 'coef' item is a list of the relevant coefficient.
- keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5.
- test_config(dict, optional): Evaluation config.
- precision (str, optional): Use AMP if precision='fp16'. If precision='fp32', the training is normal.
- amp_level (str, optional): Auto mixed precision level. Accepted values are “O1” and “O2”: O1 represent mixed precision,
- the input data type of each operator will be casted by white_list and black_list; O2 represent Pure fp16, all operators
- parameters and input data will be casted to fp16, except operators in black_list, don’t support fp16 kernel and batchnorm. Default is O1(amp)
- profiler_options (str, optional): The option of train profiler.
- to_static_training (bool, optional): Whether to use @to_static for training.
- """
- model.train()
- nranks = paddle.distributed.ParallelEnv().nranks
- local_rank = paddle.distributed.ParallelEnv().local_rank
- start_iter = 0
- if resume_model is not None:
- start_iter = resume(model, optimizer, resume_model)
- if not os.path.isdir(save_dir):
- if os.path.exists(save_dir):
- os.remove(save_dir)
- os.makedirs(save_dir, exist_ok=True)
-
- if precision == 'fp16':
- logger.info('use AMP to train. AMP level = {}'.format(amp_level))
- scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
- if amp_level == 'O2':
- model, optimizer = paddle.amp.decorate(
- models=model,
- optimizers=optimizer,
- level='O2',
- save_dtype='float32')
- if nranks > 1:
- paddle.distributed.fleet.init(is_collective=True)
- optimizer = paddle.distributed.fleet.distributed_optimizer(
- optimizer)
- ddp_model = paddle.distributed.fleet.distributed_model(model)
- batch_sampler = paddle.io.DistributedBatchSampler(
- train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
- loader = paddle.io.DataLoader(
- train_dataset,
- batch_sampler=batch_sampler,
- num_workers=num_workers,
- return_list=True,
- worker_init_fn=worker_init_fn, )
- if use_vdl:
- from visualdl import LogWriter
- log_writer = LogWriter(save_dir)
- if to_static_training:
- model = paddle.jit.to_static(model)
- logger.info("Successfully applied @to_static")
- avg_loss = 0.0
- avg_loss_list = []
- iters_per_epoch = len(batch_sampler)
- best_mean_iou = -1.0
- best_model_iter = -1
- reader_cost_averager = TimeAverager()
- batch_cost_averager = TimeAverager()
- save_models = deque()
- batch_start = time.time()
- iter = start_iter
- while iter < iters:
- for data in loader:
- iter += 1
- if iter > iters:
- version = paddle.__version__
- if version == '2.1.2':
- continue
- else:
- break
- reader_cost_averager.record(time.time() - batch_start)
- images = data['img']
- labels = data['label'].astype('int64')
- edges = None
- if 'edge' in data.keys():
- edges = data['edge'].astype('int64')
- if hasattr(model, 'data_format') and model.data_format == 'NHWC':
- images = images.transpose((0, 2, 3, 1))
- if precision == 'fp16':
- with paddle.amp.auto_cast(
- level=amp_level,
- enable=True,
- custom_white_list={
- "elementwise_add", "batch_norm", "sync_batch_norm"
- },
- custom_black_list={'bilinear_interp_v2'}):
- logits_list = ddp_model(images) if nranks > 1 else model(
- images)
- loss_list = loss_computation(
- logits_list=logits_list,
- labels=labels,
- edges=edges,
- losses=losses)
- loss = sum(loss_list)
- scaled = scaler.scale(loss)
- scaled.backward()
- if isinstance(optimizer, paddle.distributed.fleet.Fleet):
- scaler.minimize(optimizer.user_defined_optimizer, scaled)
- else:
- scaler.minimize(optimizer, scaled)
- else:
- logits_list = ddp_model(images) if nranks > 1 else model(images)
- loss_list = loss_computation(
- logits_list=logits_list,
- labels=labels,
- edges=edges,
- losses=losses)
- loss = sum(loss_list)
- loss.backward()
-
- if isinstance(optimizer, paddle.optimizer.lr.ReduceOnPlateau):
- optimizer.step(loss)
- else:
- optimizer.step()
- lr = optimizer.get_lr()
-
- if isinstance(optimizer, paddle.distributed.fleet.Fleet):
- lr_sche = optimizer.user_defined_optimizer._learning_rate
- else:
- lr_sche = optimizer._learning_rate
- if isinstance(lr_sche, paddle.optimizer.lr.LRScheduler):
- lr_sche.step()
- train_profiler.add_profiler_step(profiler_options)
- model.clear_gradients()
- avg_loss += loss.numpy()[0]
- if not avg_loss_list:
- avg_loss_list = [l.numpy() for l in loss_list]
- else:
- for i in range(len(loss_list)):
- avg_loss_list[i] += loss_list[i].numpy()
- batch_cost_averager.record(
- time.time() - batch_start, num_samples=batch_size)
- if (iter) % log_iters == 0 and local_rank == 0:
- avg_loss /= log_iters
- avg_loss_list = [l[0] / log_iters for l in avg_loss_list]
- remain_iters = iters - iter
- avg_train_batch_cost = batch_cost_averager.get_average()
- avg_train_reader_cost = reader_cost_averager.get_average()
- eta = calculate_eta(remain_iters, avg_train_batch_cost)
- logger.info(
- "[TRAIN] epoch: {}, iter: {}/{}, loss: {:.4f}, lr: {:.6f}, batch_cost: {:.4f}, reader_cost: {:.5f}, ips: {:.4f} samples/sec | ETA {}"
- .format((iter - 1
- ) // iters_per_epoch + 1, iter, iters, avg_loss,
- lr, avg_train_batch_cost, avg_train_reader_cost,
- batch_cost_averager.get_ips_average(), eta))
- if use_vdl:
- log_writer.add_scalar('Train/loss', avg_loss, iter)
-
- if len(avg_loss_list) > 1:
- avg_loss_dict = {}
- for i, value in enumerate(avg_loss_list):
- avg_loss_dict['loss_' + str(i)] = value
- for key, value in avg_loss_dict.items():
- log_tag = 'Train/' + key
- log_writer.add_scalar(log_tag, value, iter)
- log_writer.add_scalar('Train/lr', lr, iter)
- log_writer.add_scalar('Train/batch_cost',
- avg_train_batch_cost, iter)
- log_writer.add_scalar('Train/reader_cost',
- avg_train_reader_cost, iter)
- avg_loss = 0.0
- avg_loss_list = []
- reader_cost_averager.reset()
- batch_cost_averager.reset()
- if (iter % save_interval == 0 or
- iter == iters) and (val_dataset is not None):
- num_workers = 1 if num_workers > 0 else 0
- if test_config is None:
- test_config = {}
- mean_iou, acc, _, _, _ = evaluate(
- model,
- val_dataset,
- num_workers=num_workers,
- precision=precision,
- amp_level=amp_level,
- **test_config)
- model.train()
- if (iter % save_interval == 0 or iter == iters) and local_rank == 0:
- current_save_dir = os.path.join(save_dir,
- "iter_{}".format(iter))
- if not os.path.isdir(current_save_dir):
- os.makedirs(current_save_dir)
- paddle.save(model.state_dict(),
- os.path.join(current_save_dir, 'model.pdparams'))
- paddle.save(optimizer.state_dict(),
- os.path.join(current_save_dir, 'model.pdopt'))
- save_models.append(current_save_dir)
- if len(save_models) > keep_checkpoint_max > 0:
- model_to_remove = save_models.popleft()
- shutil.rmtree(model_to_remove)
- if val_dataset is not None:
- if mean_iou > best_mean_iou:
- best_mean_iou = mean_iou
- best_model_iter = iter
- best_model_dir = os.path.join(save_dir, "best_model")
- paddle.save(
- model.state_dict(),
- os.path.join(best_model_dir, 'model.pdparams'))
- logger.info(
- '[EVAL] The model with the best validation mIoU ({:.4f}) was saved at iter {}.'
- .format(best_mean_iou, best_model_iter))
- if use_vdl:
- log_writer.add_scalar('Evaluate/mIoU', mean_iou, iter)
- log_writer.add_scalar('Evaluate/Acc', acc, iter)
- batch_start = time.time()
-
- if local_rank == 0 and not (precision == 'fp16' and amp_level == 'O2'):
- _, c, h, w = images.shape
- _ = paddle.flops(
- model, [1, c, h, w],
- custom_ops={paddle.nn.SyncBatchNorm: op_flops_funs.count_syncbn})
-
- time.sleep(0.5)
- if use_vdl:
- log_writer.close()
|