浏览代码

update amp train (#143)

* update amp train

* Update and polish

* polish amp train

* Reuse scaler and fix bug

---------

Co-authored-by: Bobholamovic <mhlin425@whu.edu.cn>
huilin 2 年之前
父节点
当前提交
3ee1c7bb42

+ 45 - 5
docs/apis/train_cn.md

@@ -64,7 +64,11 @@ def train(self,
           early_stop=False,
           early_stop_patience=5,
           use_vdl=True,
-          resume_checkpoint=None):
+          resume_checkpoint=None,
+          precision='fp32',
+          amp_level='O1',
+          custom_white_list=None,
+          custom_black_list=None):
 ```
 
 其中各参数的含义如下:
@@ -86,6 +90,10 @@ def train(self,
 |`early_stop_patience`|`int`|启用早停策略时的`patience`参数(参见[`EarlyStop`](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/utils/utils.py))。|`5`|
 |`use_vdl`|`bool`|是否启用VisualDL日志。|`True`|
 |`resume_checkpoint`|`str` \| `None`|检查点路径。PaddleRS支持从检查点(包含先前训练过程中存储的模型权重和优化器权重)继续训练,但需注意`resume_checkpoint`与`pretrain_weights`不得同时设置为`None`以外的值。|`None`|
+|`precision`|`str`|当设定为`'fp16'`时,启用自动混合精度训练。|`'fp32'`|
+|`amp_level`|`str`|自动混合精度训练模式。在O1模式下,基于白名单和黑名单确定每个算子使用FP16还是FP32精度计算。在O2模式下,除自定义黑名单中指定的算子以及部分不支持FP16精度的算子以外,全部使用FP16精度计算。|`'O1'`|
+|`custom_white_list`|`set` \| `list` \| `tuple` \| `None` |自定义白名单,用于自动混合精度训练。|`None`|
+|`custom_black_list`|`set` \| `list` \| `tuple` \| `None` |自定义黑名单,用于自动混合精度训练。|`None`|
 
 ### `BaseClassifier.train()`
 
@@ -107,7 +115,11 @@ def train(self,
           early_stop=False,
           early_stop_patience=5,
           use_vdl=True,
-          resume_checkpoint=None):
+          resume_checkpoint=None,
+          precision='fp32',
+          amp_level='O1',
+          custom_white_list=None,
+          custom_black_list=None):
 ```
 
 其中各参数的含义如下:
@@ -129,6 +141,10 @@ def train(self,
 |`early_stop_patience`|`int`|启用早停策略时的`patience`参数(参见[`EarlyStop`](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/utils/utils.py))。|`5`|
 |`use_vdl`|`bool`|是否启用VisualDL日志。|`True`|
 |`resume_checkpoint`|`str` \| `None`|检查点路径。PaddleRS支持从检查点(包含先前训练过程中存储的模型权重和优化器权重)继续训练,但需注意`resume_checkpoint`与`pretrain_weights`不得同时设置为`None`以外的值。|`None`|
+|`precision`|`str`|当设定为`'fp16'`时,启用自动混合精度训练。|`'fp32'`|
+|`amp_level`|`str`|自动混合精度训练模式。在O1模式下,基于白名单和黑名单确定每个算子使用FP16还是FP32精度计算。在O2模式下,除自定义黑名单中指定的算子以及部分不支持FP16精度的算子以外,全部使用FP16精度计算。|`'O1'`|
+|`custom_white_list`|`set` \| `list` \| `tuple` \| `None` |自定义白名单,用于自动混合精度训练。|`None`|
+|`custom_black_list`|`set` \| `list` \| `tuple` \| `None` |自定义黑名单,用于自动混合精度训练。|`None`|
 
 ### `BaseDetector.train()`
 
@@ -155,7 +171,11 @@ def train(self,
           early_stop=False,
           early_stop_patience=5,
           use_vdl=True,
-          resume_checkpoint=None):
+          resume_checkpoint=None,
+          precision='fp32',
+          amp_level='O1',
+          custom_white_list=None,
+          custom_black_list=None):
 ```
 
 其中各参数的含义如下:
@@ -182,6 +202,10 @@ def train(self,
 |`early_stop_patience`|`int`|启用早停策略时的`patience`参数(参见[`EarlyStop`](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/utils/utils.py))。|`5`|
 |`use_vdl`|`bool`|是否启用VisualDL日志。|`True`|
 |`resume_checkpoint`|`str` \| `None`|检查点路径。PaddleRS支持从检查点(包含先前训练过程中存储的模型权重和优化器权重)继续训练,但需注意`resume_checkpoint`与`pretrain_weights`不得同时设置为`None`以外的值。|`None`|
+|`precision`|`str`|当设定为`'fp16'`时,启用自动混合精度训练。|`'fp32'`|
+|`amp_level`|`str`|自动混合精度训练模式。在O1模式下,基于白名单和黑名单确定每个算子使用FP16还是FP32精度计算。在O2模式下,除自定义黑名单中指定的算子以及部分不支持FP16精度的算子以外,全部使用FP16精度计算。|`'O1'`|
+|`custom_white_list`|`set` \| `list` \| `tuple` \| `None` |自定义白名单,用于自动混合精度训练。|`None`|
+|`custom_black_list`|`set` \| `list` \| `tuple` \| `None` |自定义黑名单,用于自动混合精度训练。|`None`|
 
 ### `BaseRestorer.train()`
 
@@ -203,7 +227,11 @@ def train(self,
           early_stop=False,
           early_stop_patience=5,
           use_vdl=True,
-          resume_checkpoint=None):
+          resume_checkpoint=None,
+          precision='fp32',
+          amp_level='O1',
+          custom_white_list=None,
+          custom_black_list=None):
 ```
 
 其中各参数的含义如下:
@@ -225,6 +253,10 @@ def train(self,
 |`early_stop_patience`|`int`|启用早停策略时的`patience`参数(参见[`EarlyStop`](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/utils/utils.py))。|`5`|
 |`use_vdl`|`bool`|是否启用VisualDL日志。|`True`|
 |`resume_checkpoint`|`str` \| `None`|检查点路径。PaddleRS支持从检查点(包含先前训练过程中存储的模型权重和优化器权重)继续训练,但需注意`resume_checkpoint`与`pretrain_weights`不得同时设置为`None`以外的值。|`None`|
+|`precision`|`str`|当设定为`'fp16'`时,启用自动混合精度训练。|`'fp32'`|
+|`amp_level`|`str`|自动混合精度训练模式。在O1模式下,基于白名单和黑名单确定每个算子使用FP16还是FP32精度计算。在O2模式下,除自定义黑名单中指定的算子以及部分不支持FP16精度的算子以外,全部使用FP16精度计算。|`'O1'`|
+|`custom_white_list`|`set` \| `list` \| `tuple` \| `None` |自定义白名单,用于自动混合精度训练。|`None`|
+|`custom_black_list`|`set` \| `list` \| `tuple` \| `None` |自定义黑名单,用于自动混合精度训练。|`None`|
 
 ### `BaseSegmenter.train()`
 
@@ -246,7 +278,11 @@ def train(self,
           early_stop=False,
           early_stop_patience=5,
           use_vdl=True,
-          resume_checkpoint=None):
+          resume_checkpoint=None,
+          precision='fp32',
+          amp_level='O1',
+          custom_white_list=None,
+          custom_black_list=None):
 ```
 
 其中各参数的含义如下:
@@ -268,6 +304,10 @@ def train(self,
 |`early_stop_patience`|`int`|启用早停策略时的`patience`参数(参见[`EarlyStop`](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/utils/utils.py))。|`5`|
 |`use_vdl`|`bool`|是否启用VisualDL日志。|`True`|
 |`resume_checkpoint`|`str` \| `None`|检查点路径。PaddleRS支持从检查点(包含先前训练过程中存储的模型权重和优化器权重)继续训练,但需注意`resume_checkpoint`与`pretrain_weights`不得同时设置为`None`以外的值。|`None`|
+|`precision`|`str`|当设定为`'fp16'`时,启用自动混合精度训练。|`'fp32'`|
+|`amp_level`|`str`|自动混合精度训练模式。在O1模式下,基于白名单和黑名单确定每个算子使用FP16还是FP32精度计算。在O2模式下,除自定义黑名单中指定的算子以及部分不支持FP16精度的算子以外,全部使用FP16精度计算。|`'O1'`|
+|`custom_white_list`|`set` \| `list` \| `tuple` \| `None` |自定义白名单,用于自动混合精度训练。|`None`|
+|`custom_black_list`|`set` \| `list` \| `tuple` \| `None` |自定义黑名单,用于自动混合精度训练。|`None`|
 
 ## `evaluate()`
 

+ 45 - 5
docs/apis/train_en.md

@@ -64,7 +64,11 @@ def train(self,
           early_stop=False,
           early_stop_patience=5,
           use_vdl=True,
-          resume_checkpoint=None):
+          resume_checkpoint=None,
+          precision='fp32',
+          amp_level='O1',
+          custom_white_list=None,
+          custom_black_list=None):
 ```
 
 The meaning of each parameter is as follows:
@@ -86,6 +90,10 @@ The meaning of each parameter is as follows:
 |`early_stop_patience`|`int`|`patience` parameter when the early stopping policy is enabled. Please refer to [`EarlyStop`](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/utils/utils.py) for more details.|`5`|
 |`use_vdl`|`bool`|Whether to enable VisualDL.|`True`|
 |`resume_checkpoint`|`str` \| `None`|Checkpoint path. PaddleRS supports resuming training from checkpoints (including model weights and optimizer weights stored during previous training), but note that `resume_checkpoint` and `pretrain_weights` must not be set to values other than `None` at the same time.|`None`|
+|`precision`|`str`|Use AMP (auto mixed precision) training if `precision` is set to `'fp16'`.|`'fp32'`|
+|`amp_level`|`str`|Auto mixed precision level. Accepted values are 'O1' and 'O2': At O1 level, the input data type of each operator will be casted according to a white list and a black list. At O2 level, all parameters and input data will be casted to FP16, except those for the operators in the black list, those without the support for FP16 kernel, and those for the batchnorm layers.|`'O1'`|
+|`custom_white_list`|`set` \| `list` \| `tuple` \| `None` |Custom white list to use when `amp_level` is set to `'O1'`.|`None`|
+|`custom_black_list`|`set` \| `list` \| `tuple` \| `None` |Custom black list to use in AMP training.|`None`|
 
 ### `BaseClassifier.train()`
 
@@ -107,7 +115,11 @@ def train(self,
           early_stop=False,
           early_stop_patience=5,
           use_vdl=True,
-          resume_checkpoint=None):
+          resume_checkpoint=None,
+          precision='fp32',
+          amp_level='O1',
+          custom_white_list=None,
+          custom_black_list=None):
 ```
 
 The meaning of each parameter is as follows:
@@ -129,6 +141,10 @@ The meaning of each parameter is as follows:
 |`early_stop_patience`|`int`|`patience` parameter when the early stopping policy is enabled. Please refer to [`EarlyStop`](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/utils/utils.py) for more details.|`5`|
 |`use_vdl`|`bool`|Whether to enable VisualDL.|`True`|
 |`resume_checkpoint`|`str` \| `None`|Checkpoint path. PaddleRS supports resuming training from checkpoints (including model weights and optimizer weights stored during previous training), but note that `resume_checkpoint` and `pretrain_weights` must not be set to values other than `None` at the same time.|`None`|
+|`precision`|`str`|Use AMP (auto mixed precision) training if `precision` is set to `'fp16'`.|`'fp32'`|
+|`amp_level`|`str`|Auto mixed precision level. Accepted values are 'O1' and 'O2': At O1 level, the input data type of each operator will be casted according to a white list and a black list. At O2 level, all parameters and input data will be casted to FP16, except those for the operators in the black list, those without the support for FP16 kernel, and those for the batchnorm layers.|`'O1'`|
+|`custom_white_list`|`set` \| `list` \| `tuple` \| `None` |Custom white list to use when `amp_level` is set to `'O1'`.|`None`|
+|`custom_black_list`|`set` \| `list` \| `tuple` \| `None` |Custom black list to use in AMP training.|`None`|
 
 ### `BaseDetector.train()`
 
@@ -155,7 +171,11 @@ def train(self,
           early_stop=False,
           early_stop_patience=5,
           use_vdl=True,
-          resume_checkpoint=None):
+          resume_checkpoint=None,
+          precision='fp32',
+          amp_level='O1',
+          custom_white_list=None,
+          custom_black_list=None):
 ```
 
 The meaning of each parameter is as follows:
@@ -182,6 +202,10 @@ The meaning of each parameter is as follows:
 |`early_stop_patience`|`int`|`patience` parameter when the early stopping policy is enabled. Please refer to [`EarlyStop`](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/utils/utils.py) for more details.|`5`|
 |`use_vdl`|`bool`|Whether to enable VisualDL.|`True`|
 |`resume_checkpoint`|`str` \| `None`|Checkpoint path. PaddleRS supports resuming training from checkpoints (including model weights and optimizer weights stored during previous training), but note that `resume_checkpoint` and `pretrain_weights` must not be set to values other than `None` at the same time.|`None`|
+|`precision`|`str`|Use AMP (auto mixed precision) training if `precision` is set to `'fp16'`.|`'fp32'`|
+|`amp_level`|`str`|Auto mixed precision level. Accepted values are 'O1' and 'O2': At O1 level, the input data type of each operator will be casted according to a white list and a black list. At O2 level, all parameters and input data will be casted to FP16, except those for the operators in the black list, those without the support for FP16 kernel, and those for the batchnorm layers.|`'O1'`|
+|`custom_white_list`|`set` \| `list` \| `tuple` \| `None` |Custom white list to use when `amp_level` is set to `'O1'`.|`None`|
+|`custom_black_list`|`set` \| `list` \| `tuple` \| `None` |Custom black list to use in AMP training.|`None`|
 
 ### `BaseRestorer.train()`
 
@@ -203,7 +227,11 @@ def train(self,
           early_stop=False,
           early_stop_patience=5,
           use_vdl=True,
-          resume_checkpoint=None):
+          resume_checkpoint=None,
+          precision='fp32',
+          amp_level='O1',
+          custom_white_list=None,
+          custom_black_list=None):
 ```
 
 The meaning of each parameter is as follows:
@@ -225,6 +253,10 @@ The meaning of each parameter is as follows:
 |`early_stop_patience`|`int`|`patience` parameter when the early stopping policy is enabled. Please refer to [`EarlyStop`](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/utils/utils.py) for more details.|`5`|
 |`use_vdl`|`bool`|Whether to enable VisualDL.|`True`|
 |`resume_checkpoint`|`str` \| `None`|Checkpoint path. PaddleRS supports resuming training from checkpoints (including model weights and optimizer weights stored during previous training), but note that `resume_checkpoint` and `pretrain_weights` must not be set to values other than `None` at the same time.|`None`|
+|`precision`|`str`|Use AMP (auto mixed precision) training if `precision` is set to `'fp16'`.|`'fp32'`|
+|`amp_level`|`str`|Auto mixed precision level. Accepted values are 'O1' and 'O2': At O1 level, the input data type of each operator will be casted according to a white list and a black list. At O2 level, all parameters and input data will be casted to FP16, except those for the operators in the black list, those without the support for FP16 kernel, and those for the batchnorm layers.|`'O1'`|
+|`custom_white_list`|`set` \| `list` \| `tuple` \| `None` |Custom white list to use when `amp_level` is set to `'O1'`.|`None`|
+|`custom_black_list`|`set` \| `list` \| `tuple` \| `None` |Custom black list to use in AMP training.|`None`|
 
 ### `BaseSegmenter.train()`
 
@@ -246,7 +278,11 @@ def train(self,
           early_stop=False,
           early_stop_patience=5,
           use_vdl=True,
-          resume_checkpoint=None):
+          resume_checkpoint=None,
+          precision='fp32',
+          amp_level='O1',
+          custom_white_list=None,
+          custom_black_list=None):
 ```
 
 The meaning of each parameter is as follows:
@@ -268,6 +304,10 @@ The meaning of each parameter is as follows:
 |`early_stop_patience`|`int`|`patience` parameter when the early stopping policy is enabled. Please refer to [`EarlyStop`](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/utils/utils.py) for more details.|`5`|
 |`use_vdl`|`bool`|Whether to enable VisualDL.|`True`|
 |`resume_checkpoint`|`str` \| `None`|Checkpoint path. PaddleRS supports resuming training from checkpoints (including model weights and optimizer weights stored during previous training), but note that `resume_checkpoint` and `pretrain_weights` must not be set to values other than `None` at the same time.|`None`|
+|`precision`|`str`|Use AMP (auto mixed precision) training if `precision` is set to `'fp16'`.|`'fp32'`|
+|`amp_level`|`str`|Auto mixed precision level. Accepted values are 'O1' and 'O2': At O1 level, the input data type of each operator will be casted according to a white list and a black list. At O2 level, all parameters and input data will be casted to FP16, except those for the operators in the black list, those without the support for FP16 kernel, and those for the batchnorm layers.|`'O1'`|
+|`custom_white_list`|`set` \| `list` \| `tuple` \| `None` |Custom white list to use when `amp_level` is set to `'O1'`.|`None`|
+|`custom_black_list`|`set` \| `list` \| `tuple` \| `None` |Custom black list to use in AMP training.|`None`|
 
 ## `evaluate()`
 

+ 46 - 16
paddlers/tasks/base.py

@@ -78,6 +78,11 @@ class BaseModel(metaclass=ModelMeta):
         self.eval_metrics = None
         self.best_accuracy = -1.
         self.best_model_epoch = -1
+        self.precision = 'fp32'
+        self.amp_level = None
+        self.custom_white_list = None
+        self.custom_black_list = None
+        self.scaler = None
         # Whether to use synchronized BN
         self.sync_bn = False
         self.status = 'Normal'
@@ -312,6 +317,20 @@ class BaseModel(metaclass=ModelMeta):
                    use_vdl=True):
         self._check_transforms(train_dataset.transforms)
 
+        net, optimizer = self.net, self.optimizer
+        # Use AMP
+        if self.precision == 'fp16':
+            logging.info("Use AMP training. AMP level = {}.".format(
+                self.amp_level))
+            # XXX: Hard-code init loss scaling
+            self.scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
+            if self.amp_level == 'O2':
+                net, optimizer = paddle.amp.decorate(
+                    models=self.net,
+                    optimizers=self.optimizer,
+                    level=self.amp_level,
+                    save_dtype='float32')
+
         # XXX: Hard-coding
         if self.model_type == 'detector' and 'RCNN' in self.__class__.__name__ and train_dataset.pos_num < len(
                 train_dataset.file_list):
@@ -325,12 +344,10 @@ class BaseModel(metaclass=ModelMeta):
             ):
                 paddle.distributed.init_parallel_env()
                 ddp_net = to_data_parallel(
-                    self.net,
-                    find_unused_parameters=self.find_unused_parameters)
+                    net, find_unused_parameters=self.find_unused_parameters)
             else:
                 ddp_net = to_data_parallel(
-                    self.net,
-                    find_unused_parameters=self.find_unused_parameters)
+                    net, find_unused_parameters=self.find_unused_parameters)
 
         if use_vdl:
             from visualdl import LogWriter
@@ -361,7 +378,7 @@ class BaseModel(metaclass=ModelMeta):
 
         current_step = 0
         for i in range(start_epoch, num_epochs):
-            self.net.train()
+            net.train()
             if callable(
                     getattr(self.train_data_loader.dataset, 'set_epoch', None)):
                 self.train_data_loader.dataset.set_epoch(i)
@@ -370,14 +387,14 @@ class BaseModel(metaclass=ModelMeta):
 
             for step, data in enumerate(self.train_data_loader()):
                 if nranks > 1:
-                    outputs = self.train_step(step, data, ddp_net)
+                    outputs = self.train_step(step, data, ddp_net, optimizer)
                 else:
-                    outputs = self.train_step(step, data, self.net)
+                    outputs = self.train_step(step, data, net, optimizer)
 
-                scheduler_step(self.optimizer, outputs['loss'])
+                scheduler_step(optimizer, outputs['loss'])
 
                 train_avg_metrics.update(outputs)
-                lr = self.optimizer.get_lr()
+                lr = optimizer.get_lr()
                 outputs['lr'] = lr
                 if ema is not None:
                     ema.update(self.net)
@@ -666,13 +683,26 @@ class BaseModel(metaclass=ModelMeta):
         logging.info("The inference model for deployment is saved in {}.".
                      format(save_dir))
 
-    def train_step(self, step, data, net):
-        outputs = self.run(net, data, mode='train')
-
-        loss = outputs['loss']
-        loss.backward()
-        self.optimizer.step()
-        self.optimizer.clear_grad()
+    def train_step(self, step, data, net, optimizer):
+        if self.precision == 'fp16':
+            with paddle.amp.auto_cast(
+                    level=self.amp_level,
+                    enable=True,
+                    custom_white_list=self.custom_white_list,
+                    custom_black_list=self.custom_black_list):
+                outputs = self.run(net, data, mode='train')
+            scaled = self.scaler.scale(outputs['loss'])
+            scaled.backward()
+            if isinstance(optimizer, paddle.distributed.fleet.Fleet):
+                self.scaler.minimize(optimizer.user_defined_optimizer, scaled)
+            else:
+                self.scaler.minimize(optimizer, scaled)
+        else:
+            outputs = self.run(net, data, mode='train')
+            loss = outputs['loss']
+            loss.backward()
+            optimizer.step()
+            optimizer.clear_grad()
 
         return outputs
 

+ 31 - 3
paddlers/tasks/change_detector.py

@@ -231,7 +231,11 @@ class BaseChangeDetector(BaseModel):
               early_stop=False,
               early_stop_patience=5,
               use_vdl=True,
-              resume_checkpoint=None):
+              resume_checkpoint=None,
+              precision='fp32',
+              amp_level='O1',
+              custom_white_list=None,
+              custom_black_list=None):
         """
         Train the model.
 
@@ -263,7 +267,23 @@ class BaseChangeDetector(BaseModel):
                 training from. If None, no training checkpoint will be resumed. At most
                 Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
                 Defaults to None.
+            precision (str, optional): Use AMP (auto mixed precision) training if `precision`
+                is set to 'fp16'. Defaults to 'fp32'.
+            amp_level (str, optional): Auto mixed precision level. Accepted values are 'O1' 
+                and 'O2': At O1 level, the input data type of each operator will be casted 
+                according to a white list and a black list. At O2 level, all parameters and 
+                input data will be casted to FP16, except those for the operators in the black 
+                list, those without the support for FP16 kernel, and those for the batchnorm 
+                layers. Defaults to 'O1'.
+            custom_white_list(set|list|tuple|None, optional): Custom white list to use when 
+                `amp_level` is set to 'O1'. Defaults to None.
+            custom_black_list(set|list|tuple|None, optional): Custom black list to use in AMP 
+                training. Defaults to None.
         """
+        self.precision = precision
+        self.amp_level = amp_level
+        self.custom_white_list = custom_white_list
+        self.custom_black_list = custom_black_list
 
         if self.status == 'Infer':
             logging.error(
@@ -454,12 +474,20 @@ class BaseChangeDetector(BaseModel):
         label_area_all = 0
         conf_mat_all = []
         logging.info(
-            "Start to evaluate(total_samples={}, total_steps={})...".format(
+            "Start to evaluate (total_samples={}, total_steps={})...".format(
                 eval_dataset.num_samples,
                 math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
         with paddle.no_grad():
             for step, data in enumerate(self.eval_data_loader):
-                outputs = self.run(self.net, data, 'eval')
+                if self.precision == 'fp16':
+                    with paddle.amp.auto_cast(
+                            level=self.amp_level,
+                            enable=True,
+                            custom_white_list=self.custom_white_list,
+                            custom_black_list=self.custom_black_list):
+                        outputs = self.run(self.net, data, 'eval')
+                else:
+                    outputs = self.run(self.net, data, 'eval')
                 pred_area = outputs['pred_area']
                 label_area = outputs['label_area']
                 intersect_area = outputs['intersect_area']

+ 32 - 4
paddlers/tasks/classifier.py

@@ -194,7 +194,11 @@ class BaseClassifier(BaseModel):
               early_stop=False,
               early_stop_patience=5,
               use_vdl=True,
-              resume_checkpoint=None):
+              resume_checkpoint=None,
+              precision='fp32',
+              amp_level='O1',
+              custom_white_list=None,
+              custom_black_list=None):
         """
         Train the model.
 
@@ -228,7 +232,23 @@ class BaseClassifier(BaseModel):
                 training from. If None, no training checkpoint will be resumed. At most
                 Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
                 Defaults to None.
+            precision (str, optional): Use AMP (auto mixed precision) training if `precision`
+                is set to 'fp16'. Defaults to 'fp32'.
+            amp_level (str, optional): Auto mixed precision level. Accepted values are 'O1' 
+                and 'O2': At O1 level, the input data type of each operator will be casted 
+                according to a white list and a black list. At O2 level, all parameters and 
+                input data will be casted to FP16, except those for the operators in the black 
+                list, those without the support for FP16 kernel, and those for the batchnorm 
+                layers. Defaults to 'O1'.
+            custom_white_list(set|list|tuple|None, optional): Custom white list to use when 
+                `amp_level` is set to 'O1'. Defaults to None.
+            custom_black_list(set|list|tuple|None, optional): Custom black list to use in AMP 
+                training. Defaults to None.
         """
+        self.precision = precision
+        self.amp_level = amp_level
+        self.custom_white_list = custom_white_list
+        self.custom_black_list = custom_black_list
 
         if self.status == 'Infer':
             logging.error(
@@ -402,14 +422,22 @@ class BaseClassifier(BaseModel):
             self.eval_data_loader = self.build_data_loader(
                 eval_dataset, batch_size=batch_size, mode='eval')
             logging.info(
-                "Start to evaluate(total_samples={}, total_steps={})...".format(
-                    eval_dataset.num_samples, eval_dataset.num_samples))
+                "Start to evaluate (total_samples={}, total_steps={})...".
+                format(eval_dataset.num_samples, eval_dataset.num_samples))
 
             top1s = []
             top5s = []
             with paddle.no_grad():
                 for step, data in enumerate(self.eval_data_loader):
-                    outputs = self.run(self.net, data, 'eval')
+                    if self.precision == 'fp16':
+                        with paddle.amp.auto_cast(
+                                level=self.amp_level,
+                                enable=True,
+                                custom_white_list=self.custom_white_list,
+                                custom_black_list=self.custom_black_list):
+                            outputs = self.run(self.net, data, 'eval')
+                    else:
+                        outputs = self.run(self.net, data, 'eval')
                     top1s.append(outputs["top1"])
                     top5s.append(outputs["top5"])
 

+ 43 - 11
paddlers/tasks/object_detector.py

@@ -210,7 +210,11 @@ class BaseDetector(BaseModel):
               early_stop=False,
               early_stop_patience=5,
               use_vdl=True,
-              resume_checkpoint=None):
+              resume_checkpoint=None,
+              precision='fp32',
+              amp_level='O1',
+              custom_white_list=None,
+              custom_black_list=None):
         """
         Train the model.
 
@@ -256,8 +260,22 @@ class BaseDetector(BaseModel):
                 training from. If None, no training checkpoint will be resumed. At most
                 Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
                 Defaults to None.
+            precision (str, optional): Use AMP (auto mixed precision) training if `precision`
+                is set to 'fp16'. Defaults to 'fp32'.
+            amp_level (str, optional): Auto mixed precision level. Accepted values are 'O1' 
+                and 'O2': At O1 level, the input data type of each operator will be casted 
+                according to a white list and a black list. At O2 level, all parameters and 
+                input data will be casted to FP16, except those for the operators in the black 
+                list, those without the support for FP16 kernel, and those for the batchnorm 
+                layers. Defaults to 'O1'.
+            custom_white_list(set|list|tuple|None, optional): Custom white list to use when 
+                `amp_level` is set to 'O1'. Defaults to None.
+            custom_black_list(set|list|tuple|None, optional): Custom black list to use in AMP 
+                training. Defaults to None.
         """
-
+        if precision != 'fp32':
+            raise ValueError("Currently, {} does not support AMP training.".
+                             format(self.__class__.__name__))
         args = self._pre_train(locals())
         args.pop('self')
         return self._real_train(**args)
@@ -265,12 +283,18 @@ class BaseDetector(BaseModel):
     def _pre_train(self, in_args):
         return in_args
 
-    def _real_train(
-            self, num_epochs, train_dataset, train_batch_size, eval_dataset,
-            optimizer, save_interval_epochs, log_interval_steps, save_dir,
-            pretrain_weights, learning_rate, warmup_steps, warmup_start_lr,
-            lr_decay_epochs, lr_decay_gamma, metric, use_ema, early_stop,
-            early_stop_patience, use_vdl, resume_checkpoint):
+    def _real_train(self, num_epochs, train_dataset, train_batch_size,
+                    eval_dataset, optimizer, save_interval_epochs,
+                    log_interval_steps, save_dir, pretrain_weights,
+                    learning_rate, warmup_steps, warmup_start_lr,
+                    lr_decay_epochs, lr_decay_gamma, metric, use_ema,
+                    early_stop, early_stop_patience, use_vdl, resume_checkpoint,
+                    precision, amp_level, custom_white_list, custom_black_list):
+
+        self.precision = precision
+        self.amp_level = amp_level
+        self.custom_white_list = custom_white_list
+        self.custom_black_list = custom_black_list
 
         if self.status == 'Infer':
             logging.error(
@@ -574,11 +598,19 @@ class BaseDetector(BaseModel):
             scores = collections.OrderedDict()
 
             logging.info(
-                "Start to evaluate(total_samples={}, total_steps={})...".format(
-                    eval_dataset.num_samples, eval_dataset.num_samples))
+                "Start to evaluate (total_samples={}, total_steps={})...".
+                format(eval_dataset.num_samples, eval_dataset.num_samples))
             with paddle.no_grad():
                 for step, data in enumerate(self.eval_data_loader):
-                    outputs = self.run(self.net, data, 'eval')
+                    if self.precision == 'fp16':
+                        with paddle.amp.auto_cast(
+                                level=self.amp_level,
+                                enable=True,
+                                custom_white_list=self.custom_white_list,
+                                custom_black_list=self.custom_black_list):
+                            outputs = self.run(self.net, data, 'eval')
+                    else:
+                        outputs = self.run(self.net, data, 'eval')
                     eval_metric.update(data, outputs)
                 eval_metric.accumulate()
                 self.eval_details = eval_metric.details

+ 41 - 10
paddlers/tasks/restorer.py

@@ -195,7 +195,11 @@ class BaseRestorer(BaseModel):
               early_stop=False,
               early_stop_patience=5,
               use_vdl=True,
-              resume_checkpoint=None):
+              resume_checkpoint=None,
+              precision='fp32',
+              amp_level='O1',
+              custom_white_list=None,
+              custom_black_list=None):
         """
         Train the model.
 
@@ -228,7 +232,26 @@ class BaseRestorer(BaseModel):
                 training from. If None, no training checkpoint will be resumed. At most
                 Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
                 Defaults to None.
+            precision (str, optional): Use AMP (auto mixed precision) training if `precision`
+                is set to 'fp16'. Defaults to 'fp32'.
+            amp_level (str, optional): Auto mixed precision level. Accepted values are 'O1' 
+                and 'O2': At O1 level, the input data type of each operator will be casted 
+                according to a white list and a black list. At O2 level, all parameters and 
+                input data will be casted to FP16, except those for the operators in the black 
+                list, those without the support for FP16 kernel, and those for the batchnorm 
+                layers. Defaults to 'O1'.
+            custom_white_list(set|list|tuple|None, optional): Custom white list to use when 
+                `amp_level` is set to 'O1'. Defaults to None.
+            custom_black_list(set|list|tuple|None, optional): Custom black list to use in AMP 
+                training. Defaults to None.
         """
+        if precision != 'fp32':
+            raise ValueError("Currently, {} does not support AMP training.".
+                             format(self.__class__.__name__))
+        self.precision = precision
+        self.amp_level = amp_level
+        self.custom_white_list = custom_white_list
+        self.custom_black_list = custom_black_list
 
         if self.status == 'Infer':
             logging.error(
@@ -415,11 +438,19 @@ class BaseRestorer(BaseModel):
             psnr = metrics.PSNR(crop_border=4, test_y_channel=True)
             ssim = metrics.SSIM(crop_border=4, test_y_channel=True)
             logging.info(
-                "Start to evaluate(total_samples={}, total_steps={})...".format(
-                    eval_dataset.num_samples, eval_dataset.num_samples))
+                "Start to evaluate (total_samples={}, total_steps={})...".
+                format(eval_dataset.num_samples, eval_dataset.num_samples))
             with paddle.no_grad():
                 for step, data in enumerate(self.eval_data_loader):
-                    outputs = self.run(self.net, data, 'eval')
+                    if self.precision == 'fp16':
+                        with paddle.amp.auto_cast(
+                                level=self.amp_level,
+                                enable=True,
+                                custom_white_list=self.custom_white_list,
+                                custom_black_list=self.custom_black_list):
+                            outputs = self.run(self.net, data, 'eval')
+                    else:
+                        outputs = self.run(self.net, data, 'eval')
                     psnr.update(outputs['pred'], outputs['tar'])
                     ssim.update(outputs['pred'], outputs['tar'])
 
@@ -699,7 +730,7 @@ class DRN(BaseRestorer):
             raise ValueError("Invalid `gan_mode`!")
         return outputs
 
-    def train_step(self, step, data, net):
+    def train_step(self, step, data, net, optimizer):
         outputs = self.run_gan(
             net, (data[0]['image'], data[0]['target']),
             mode='train',
@@ -709,9 +740,9 @@ class DRN(BaseRestorer):
                 net, (outputs['sr'], outputs['lr']),
                 mode='train',
                 gan_mode='forward_dual'))
-        self.optimizer.clear_grad()
+        optimizer.clear_grad()
         (outputs['loss_prim'] + outputs['loss_dual']).backward()
-        self.optimizer.step()
+        optimizer.step()
         return {
             'loss': outputs['loss_prim'] + outputs['loss_dual'],
             'loss_prim': outputs['loss_prim'],
@@ -858,9 +889,9 @@ class ESRGAN(BaseRestorer):
             raise ValueError("Invalid `gan_mode`!")
         return outputs
 
-    def train_step(self, step, data, net):
+    def train_step(self, step, data, net, optimizer):
         if self.use_gan:
-            optim_g, optim_d = self.optimizer
+            optim_g, optim_d = optimizer
 
             outputs = self.run_gan(
                 net, (data[0]['image'], data[0]['target']),
@@ -889,7 +920,7 @@ class ESRGAN(BaseRestorer):
                 'loss_d': outputs['loss_d']
             }
         else:
-            return super(ESRGAN, self).train_step(step, data, net)
+            return super(ESRGAN, self).train_step(step, data, net, optimizer)
 
     def _set_requires_grad(self, net, requires_grad):
         for p in net.parameters():

+ 31 - 3
paddlers/tasks/segmenter.py

@@ -220,7 +220,11 @@ class BaseSegmenter(BaseModel):
               early_stop=False,
               early_stop_patience=5,
               use_vdl=True,
-              resume_checkpoint=None):
+              resume_checkpoint=None,
+              precision='fp32',
+              amp_level='O1',
+              custom_white_list=None,
+              custom_black_list=None):
         """
         Train the model.
 
@@ -253,7 +257,23 @@ class BaseSegmenter(BaseModel):
                 training from. If None, no training checkpoint will be resumed. At most
                 Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
                 Defaults to None.
+            precision (str, optional): Use AMP (auto mixed precision) training if `precision`
+                is set to 'fp16'. Defaults to 'fp32'.
+            amp_level (str, optional): Auto mixed precision level. Accepted values are 'O1' 
+                and 'O2': At O1 level, the input data type of each operator will be casted 
+                according to a white list and a black list. At O2 level, all parameters and 
+                input data will be casted to FP16, except those for the operators in the black 
+                list, those without the support for FP16 kernel, and those for the batchnorm 
+                layers. Defaults to 'O1'.
+            custom_white_list(set|list|tuple|None, optional): Custom white list to use when 
+                `amp_level` is set to 'O1'. Defaults to None.
+            custom_black_list(set|list|tuple|None, optional): Custom black list to use in AMP 
+                training. Defaults to None.
         """
+        self.precision = precision
+        self.amp_level = amp_level
+        self.custom_white_list = custom_white_list
+        self.custom_black_list = custom_black_list
 
         if self.status == 'Infer':
             logging.error(
@@ -434,12 +454,20 @@ class BaseSegmenter(BaseModel):
         label_area_all = 0
         conf_mat_all = []
         logging.info(
-            "Start to evaluate(total_samples={}, total_steps={})...".format(
+            "Start to evaluate (total_samples={}, total_steps={})...".format(
                 eval_dataset.num_samples,
                 math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
         with paddle.no_grad():
             for step, data in enumerate(self.eval_data_loader):
-                outputs = self.run(self.net, data, 'eval')
+                if self.precision == 'fp16':
+                    with paddle.amp.auto_cast(
+                            level=self.amp_level,
+                            enable=True,
+                            custom_white_list=self.custom_white_list,
+                            custom_black_list=self.custom_black_list):
+                        outputs = self.run(self.net, data, 'eval')
+                else:
+                    outputs = self.run(self.net, data, 'eval')
                 pred_area = outputs['pred_area']
                 label_area = outputs['label_area']
                 intersect_area = outputs['intersect_area']

+ 1 - 1
paddlers/tasks/utils/det_metrics/coco_utils.py

@@ -135,7 +135,7 @@ def cocoapi_eval(anns,
         results_flatten = list(itertools.chain(*results_per_category))
         headers = ['category', 'AP'] * (num_columns // 2)
         results_2d = itertools.zip_longest(
-            *[results_flatten[i::num_columns] for i in range(num_columns)])
+            * [results_flatten[i::num_columns] for i in range(num_columns)])
         table_data = [headers]
         table_data += [result for result in results_2d]
         table = AsciiTable(table_data)