Quellcode durchsuchen

[TIPC] Initialize TIPC (#8)

* Init TIPC

* Update gitignore

* Fix bugs

* Add docs and more models

* Fix bugs
Lin Manhui vor 2 Jahren
Ursprung
Commit
bddddc5164

+ 3 - 1
.gitignore

@@ -132,9 +132,11 @@ dmypy.json
 # Pyre type checker
 .pyre/
 
-# testdata
+# test data
 tutorials/train/change_detection/DataSet/
 tutorials/train/classification/DataSet/
 optic_disc_seg.tar
 optic_disc_seg/
 output/
+
+/log

+ 12 - 5
paddlers/tasks/__init__.py

@@ -12,9 +12,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import paddlers.tasks.object_detector as det
-import paddlers.tasks.segmenter as seg
-import paddlers.tasks.change_detector as cd
-import paddlers.tasks.classifier as clas
-import paddlers.tasks.image_restorer as res
+import paddlers.tasks.object_detector as detector
+import paddlers.tasks.segmenter as segmenter
+import paddlers.tasks.change_detector as change_detector
+import paddlers.tasks.classifier as classifier
+import paddlers.tasks.image_restorer as restorer
 from .load_model import load_model
+
+# Shorter aliases
+det = detector
+seg = segmenter
+cd = change_detector
+clas = classifier
+res = restorer

+ 7 - 2
paddlers/tasks/base.py

@@ -194,10 +194,15 @@ class BaseModel(metaclass=ModelMeta):
                 info['Transforms'] = list()
                 for op in self.test_transforms.transforms:
                     name = op.__class__.__name__
-                    if name.startswith('Arrange'):
-                        continue
                     attr = op.__dict__
                     info['Transforms'].append({name: attr})
+                arrange = self.test_transforms.arrange
+                if arrange is not None:
+                    info['Transforms'].append({
+                        arrange.__class__.__name__: {
+                            'mode': 'test'
+                        }
+                    })
         info['completed_epochs'] = self.completed_epochs
         return info
 

+ 4 - 1
paddlers/transforms/batch_operators.py

@@ -62,7 +62,10 @@ class BatchCompose(Transform):
                 for i in range(len(samples)):
                     tmp_data.append(samples[i][k])
                 if not 'gt_' in k and not 'is_crowd' in k and not 'difficult' in k:
-                    tmp_data = np.stack(tmp_data, axis=0)
+                    # This if assumes that all elements in tmp_data has the same type.
+                    if len(tmp_data) == 0 or not isinstance(tmp_data[0],
+                                                            (str, bytes)):
+                        tmp_data = np.stack(tmp_data, axis=0)
                 batch_data[k] = tmp_data
         return batch_data
 

+ 3 - 0
test_tipc/.gitignore

@@ -0,0 +1,3 @@
+/scripts/
+/output/
+/data/

+ 62 - 0
test_tipc/README.md

@@ -0,0 +1,62 @@
+# 飞桨训推一体全流程(TIPC)
+
+## 1. 简介
+
+飞桨除了基本的模型训练和预测,还提供了支持多端多平台的高性能推理部署工具。本文档提供了飞桨训推一体全流程(Training and Inference Pipeline Criterion(TIPC))信息和测试工具,方便用户查阅每种模型的训练推理部署打通情况,并可以进行一键测试。
+
+<div align="center">
+    <img src="docs/overview.png" width="1000">
+</div>
+
+## 2. 汇总信息
+
+打通情况汇总如下,已填写的部分表示可以使用本工具进行一键测试,未填写的表示正在支持中。
+
+**字段说明:**
+- 基础训练预测:指Linux GPU/CPU环境下的模型训练、Paddle Inference Python预测。
+- 更多训练方式:包括多机多卡、混合精度训练。
+- 更多部署方式:包括C++预测、Serving服务化部署、ARM端侧部署等多种部署方式,具体列表见[3.3节](#3.3)
+- Slim训练部署:包括PACT在线量化、离线量化。
+- 更多训练环境:包括Windows GPU/CPU、Linux NPU、Linux DCU等多种环境。
+
+
+| 任务类别 | 模型名称 | 基础<br>训练预测 | 更多<br>训练方式 | 更多<br>部署方式 | Slim<br>训练部署 |  更多<br>训练环境  |
+| :--- | :--- |  :----:  | :--------: |  :----:  |   :----:  |   :----:  |   :----:  |
+| 变化检测 | BIT | 支持 | - | - | - | - |
+| 场景分类 | HRNet | 支持 | - | - | - | - |
+| 目标检测 | PP-YOLO | 支持 | - | - | - | - |
+| 图像分割 | UNet | 支持 | - | - | - | - |
+
+
+## 3. 测试工具简介
+
+### 3.1 目录介绍
+
+```
+test_tipc
+    |--configs                                      # 配置目录
+    |    |--task_name                               # 任务名称
+    |           |--model_name                       # 模型名称
+    |                   |--train_infer_python.txt   # 基础训练推理测试配置文件
+    |--docs                                         # 文档目录
+    |   |--test_train_inference_python.md           # 基础训练推理测试说明文档
+    |----README.md                                  # TIPC说明文档
+    |----prepare.sh                                 # TIPC基础训练推理测试数据准备脚本
+    |----test_train_inference_python.sh             # TIPC基础训练推理测试解析脚本
+    |----common_func.sh                             # TIPC基础训练推理测试常用函数
+```
+
+### 3.2 测试流程概述
+
+使用本工具,可以测试不同功能的支持情况。测试过程包含:
+
+1. 准备数据与环境
+2. 运行测试脚本,观察不同配置是否运行成功。
+
+<a name="3.3"></a>
+### 3.3 开始测试
+
+请参考相应文档,完成指定功能的测试。
+
+- 基础训练预测测试:
+    - [Linux GPU/CPU 基础训练推理测试](docs/test_train_inference_python.md)

+ 131 - 0
test_tipc/common_func.sh

@@ -0,0 +1,131 @@
+#!/bin/bash
+
+function func_parser_key() {
+    strs=$1
+    IFS=":"
+    array=(${strs})
+    tmp=${array[0]}
+    echo ${tmp}
+}
+
+function func_parser_value() {
+    strs=$1
+    IFS=":"
+    array=(${strs})
+    tmp=${array[1]}
+    echo ${tmp}
+}
+
+function func_parser_value_lite() {
+    strs=$1
+    IFS=$2
+    array=(${strs})
+    tmp=${array[1]}
+    echo ${tmp}
+}
+
+function func_set_params() {
+    key=$1
+    value=$2
+    if [ ${key}x = "null"x ];then
+        echo " "
+    elif [[ ${value} = "null" ]] || [[ ${value} = " " ]] || [ ${#value} -le 0 ];then
+        echo " "
+    else 
+        echo "${key}=${value}"
+    fi
+}
+
+function func_parser_params() {
+    strs=$1
+    IFS=":"
+    array=(${strs})
+    key=${array[0]}
+    tmp=${array[1]}
+    IFS="|"
+    res=""
+    for _params in ${tmp[*]}; do
+        IFS="="
+        array=(${_params})
+        mode=${array[0]}
+        value=${array[1]}
+        if [[ ${mode} = ${MODE} ]]; then
+            IFS="|"
+            #echo $(func_set_params "${mode}" "${value}")
+            echo $value
+            break
+        fi
+        IFS="|"
+    done
+    echo ${res}
+}
+
+function status_check() {
+    local last_status=$1   # the exit code
+    local run_command=$2
+    local run_log=$3
+    local model_name=$4
+
+    if [ $last_status -eq 0 ]; then
+        echo -e "\033[33m Run successfully with command - ${model_name} - ${run_command}!  \033[0m" | tee -a ${run_log}
+    else
+        echo -e "\033[33m Run failed with command - ${model_name} - ${run_command}!  \033[0m" | tee -a ${run_log}
+    fi
+}
+
+function download_and_unzip_dataset() {
+    local ds_dir="$1"
+    local ds_name="$2"
+    local url="$3"
+    local clear="${4-True}"
+
+    local ds_path="${ds_dir}/${ds_name}"
+    local zip_name="${url##*/}"
+
+    if [ ${clear} = 'True' ]; then
+        rm -rf "${ds_path}"
+    fi
+
+    wget -nc -P "${ds_dir}" "${url}" --no-check-certificate
+    cd "${ds_dir}" && unzip "${zip_name}" && cd - \
+        && echo "Successfully downloaded ${zip_name} from ${url}. File saved in ${ds_path}. "
+}
+
+function parse_extra_args() {
+    local lines=("$@")
+    local last_idx=$((${#lines[@]}-1))
+    local IFS=';'
+    extra_args=(${lines[last_idx]})
+}
+
+function add_suffix() {
+    local ori_path="$1"
+    local suffix=$2
+    local ext="${ori_path##*.}"
+    echo "${ori_path%.*}${suffix}.${ext}"
+}
+
+function parse_first_value() {
+    local key_values=$1
+    local IFS=":"
+    local arr=(${key_values})
+    echo ${arr[1]}
+}
+
+function parse_second_value() {
+    local key_values=$1
+    local IFS=":"
+    local arr=(${key_values})
+    echo ${arr[2]}
+}
+
+function run_command() {
+    local cmd="$1"
+    local log_path="$2"
+    if [ -n "${log_path}" ]; then
+        eval ${cmd} | tee "${log_path}"
+        test ${PIPESTATUS[0]} -eq 0
+    else
+        eval ${cmd}
+    fi
+}

+ 252 - 0
test_tipc/config_utils.py

@@ -0,0 +1,252 @@
+#!/usr/bin/env python
+
+import argparse
+import os.path as osp
+from collections.abc import Mapping
+
+import yaml
+
+
+def _chain_maps(*maps):
+    chained = dict()
+    keys = set().union(*maps)
+    for key in keys:
+        vals = [m[key] for m in maps if key in m]
+        if isinstance(vals[0], Mapping):
+            chained[key] = _chain_maps(*vals)
+        else:
+            chained[key] = vals[0]
+    return chained
+
+
+def read_config(config_path):
+    with open(config_path, 'r', encoding='utf-8') as f:
+        cfg = yaml.safe_load(f)
+    return cfg or {}
+
+
+def parse_configs(cfg_path, inherit=True):
+    if inherit:
+        cfgs = []
+        cfgs.append(read_config(cfg_path))
+        while cfgs[-1].get('_base_'):
+            base_path = cfgs[-1].pop('_base_')
+            curr_dir = osp.dirname(cfg_path)
+            cfgs.append(
+                read_config(osp.normpath(osp.join(curr_dir, base_path))))
+        return _chain_maps(*cfgs)
+    else:
+        return read_config(cfg_path)
+
+
+def _cfg2args(cfg, parser, prefix=''):
+    node_keys = set()
+    for k, v in cfg.items():
+        opt = prefix + k
+        if isinstance(v, list):
+            if len(v) == 0:
+                parser.add_argument(
+                    '--' + opt, type=object, nargs='*', default=v)
+            else:
+                # Only apply to homogeneous lists
+                if isinstance(v[0], CfgNode):
+                    node_keys.add(opt)
+                parser.add_argument(
+                    '--' + opt, type=type(v[0]), nargs='*', default=v)
+        elif isinstance(v, dict):
+            # Recursively parse a dict
+            _, new_node_keys = _cfg2args(v, parser, opt + '.')
+            node_keys.update(new_node_keys)
+        elif isinstance(v, CfgNode):
+            node_keys.add(opt)
+            _, new_node_keys = _cfg2args(v.to_dict(), parser, opt + '.')
+            node_keys.update(new_node_keys)
+        elif isinstance(v, bool):
+            parser.add_argument('--' + opt, action='store_true', default=v)
+        else:
+            parser.add_argument('--' + opt, type=type(v), default=v)
+    return parser, node_keys
+
+
+def _args2cfg(cfg, args, node_keys):
+    args = vars(args)
+    for k, v in args.items():
+        pos = k.find('.')
+        if pos != -1:
+            # Iteratively parse a dict
+            dict_ = cfg
+            while pos != -1:
+                dict_.setdefault(k[:pos], {})
+                dict_ = dict_[k[:pos]]
+                k = k[pos + 1:]
+                pos = k.find('.')
+            dict_[k] = v
+        else:
+            cfg[k] = v
+
+    for k in node_keys:
+        pos = k.find('.')
+        if pos != -1:
+            # Iteratively parse a dict
+            dict_ = cfg
+            while pos != -1:
+                dict_.setdefault(k[:pos], {})
+                dict_ = dict_[k[:pos]]
+                k = k[pos + 1:]
+                pos = k.find('.')
+            v = dict_[k]
+            dict_[k] = [CfgNode(v_) for v_ in v] if isinstance(
+                v, list) else CfgNode(v)
+        else:
+            v = cfg[k]
+            cfg[k] = [CfgNode(v_) for v_ in v] if isinstance(
+                v, list) else CfgNode(v)
+
+    return cfg
+
+
+def parse_args(*args, **kwargs):
+    cfg_parser = argparse.ArgumentParser(add_help=False)
+    cfg_parser.add_argument('--config', type=str, default='')
+    cfg_parser.add_argument('--inherit_off', action='store_true')
+    cfg_args = cfg_parser.parse_known_args()[0]
+    cfg_path = cfg_args.config
+    inherit_on = not cfg_args.inherit_off
+
+    # Main parser
+    parser = argparse.ArgumentParser(
+        conflict_handler='resolve', parents=[cfg_parser])
+    # Global settings
+    parser.add_argument('cmd', choices=['train', 'eval'])
+    parser.add_argument('task', choices=['cd', 'clas', 'det', 'seg'])
+
+    # Data
+    parser.add_argument('--datasets', type=dict, default={})
+    parser.add_argument('--transforms', type=dict, default={})
+    parser.add_argument('--download_on', action='store_true')
+    parser.add_argument('--download_url', type=str, default='')
+    parser.add_argument('--download_path', type=str, default='./')
+
+    # Optimizer
+    parser.add_argument('--optimizer', type=dict, default={})
+
+    # Training related
+    parser.add_argument('--num_epochs', type=int, default=100)
+    parser.add_argument('--train_batch_size', type=int, default=8)
+    parser.add_argument('--save_interval_epochs', type=int, default=1)
+    parser.add_argument('--log_interval_steps', type=int, default=1)
+    parser.add_argument('--save_dir', default='../exp/')
+    parser.add_argument('--learning_rate', type=float, default=0.01)
+    parser.add_argument('--early_stop', action='store_true')
+    parser.add_argument('--early_stop_patience', type=int, default=5)
+    parser.add_argument('--use_vdl', action='store_true')
+    parser.add_argument('--resume_checkpoint', type=str)
+    parser.add_argument('--train', type=dict, default={})
+
+    # Loss
+    parser.add_argument('--losses', type=dict, nargs='+', default={})
+
+    # Model
+    parser.add_argument('--model', type=dict, default={})
+
+    if osp.exists(cfg_path):
+        cfg = parse_configs(cfg_path, inherit_on)
+        parser, node_keys = _cfg2args(cfg, parser, '')
+        args = parser.parse_args(*args, **kwargs)
+        return _args2cfg(dict(), args, node_keys)
+    elif cfg_path != '':
+        raise FileNotFoundError
+    else:
+        args = parser.parse_args()
+        return _args2cfg(dict(), args, set())
+
+
+class _CfgNodeMeta(yaml.YAMLObjectMetaclass):
+    def __call__(cls, obj):
+        if isinstance(obj, CfgNode):
+            return obj
+        return super(_CfgNodeMeta, cls).__call__(obj)
+
+
+class CfgNode(yaml.YAMLObject, metaclass=_CfgNodeMeta):
+    yaml_tag = u'!Node'
+    yaml_loader = yaml.SafeLoader
+    # By default use a lexical scope
+    ctx = globals()
+
+    def __init__(self, dict_):
+        super().__init__()
+        self.type = dict_['type']
+        self.args = dict_.get('args', [])
+        self.module = self._get_module(dict_.get('module', ''))
+
+    @classmethod
+    def set_context(cls, ctx):
+        # TODO: Implement dynamic scope with inspect.stack()
+        old_ctx = cls.ctx
+        cls.ctx = ctx
+        return old_ctx
+
+    def build_object(self, mod=None):
+        if mod is None:
+            mod = self.module
+        cls = getattr(mod, self.type)
+        if isinstance(self.args, list):
+            args = build_objects(self.args)
+            obj = cls(*args)
+        elif isinstance(self.args, dict):
+            args = build_objects(self.args)
+            obj = cls(**args)
+        else:
+            raise NotImplementedError
+        return obj
+
+    def _get_module(self, s):
+        mod = None
+        while s:
+            idx = s.find('.')
+            if idx == -1:
+                next_ = s
+                s = ''
+            else:
+                next_ = s[:idx]
+                s = s[idx + 1:]
+            if mod is None:
+                mod = self.ctx[next_]
+            else:
+                mod = getattr(mod, next_)
+        return mod
+
+    @staticmethod
+    def build_objects(cfg, mod=None):
+        if isinstance(cfg, list):
+            return [CfgNode.build_objects(c, mod=mod) for c in cfg]
+        elif isinstance(cfg, CfgNode):
+            return cfg.build_object(mod=mod)
+        elif isinstance(cfg, dict):
+            return {
+                k: CfgNode.build_objects(
+                    v, mod=mod)
+                for k, v in cfg.items()
+            }
+        else:
+            return cfg
+
+    def __repr__(self):
+        return f"(type={self.type}, args={self.args}, module={self.module or ' '})"
+
+    @classmethod
+    def from_yaml(cls, loader, node):
+        map_ = loader.construct_mapping(node)
+        return cls(map_)
+
+    def items(self):
+        yield from [('type', self.type), ('args', self.args), ('module',
+                                                               self.module)]
+
+    def to_dict(self):
+        return dict(self.items())
+
+
+def build_objects(cfg, mod=None):
+    return CfgNode.build_objects(cfg, mod=mod)

+ 70 - 0
test_tipc/configs/cd/_base_/airchange.yaml

@@ -0,0 +1,70 @@
+# Basic configurations of AirChange dataset
+
+datasets:
+    train: !Node
+        type: CDDataset
+        args: 
+            data_dir: ./test_tipc/data/airchange/
+            file_list: ./test_tipc/data/airchange/train.txt
+            label_list: null
+            num_workers: 0
+            shuffle: True
+            with_seg_labels: False
+            binarize_labels: True
+    eval: !Node
+        type: CDDataset
+        args:
+            data_dir: ./test_tipc/data/airchange/
+            file_list: ./test_tipc/data/airchange/eval.txt
+            label_list: null
+            num_workers: 0
+            shuffle: False
+            with_seg_labels: False
+            binarize_labels: True
+transforms:
+    train:
+        - !Node
+          type: DecodeImg
+        - !Node
+          type: RandomCrop
+          args:
+            crop_size: 256
+            aspect_ratio: [0.5, 2.0]
+            scaling: [0.2, 1.0]
+        - !Node
+          type: RandomHorizontalFlip
+          args:
+            prob: 0.5
+        - !Node
+          type: Normalize
+          args:
+            mean: [0.5, 0.5, 0.5]
+            std: [0.5, 0.5, 0.5]
+        - !Node
+          type: ArrangeChangeDetector
+          args: ['train']
+    eval:
+        - !Node
+          type: DecodeImg
+        - !Node
+          type: Normalize
+          args:
+            mean: [0.5, 0.5, 0.5]
+            std: [0.5, 0.5, 0.5]
+        - !Node
+          type: ArrangeChangeDetector
+          args: ['eval']
+download_on: False
+download_url: https://paddlers.bj.bcebos.com/datasets/airchange.zip
+download_path: ./test_tipc/data/
+
+num_epochs: 5
+train_batch_size: 4
+save_interval_epochs: 3
+log_interval_steps: 50
+save_dir: ./test_tipc/output/cd/
+learning_rate: 0.01
+early_stop: False
+early_stop_patience: 5
+use_vdl: False
+resume_checkpoint: ''

+ 8 - 0
test_tipc/configs/cd/bit/bit.yaml

@@ -0,0 +1,8 @@
+# Basic configurations of BIT
+
+_base_: ../_base_/airchange.yaml
+
+save_dir: ./test_tipc/output/cd/bit/
+
+model: !Node
+       type: BIT

+ 53 - 0
test_tipc/configs/cd/bit/train_infer_python.txt

@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:cd:bit
+python:python
+gpu_list:0|0,1
+use_gpu:null|null
+--precision:null
+--num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=10
+--save_dir:adaptive
+--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4
+--model_path:null
+train_model_name:best_model
+train_infer_file_list:./test_tipc/data/airchange/:./test_tipc/data/airchange/eval.txt
+null:null
+##
+trainer:norm
+norm_train:test_tipc/run_task.py train cd --config ./test_tipc/configs/cd/bit/bit.yaml
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:null
+null:null
+##
+===========================export_params===========================
+--save_dir:adaptive
+--model_dir:adaptive
+--fixed_input_shape:[-1,3,256,256]
+norm_export:deploy/export/export_model.py
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+===========================infer_params===========================
+infer_model:null
+infer_export:null
+infer_quant:False
+inference:test_tipc/infer.py
+--device:cpu|gpu
+--enable_mkldnn:True
+--cpu_threads:6
+--batch_size:1
+--use_trt:False
+--precision:fp32
+--model_dir:null
+--file_list:null:null
+--save_log_path:null
+--benchmark:True
+--model_name:bit
+null:null

+ 72 - 0
test_tipc/configs/clas/_base_/ucmerced.yaml

@@ -0,0 +1,72 @@
+# Basic configurations of UCMerced dataset
+
+datasets:
+    train: !Node
+        type: ClasDataset
+        args: 
+            data_dir: ./test_tipc/data/ucmerced/
+            file_list: ./test_tipc/data/ucmerced/train.txt
+            label_list: ./test_tipc/data/ucmerced/labels.txt
+            num_workers: 0
+            shuffle: True
+    eval: !Node
+        type: ClasDataset
+        args:
+            data_dir: ./test_tipc/data/ucmerced/
+            file_list: ./test_tipc/data/ucmerced/val.txt
+            label_list: ./test_tipc/data/ucmerced/labels.txt
+            num_workers: 0
+            shuffle: False
+transforms:
+    train:
+        - !Node
+          type: DecodeImg
+        - !Node
+          type: Resize
+          args:
+            target_size: 256
+        - !Node
+          type: RandomHorizontalFlip
+          args:
+            prob: 0.5
+        - !Node
+          type: RandomVerticalFlip
+          args:
+            prob: 0.5
+        - !Node
+          type: Normalize
+          args:
+            mean: [0.5, 0.5, 0.5]
+            std: [0.5, 0.5, 0.5]
+        - !Node
+          type: ArrangeClassifier
+          args: ['train']
+    eval:
+        - !Node
+          type: DecodeImg
+        - !Node
+          type: Resize
+          args:
+            target_size: 256
+        - !Node
+          type: Normalize
+          args:
+            mean: [0.5, 0.5, 0.5]
+            std: [0.5, 0.5, 0.5]
+        - !Node
+          type: ArrangeClassifier
+          args: ['eval']
+download_on: False
+download_url: https://paddlers.bj.bcebos.com/datasets/ucmerced.zip
+download_path: ./test_tipc/data/
+
+num_epochs: 2
+train_batch_size: 16
+save_interval_epochs: 5
+log_interval_steps: 50
+save_dir: e./test_tipc/output/clas/
+learning_rate: 0.01
+early_stop: False
+early_stop_patience: 5
+use_vdl: False
+resume_checkpoint: ''

+ 10 - 0
test_tipc/configs/clas/hrnet/hrnet.yaml

@@ -0,0 +1,10 @@
+# Basic configurations of HRNet
+
+_base_: ../_base_/ucmerced.yaml
+
+save_dir: ./test_tipc/output/clas/hrnet/
+
+model: !Node
+       type: HRNet_W18_C
+       args:
+           num_classes: 21

+ 53 - 0
test_tipc/configs/clas/hrnet/train_infer_python.txt

@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:clas:hrnet
+python:python
+gpu_list:0|0,1
+use_gpu:null|null
+--precision:null
+--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=10
+--save_dir:adaptive
+--train_batch_size:lite_train_lite_infer=16|lite_train_whole_infer=16|whole_train_whole_infer=16
+--model_path:null
+train_model_name:best_model
+train_infer_file_list:./test_tipc/data/ucmerced/:./test_tipc/data/ucmerced/val.txt
+null:null
+##
+trainer:norm
+norm_train:test_tipc/run_task.py train clas --config ./test_tipc/configs/clas/hrnet/hrnet.yaml
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:null
+null:null
+##
+===========================export_params===========================
+--save_dir:adaptive
+--model_dir:adaptive
+--fixed_input_shape:[-1,3,256,256]
+norm_export:deploy/export/export_model.py
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+===========================infer_params===========================
+infer_model:null
+infer_export:null
+infer_quant:False
+inference:test_tipc/infer.py
+--device:cpu|gpu
+--enable_mkldnn:True
+--cpu_threads:6
+--batch_size:1
+--use_trt:False
+--precision:fp32
+--model_dir:null
+--file_list:null:null
+--save_log_path:null
+--benchmark:True
+--model_name:hrnet
+null:null

+ 74 - 0
test_tipc/configs/det/_base_/sarship.yaml

@@ -0,0 +1,74 @@
+# Basic configurations of SARShip dataset
+
+datasets:
+    train: !Node
+        type: VOCDetDataset
+        args: 
+            data_dir: ./test_tipc/data/sarship/
+            file_list: ./test_tipc/data/sarship/train.txt
+            label_list: ./test_tipc/data/sarship/labels.txt
+            shuffle: True
+    eval: !Node
+        type: VOCDetDataset
+        args:
+            data_dir: ./test_tipc/data/sarship/
+            file_list: ./test_tipc/data/sarship/eval.txt
+            label_list: ./test_tipc/data/sarship/labels.txt
+            shuffle: False
+transforms:
+    train:
+        - !Node
+          type: DecodeImg
+        - !Node
+          type: RandomDistort
+        - !Node
+          type: RandomExpand
+        - !Node
+          type: RandomCrop
+        - !Node
+          type: RandomHorizontalFlip
+        - !Node
+          type: BatchRandomResize
+          args:
+            target_sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
+            interp: RANDOM
+        - !Node
+          type: Normalize
+          args:
+            mean: [0.485, 0.456, 0.406]
+            std: [0.229, 0.224, 0.225]
+        - !Node
+          type: ArrangeDetector
+          args: ['train']
+    eval:
+        - !Node
+          type: DecodeImg
+        - !Node
+          type: Resize
+          args:
+            target_size: 608
+            interp: CUBIC
+        - !Node
+          type: Normalize
+          args:
+            mean: [0.485, 0.456, 0.406]
+            std: [0.229, 0.224, 0.225]
+        - !Node
+          type: ArrangeDetector
+          args: ['eval']
+download_on: False
+download_url: https://paddlers.bj.bcebos.com/datasets/sarship.zip
+download_path: ./test_tipc/data/
+
+num_epochs: 10
+train_batch_size: 4
+save_interval_epochs: 5
+log_interval_steps: 4
+save_dir: ./test_tipc/output/det/
+learning_rate: 0.0005
+use_vdl: False
+resume_checkpoint: ''
+train:
+    pretrain_weights: COCO
+    warmup_steps: 0
+    warmup_start_lr: 0.0

+ 10 - 0
test_tipc/configs/det/ppyolo/ppyolo.yaml

@@ -0,0 +1,10 @@
+# Basic configurations of PP-YOLO
+
+_base_: ../_base_/sarship.yaml
+
+save_dir: ./test_tipc/output/det/ppyolo/
+
+model: !Node
+       type: PPYOLO
+       args:
+           num_classes: 1

+ 53 - 0
test_tipc/configs/det/ppyolo/train_infer_python.txt

@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:det:ppyolo
+python:python
+gpu_list:0|0,1
+use_gpu:null|null
+--precision:null
+--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=10
+--save_dir:adaptive
+--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4
+--model_path:null
+train_model_name:best_model
+train_infer_file_list:./test_tipc/data/sarship/:./test_tipc/data/sarship/eval.txt
+null:null
+##
+trainer:norm
+norm_train:test_tipc/run_task.py train det --config ./test_tipc/configs/det/ppyolo/ppyolo.yaml
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:null
+null:null
+##
+===========================export_params===========================
+--save_dir:adaptive
+--model_dir:adaptive
+--fixed_input_shape:[-1,3,608,608]
+norm_export:deploy/export/export_model.py
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+===========================infer_params===========================
+infer_model:null
+infer_export:null
+infer_quant:False
+inference:test_tipc/infer.py
+--device:cpu|gpu
+--enable_mkldnn:True
+--cpu_threads:6
+--batch_size:1
+--use_trt:False
+--precision:fp32
+--model_dir:null
+--file_list:null:null
+--save_log_path:null
+--benchmark:True
+--model_name:ppyolo
+null:null

+ 68 - 0
test_tipc/configs/seg/_base_/rsseg.yaml

@@ -0,0 +1,68 @@
+# Basic configurations of RSSeg dataset
+
+datasets:
+    train: !Node
+        type: SegDataset
+        args: 
+            data_dir: ./test_tipc/data/rsseg/
+            file_list: ./test_tipc/data/rsseg/train.txt
+            label_list: ./test_tipc/data/rsseg/labels.txt
+            num_workers: 0
+            shuffle: True
+    eval: !Node
+        type: SegDataset
+        args:
+            data_dir: ./test_tipc/data/rsseg/
+            file_list: ./test_tipc/data/rsseg/val.txt
+            label_list: ./test_tipc/data/rsseg/labels.txt
+            num_workers: 0
+            shuffle: False
+transforms:
+    train:
+        - !Node
+          type: DecodeImg
+        - !Node
+          type: Resize
+          args:
+            target_size: 512
+        - !Node
+          type: RandomHorizontalFlip
+          args:
+            prob: 0.5
+        - !Node
+          type: Normalize
+          args:
+            mean: [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]
+            std: [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]
+        - !Node
+          type: ArrangeSegmenter
+          args: ['train']
+    eval:
+        - !Node
+          type: DecodeImg
+        - !Node
+          type: Resize
+          args:
+            target_size: 512
+        - !Node
+          type: Normalize
+          args:
+            mean: [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]
+            std: [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]
+        - !Node
+          type: ArrangeSegmenter
+          args: ['eval']
+download_on: False
+download_url: https://paddlers.bj.bcebos.com/datasets/rsseg.zip
+download_path: ./test_tipc/data/
+
+num_epochs: 10
+train_batch_size: 4
+save_interval_epochs: 5
+log_interval_steps: 4
+save_dir: ./test_tipc/output/seg/
+learning_rate: 0.001
+early_stop: False
+early_stop_patience: 5
+use_vdl: False
+resume_checkpoint: ''

+ 53 - 0
test_tipc/configs/seg/unet/train_infer_python.txt

@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:seg:unet
+python:python
+gpu_list:0|0,1
+use_gpu:null|null
+--precision:null
+--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=10
+--save_dir:adaptive
+--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4
+--model_path:null
+train_model_name:best_model
+train_infer_file_list:./test_tipc/data/rsseg/:./test_tipc/data/rsseg/val.txt
+null:null
+##
+trainer:norm
+norm_train:test_tipc/run_task.py train seg --config ./test_tipc/configs/seg/unet/unet.yaml
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:null
+null:null
+##
+===========================export_params===========================
+--save_dir:adaptive
+--model_dir:adaptive
+--fixed_input_shape:[-1,10,512,512]
+norm_export:deploy/export/export_model.py
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+===========================infer_params===========================
+infer_model:null
+infer_export:null
+infer_quant:False
+inference:test_tipc/infer.py
+--device:cpu|gpu
+--enable_mkldnn:True
+--cpu_threads:6
+--batch_size:1
+--use_trt:False
+--precision:fp32
+--model_dir:null
+--file_list:null:null
+--save_log_path:null
+--benchmark:True
+--model_name:unet
+null:null

+ 11 - 0
test_tipc/configs/seg/unet/unet.yaml

@@ -0,0 +1,11 @@
+# Basic configurations of UNet
+
+_base_: ../_base_/rsseg.yaml
+
+save_dir: ./test_tipc/output/seg/unet/
+
+model: !Node
+       type: UNet
+       args:
+           input_channel: 10
+           num_classes: 5

BIN
test_tipc/docs/overview.png


+ 89 - 0
test_tipc/docs/test_train_inference_python.md

@@ -0,0 +1,89 @@
+# Linux GPU/CPU 基础训练推理测试
+
+Linux GPU/CPU 基础训练推理测试的主程序为`test_train_inference_python.sh`,可以测试基于Python的模型训练、评估、推理等基本功能。
+
+## 1. 测试结论汇总
+
+- 训练相关:
+
+| 任务类别 | 模型名称 | 单机单卡 | 单机多卡 |
+|  :----: |   :----:  |    :----:  |  :----:   |
+|  变化检测  | BIT | 正常训练 | 正常训练 |
+|  场景分类  | HRNet | 正常训练 | 正常训练 |
+|  目标检测  | PP-YOLO | 正常训练 | 正常训练 |
+|  图像分割  | UNet | 正常训练 | 正常训练 |
+
+- 推理相关:
+
+| 任务类别 | 模型名称 | device_CPU | device_GPU | batchsize |
+|  :----:   |  :----: |   :----:   |  :----:  |   :----:   |
+|  变化检测  |  BIT |  支持 | 支持 | 1 |
+|  场景分类  |  HRNet |  支持 | 支持 | 1 |
+|  目标检测  |  YOLO |  支持 | 支持 | 1 |
+|  图像分割  |  UNet |  支持 | 支持 | 1 |
+
+
+## 2. 测试流程
+
+### 2.1 环境配置
+
+除了安装PaddleRS以外,您还需要安装规范化日志输出工具AutoLog:
+```
+pip install  https://paddleocr.bj.bcebos.com/libs/auto_log-1.2.0-py3-none-any.whl
+```
+
+### 2.2 功能测试
+
+先运行`test_tipc/prepare.sh`准备数据和模型,然后运行`test_tipc/test_train_inference_python.sh`进行测试。测试过程中生成的日志文件均存储在`test_tipc/output/`目录。
+
+`test_tipc/test_train_inference_python.sh`支持4种运行模式,分别是:
+
+- 模式1:lite_train_lite_infer,使用少量数据训练,用于快速验证训练到预测的流程是否能走通,不验证精度和速度;
+```shell
+bash ./test_tipc/prepare.sh test_tipc/configs/clas/hrnet/train_infer_python.txt lite_train_lite_infer
+bash ./test_tipc/test_train_inference.sh test_tipc/configs/clas/hrnet/train_infer_python.txt lite_train_lite_infer
+```  
+
+- 模式2:lite_train_whole_infer,使用少量数据训练,全量数据预测,用于验证训练后的模型执行预测时预测速度是否合理;
+```shell
+bash ./test_tipc/prepare.sh test_tipc/configs/clas/hrnet/train_infer_python.txt lite_train_whole_infer
+bash ./test_tipc/test_train_inference.sh test_tipc/configs/clas/hrnet/train_infer_python.txt lite_train_whole_infer
+```  
+
+- 模式3:whole_infer,不训练,使用全量数据预测,验证模型动转静是否正常,检查模型的预测时间和精度;
+```shell
+bash ./test_tipc/prepare.sh test_tipc/configs/clas/hrnet/train_infer_python.txt whole_infer
+# 用法1:
+bash ./test_tipc/test_train_inference.sh test_tipc/configs/clas/hrnet/train_infer_python.txt whole_infer
+# 用法2: 在指定GPU上执行预测,第三个传入参数为GPU编号
+bash ./test_tipc/test_train_inference.sh test_tipc/configs/clas/hrnet/train_infer_python.txt whole_infer '1'
+```  
+
+- 模式4:whole_train_whole_infer,CE: 全量数据训练,全量数据预测,验证模型训练精度、预测精度、预测速度;
+```shell
+bash ./test_tipc/prepare.sh test_tipc/configs/clas/hrnet/train_infer_python.txt whole_train_whole_infer
+bash ./test_tipc/test_train_inference.sh test_tipc/configs/clas/hrnet/train_infer_python.txt whole_train_whole_infer
+```  
+
+运行相应指令后,在`test_tipc/output`目录中会自动保存运行日志。如lite_train_lite_infer模式下,该目录中可能存在以下文件:
+```
+test_tipc/output/[task name]/[model name]/
+|- results_python.log    # 存储指令执行状态的日志
+|- norm_gpus_0_autocast_null/  # GPU 0号卡上的训练日志和模型保存目录
+......
+|- python_infer_cpu_usemkldnn_True_threads_6_precision_fp32_batchsize_1.log  # CPU上开启mkldnn,线程数设置为6,测试batch_size=1条件下的预测运行日志
+|- python_infer_gpu_usetrt_True_precision_fp16_batchsize_1.log # GPU上开启TensorRT,测试batch_size=1的半精度预测运行日志
+......
+```
+
+其中`results_python.log`中保存了每条指令的执行状态。如果指令运行成功,输出信息如下所示:
+```
+ Run successfully with command - hrnet - python test_tipc/infer.py --file_list ./test_tipc/data/ucmerced/ ./test_tipc/data/ucmerced/val.txt --device=gpu --use_trt=False --precision=fp32 --model_dir=./test_tipc/output/clas/hrnet/lite_train_lite_infer/norm_gpus_0,1_autocast_null/static/ --batch_size=1 --benchmark=True  !
+......
+```
+
+如果运行失败,输出信息如下所示:
+```
+ Run failed with command - hrnet - python test_tipc/infer.py --file_list ./test_tipc/data/ucmerced/ ./test_tipc/data/ucmerced/val.txt --device=gpu --use_trt=False --precision=fp32 --model_dir=./test_tipc/output/clas/hrnet/lite_train_lite_infer/norm_gpus_0,1_autocast_null/static/ --batch_size=1 --benchmark=True  !
+......
+```

+ 316 - 0
test_tipc/infer.py

@@ -0,0 +1,316 @@
+#!/usr/bin/env python
+
+import os
+import os.path as osp
+import argparse
+from operator import itemgetter
+
+import numpy as np
+import paddle
+from paddle.inference import Config
+from paddle.inference import create_predictor
+from paddle.inference import PrecisionType
+from paddlers.tasks import load_model
+from paddlers.utils import logging
+
+
+class _bool(object):
+    def __new__(cls, x):
+        if isinstance(x, str):
+            if x.lower() == 'false':
+                return False
+            elif x.lower() == 'true':
+                return True
+        return bool.__new__(x)
+
+
+class TIPCPredictor(object):
+    def __init__(self,
+                 model_dir,
+                 device='cpu',
+                 gpu_id=0,
+                 cpu_thread_num=1,
+                 use_mkl=True,
+                 mkl_thread_num=4,
+                 use_trt=False,
+                 memory_optimize=True,
+                 trt_precision_mode='fp32',
+                 benchmark=False,
+                 model_name='',
+                 batch_size=1):
+        self.model_dir = model_dir
+        self._model = load_model(model_dir, with_net=False)
+
+        if trt_precision_mode.lower() == 'fp32':
+            trt_precision_mode = PrecisionType.Float32
+        elif trt_precision_mode.lower() == 'fp16':
+            trt_precision_mode = PrecisionType.Float16
+        else:
+            logging.error(
+                "TensorRT precision mode {} is invalid. Supported modes are fp32 and fp16."
+                .format(trt_precision_mode),
+                exit=True)
+
+        self.config = self.get_config(
+            device=device,
+            gpu_id=gpu_id,
+            cpu_thread_num=cpu_thread_num,
+            use_mkl=use_mkl,
+            mkl_thread_num=mkl_thread_num,
+            use_trt=use_trt,
+            use_glog=False,
+            memory_optimize=memory_optimize,
+            max_trt_batch_size=1,
+            trt_precision_mode=trt_precision_mode)
+        self.predictor = create_predictor(self.config)
+
+        self.batch_size = batch_size
+
+        if benchmark:
+            import auto_log
+            pid = os.getpid()
+            self.autolog = auto_log.AutoLogger(
+                model_name=model_name,
+                model_precision=trt_precision_mode,
+                batch_size=batch_size,
+                data_shape='dynamic',
+                save_path=None,
+                inference_config=self.config,
+                pids=pid,
+                process_name=None,
+                gpu_ids=0,
+                time_keys=[
+                    'preprocess_time', 'inference_time', 'postprocess_time'
+                ],
+                warmup=0,
+                logger=logging)
+        self.benchmark = benchmark
+
+    def get_config(self, device, gpu_id, cpu_thread_num, use_mkl,
+                   mkl_thread_num, use_trt, use_glog, memory_optimize,
+                   max_trt_batch_size, trt_precision_mode):
+        config = Config(
+            osp.join(self.model_dir, 'model.pdmodel'),
+            osp.join(self.model_dir, 'model.pdiparams'))
+
+        if device == 'gpu':
+            config.enable_use_gpu(200, gpu_id)
+            config.switch_ir_optim(True)
+            if use_trt:
+                if self._model.model_type == 'segmenter':
+                    logging.warning(
+                        "Semantic segmentation models do not support TensorRT acceleration, "
+                        "TensorRT is forcibly disabled.")
+                elif 'RCNN' in self._model.__class__.__name__:
+                    logging.warning(
+                        "RCNN models do not support TensorRT acceleration, "
+                        "TensorRT is forcibly disabled.")
+                else:
+                    config.enable_tensorrt_engine(
+                        workspace_size=1 << 10,
+                        max_batch_size=max_trt_batch_size,
+                        min_subgraph_size=3,
+                        precision_mode=trt_precision_mode,
+                        use_static=False,
+                        use_calib_mode=False)
+        else:
+            config.disable_gpu()
+            config.set_cpu_math_library_num_threads(cpu_thread_num)
+            if use_mkl:
+                if self._model.__class__.__name__ == 'MaskRCNN':
+                    logging.warning(
+                        "MaskRCNN does not support MKL-DNN, MKL-DNN is forcibly disabled"
+                    )
+                else:
+                    try:
+                        # Cache 10 different shapes for mkldnn to avoid memory leak
+                        config.set_mkldnn_cache_capacity(10)
+                        config.enable_mkldnn()
+                        config.set_cpu_math_library_num_threads(mkl_thread_num)
+                    except Exception as e:
+                        logging.warning(
+                            "The current environment does not support MKL-DNN, MKL-DNN is disabled."
+                        )
+                        pass
+
+        if not use_glog:
+            config.disable_glog_info()
+        if memory_optimize:
+            config.enable_memory_optim()
+        config.switch_use_feed_fetch_ops(False)
+        return config
+
+    def preprocess(self, images, transforms):
+        preprocessed_samples = self._model._preprocess(
+            images, transforms, to_tensor=False)
+        if self._model.model_type == 'classifier':
+            preprocessed_samples = {'image': preprocessed_samples[0]}
+        elif self._model.model_type == 'segmenter':
+            preprocessed_samples = {
+                'image': preprocessed_samples[0],
+                'ori_shape': preprocessed_samples[1]
+            }
+        elif self._model.model_type == 'detector':
+            pass
+        elif self._model.model_type == 'change_detector':
+            preprocessed_samples = {
+                'image': preprocessed_samples[0],
+                'image2': preprocessed_samples[1],
+                'ori_shape': preprocessed_samples[2]
+            }
+        else:
+            logging.error(
+                "Invalid model type {}".format(self._model.model_type),
+                exit=True)
+        return preprocessed_samples
+
+    def postprocess(self, net_outputs, topk=1, ori_shape=None, transforms=None):
+        if self._model.model_type == 'classifier':
+            true_topk = min(self._model.num_classes, topk)
+            if self._model._postprocess is None:
+                self._model.build_postprocess_from_labels(topk)
+            # XXX: Convert ndarray to tensor as self._model._postprocess requires
+            assert len(net_outputs) == 1
+            net_outputs = paddle.to_tensor(net_outputs[0])
+            outputs = self._model._postprocess(net_outputs)
+            class_ids = map(itemgetter('class_ids'), outputs)
+            scores = map(itemgetter('scores'), outputs)
+            label_names = map(itemgetter('label_names'), outputs)
+            preds = [{
+                'class_ids_map': l,
+                'scores_map': s,
+                'label_names_map': n,
+            } for l, s, n in zip(class_ids, scores, label_names)]
+        elif self._model.model_type in ('segmenter', 'change_detector'):
+            label_map, score_map = self._model._postprocess(
+                net_outputs,
+                batch_origin_shape=ori_shape,
+                transforms=transforms.transforms)
+            preds = [{
+                'label_map': l,
+                'score_map': s
+            } for l, s in zip(label_map, score_map)]
+        elif self._model.model_type == 'detector':
+            net_outputs = {
+                k: v
+                for k, v in zip(['bbox', 'bbox_num', 'mask'], net_outputs)
+            }
+            preds = self._model._postprocess(net_outputs)
+        else:
+            logging.error(
+                "Invalid model type {}.".format(self._model.model_type),
+                exit=True)
+
+        return preds
+
+    def _run(self, images, topk=1, transforms=None, time_it=False):
+        if self.benchmark and time_it:
+            self.autolog.times.start()
+
+        preprocessed_input = self.preprocess(images, transforms)
+
+        input_names = self.predictor.get_input_names()
+        for name in input_names:
+            input_tensor = self.predictor.get_input_handle(name)
+            input_tensor.copy_from_cpu(preprocessed_input[name])
+
+        if self.benchmark and time_it:
+            self.autolog.times.stamp()
+
+        self.predictor.run()
+
+        output_names = self.predictor.get_output_names()
+        net_outputs = []
+        for name in output_names:
+            output_tensor = self.predictor.get_output_handle(name)
+            net_outputs.append(output_tensor.copy_to_cpu())
+
+        if self.benchmark and time_it:
+            self.autolog.times.stamp()
+
+        res = self.postprocess(
+            net_outputs,
+            topk,
+            ori_shape=preprocessed_input.get('ori_shape', None),
+            transforms=transforms)
+
+        if self.benchmark and time_it:
+            self.autolog.times.end(stamp=True)
+
+        return res
+
+    def predict(self, data_dir, file_list, topk=1, warmup_iters=5):
+        transforms = self._model.test_transforms
+
+        # Warm up
+        iters = 0
+        while True:
+            for images in self._parse_lines(data_dir, file_list):
+                if iters >= warmup_iters:
+                    break
+                self._run(
+                    images=images,
+                    topk=topk,
+                    transforms=transforms,
+                    time_it=False)
+                iters += 1
+            else:
+                continue
+            break
+
+        results = []
+        for images in self._parse_lines(data_dir, file_list):
+            res = self._run(
+                images=images, topk=topk, transforms=transforms, time_it=True)
+            results.append(res)
+        return results
+
+    def _parse_lines(self, data_dir, file_list):
+        with open(file_list, 'r') as f:
+            batch = []
+            for line in f:
+                items = line.strip().split()
+                items = [osp.join(data_dir, item) for item in items]
+                if self._model.model_type == 'change_detector':
+                    batch.append((items[0], items[1]))
+                else:
+                    batch.append(items[0])
+                if len(batch) == self.batch_size:
+                    yield batch
+                    batch.clear()
+            if 0 < len(batch) < self.batch_size:
+                yield batch
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument('--file_list', type=str, nargs=2)
+    parser.add_argument('--model_dir', type=str, default='./')
+    parser.add_argument(
+        '--device', type=str, choices=['cpu', 'gpu'], default='cpu')
+    parser.add_argument('--enable_mkldnn', type=_bool, default=False)
+    parser.add_argument('--cpu_threads', type=int, default=10)
+    parser.add_argument('--use_trt', type=_bool, default=False)
+    parser.add_argument(
+        '--precision', type=str, choices=['fp32', 'fp16'], default='fp16')
+    parser.add_argument('--batch_size', type=int, default=1)
+    parser.add_argument('--benchmark', type=_bool, default=False)
+    parser.add_argument('--model_name', type=str, default='')
+
+    args = parser.parse_args()
+
+    predictor = TIPCPredictor(
+        args.model_dir,
+        device=args.device,
+        cpu_thread_num=args.cpu_threads,
+        use_mkl=args.enable_mkldnn,
+        mkl_thread_num=args.cpu_threads,
+        use_trt=args.use_trt,
+        trt_precision_mode=args.precision,
+        benchmark=args.benchmark)
+
+    predictor.predict(args.file_list[0], args.file_list[1])
+
+    if args.benchmark:
+        predictor.autolog.report()

+ 43 - 0
test_tipc/prepare.sh

@@ -0,0 +1,43 @@
+#!/usr/bin/env bash
+
+source test_tipc/common_func.sh
+
+set -o errexit
+set -o nounset
+
+FILENAME=$1
+# $MODE must be one of ('lite_train_lite_infer' 'lite_train_whole_infer' 'whole_train_whole_infer', 'whole_infer')
+MODE=$2
+
+dataline=$(cat ${FILENAME})
+
+# Parse params
+IFS=$'\n'
+lines=(${dataline})
+task_name=$(parse_first_value "${lines[1]}")
+model_name=$(parse_second_value "${lines[1]}")
+
+# Download pretrained weights
+if [ ${MODE} = 'whole_infer' ]; then
+    :
+fi
+
+# Download datasets
+DATA_DIR='./test_tipc/data/'
+mkdir -p "${DATA_DIR}"
+if [[ ${MODE} == 'lite_train_lite_infer' \
+    || ${MODE} == 'lite_train_whole_infer' \
+    || ${MODE} == 'whole_train_whole_infer' \
+    || ${MODE} == 'whole_infer' ]]; then
+
+    if [[ ${task_name} == 'cd' ]]; then
+        download_and_unzip_dataset "${DATA_DIR}" airchange https://paddlers.bj.bcebos.com/datasets/airchange.zip
+    elif [[ ${task_name} == 'clas' ]]; then
+        download_and_unzip_dataset "${DATA_DIR}" ucmerced https://paddlers.bj.bcebos.com/datasets/ucmerced.zip
+    elif [[ ${task_name} == 'det' ]]; then
+        download_and_unzip_dataset "${DATA_DIR}" sarship https://paddlers.bj.bcebos.com/datasets/sarship.zip
+    elif [[ ${task_name} == 'seg' ]]; then
+        download_and_unzip_dataset "${DATA_DIR}" rsseg https://paddlers.bj.bcebos.com/datasets/rsseg_mini.zip
+    fi
+
+fi

+ 103 - 0
test_tipc/run_task.py

@@ -0,0 +1,103 @@
+#!/usr/bin/env python
+
+import os
+
+# Import cv2 and sklearn before paddlers to solve the
+# "ImportError: dlopen: cannot load any more object with static TLS" issue.
+import cv2
+import sklearn
+import paddle
+import paddlers
+from paddlers import transforms as T
+
+from config_utils import parse_args, build_objects, CfgNode
+
+
+def format_cfg(cfg, indent=0):
+    s = ''
+    if isinstance(cfg, dict):
+        for i, (k, v) in enumerate(sorted(cfg.items())):
+            s += ' ' * indent + str(k) + ': '
+            if isinstance(v, (dict, list, CfgNode)):
+                s += '\n' + format_cfg(v, indent=indent + 1)
+            else:
+                s += str(v)
+            if i != len(cfg) - 1:
+                s += '\n'
+    elif isinstance(cfg, list):
+        for i, v in enumerate(cfg):
+            s += ' ' * indent + '- '
+            if isinstance(v, (dict, list, CfgNode)):
+                s += '\n' + format_cfg(v, indent=indent + 1)
+            else:
+                s += str(v)
+            if i != len(cfg) - 1:
+                s += '\n'
+    elif isinstance(cfg, CfgNode):
+        s += ' ' * indent + f"type: {cfg.type}" + '\n'
+        s += ' ' * indent + f"module: {cfg.module}" + '\n'
+        s += ' ' * indent + 'args: \n' + format_cfg(cfg.args, indent + 1)
+    return s
+
+
+if __name__ == '__main__':
+    CfgNode.set_context(globals())
+
+    cfg = parse_args()
+    print(format_cfg(cfg))
+
+    # Automatically download data
+    if cfg['download_on']:
+        paddlers.utils.download_and_decompress(
+            cfg['download_url'], path=cfg['download_path'])
+
+    if cfg['cmd'] == 'train':
+        train_dataset = build_objects(
+            cfg['datasets']['train'], mod=paddlers.datasets)
+        train_transforms = T.Compose(
+            build_objects(
+                cfg['transforms']['train'], mod=T))
+        # XXX: Late binding of transforms
+        train_dataset.transforms = train_transforms
+    eval_dataset = build_objects(cfg['datasets']['eval'], mod=paddlers.datasets)
+    eval_transforms = T.Compose(build_objects(cfg['transforms']['eval'], mod=T))
+    # XXX: Late binding of transforms
+    eval_dataset.transforms = eval_transforms
+
+    model = build_objects(
+        cfg['model'], mod=getattr(paddlers.tasks, cfg['task']))
+
+    if cfg['cmd'] == 'train':
+        if cfg['optimizer']:
+            if len(cfg['optimizer'].args) == 0:
+                cfg['optimizer'].args = {}
+            if not isinstance(cfg['optimizer'].args, dict):
+                raise TypeError
+            if cfg['optimizer'].args.get('parameters', None) is not None:
+                raise ValueError
+            cfg['optimizer'].args['parameters'] = model.net.parameters()
+            optimizer = build_objects(cfg['optimizer'], mod=paddle.optimizer)
+        else:
+            optimizer = None
+
+        model.train(
+            num_epochs=cfg['num_epochs'],
+            train_dataset=train_dataset,
+            train_batch_size=cfg['train_batch_size'],
+            eval_dataset=eval_dataset,
+            optimizer=optimizer,
+            save_interval_epochs=cfg['save_interval_epochs'],
+            log_interval_steps=cfg['log_interval_steps'],
+            save_dir=cfg['save_dir'],
+            learning_rate=cfg['learning_rate'],
+            early_stop=cfg['early_stop'],
+            early_stop_patience=cfg['early_stop_patience'],
+            use_vdl=cfg['use_vdl'],
+            resume_checkpoint=cfg['resume_checkpoint'] or None,
+            **cfg['train'])
+    elif cfg['cmd'] == 'eval':
+        state_dict = paddle.load(
+            os.path.join(cfg['resume_checkpoint'], 'model.pdparams'))
+        model.net.set_state_dict(state_dict)
+        res = model.evaluate(eval_dataset)
+        print(res)

+ 369 - 0
test_tipc/test_train_inference.sh

@@ -0,0 +1,369 @@
+#!/usr/bin/env bash
+
+source test_tipc/common_func.sh
+
+FILENAME=$1
+# $MODE be one of {'lite_train_lite_infer' 'lite_train_whole_infer' 'whole_train_whole_infer', 'whole_infer'}
+MODE=$2
+
+dataline=$(awk 'NR>=1{print}'  $FILENAME)
+
+# Parse params
+IFS=$'\n'
+lines=(${dataline})
+
+# Training params
+task_name=$(parse_first_value "${lines[1]}")
+model_name=$(parse_second_value "${lines[1]}")
+python=$(func_parser_value "${lines[2]}")
+gpu_list=$(func_parser_value "${lines[3]}")
+train_use_gpu_key=$(func_parser_key "${lines[4]}")
+train_use_gpu_value=$(func_parser_value "${lines[4]}")
+autocast_list=$(func_parser_value "${lines[5]}")
+autocast_key=$(func_parser_key "${lines[5]}")
+epoch_key=$(func_parser_key "${lines[6]}")
+epoch_num=$(func_parser_params "${lines[6]}")
+save_model_key=$(func_parser_key "${lines[7]}")
+train_batch_key=$(func_parser_key "${lines[8]}")
+train_batch_value=$(func_parser_params "${lines[8]}")
+pretrain_model_key=$(func_parser_key "${lines[9]}")
+pretrain_model_value=$(func_parser_value "${lines[9]}")
+train_model_name=$(func_parser_value "${lines[10]}")
+train_infer_img_dir=$(parse_first_value "${lines[11]}")
+train_infer_img_file_list=$(parse_second_value "${lines[11]}")
+train_param_key1=$(func_parser_key "${lines[12]}")
+train_param_value1=$(func_parser_value "${lines[12]}")
+
+trainer_list=$(func_parser_value "${lines[14]}")
+trainer_norm=$(func_parser_key "${lines[15]}")
+norm_trainer=$(func_parser_value "${lines[15]}")
+pact_key=$(func_parser_key "${lines[16]}")
+pact_trainer=$(func_parser_value "${lines[16]}")
+fpgm_key=$(func_parser_key "${lines[17]}")
+fpgm_trainer=$(func_parser_value "${lines[17]}")
+distill_key=$(func_parser_key "${lines[18]}")
+distill_trainer=$(func_parser_value "${lines[18]}")
+trainer_key1=$(func_parser_key "${lines[19]}")
+trainer_value1=$(func_parser_value "${lines[19]}")
+trainer_key2=$(func_parser_key "${lines[20]}")
+trainer_value2=$(func_parser_value "${lines[20]}")
+
+eval_py=$(func_parser_value "${lines[23]}")
+eval_key1=$(func_parser_key "${lines[24]}")
+eval_value1=$(func_parser_value "${lines[24]}")
+
+save_infer_key=$(func_parser_key "${lines[27]}")
+export_weight=$(func_parser_key "${lines[28]}")
+export_shape_key=$(func_parser_key "${lines[29]}")
+export_shape_value=$(func_parser_value "${lines[29]}")
+norm_export=$(func_parser_value "${lines[30]}")
+pact_export=$(func_parser_value "${lines[31]}")
+fpgm_export=$(func_parser_value "${lines[32]}")
+distill_export=$(func_parser_value "${lines[33]}")
+export_key1=$(func_parser_key "${lines[34]}")
+export_value1=$(func_parser_value "${lines[34]}")
+export_key2=$(func_parser_key "${lines[35]}")
+export_value2=$(func_parser_value "${lines[35]}")
+inference_dir=$(func_parser_value "${lines[36]}")
+
+# Params of inference model
+infer_model_dir_list=$(func_parser_value "${lines[37]}")
+infer_export_list=$(func_parser_value "${lines[38]}")
+infer_is_quant=$(func_parser_value "${lines[39]}")
+# Inference params
+inference_py=$(func_parser_value "${lines[40]}")
+use_gpu_key=$(func_parser_key "${lines[41]}")
+use_gpu_list=$(func_parser_value "${lines[41]}")
+use_mkldnn_key=$(func_parser_key "${lines[42]}")
+use_mkldnn_list=$(func_parser_value "${lines[42]}")
+cpu_threads_key=$(func_parser_key "${lines[43]}")
+cpu_threads_list=$(func_parser_value "${lines[43]}")
+batch_size_key=$(func_parser_key "${lines[44]}")
+batch_size_list=$(func_parser_value "${lines[44]}")
+use_trt_key=$(func_parser_key "${lines[45]}")
+use_trt_list=$(func_parser_value "${lines[45]}")
+precision_key=$(func_parser_key "${lines[46]}")
+precision_list=$(func_parser_value "${lines[46]}")
+infer_model_key=$(func_parser_key "${lines[47]}")
+file_list_key=$(func_parser_key "${lines[48]}")
+infer_img_dir=$(parse_first_value "${lines[48]}")
+infer_img_file_list=$(parse_second_value "${lines[48]}")
+save_log_key=$(func_parser_key "${lines[49]}")
+benchmark_key=$(func_parser_key "${lines[50]}")
+benchmark_value=$(func_parser_value "${lines[50]}")
+infer_key1=$(func_parser_key "${lines[51]}")
+infer_value1=$(func_parser_value "${lines[51]}")
+infer_key2=$(func_parser_key "${lines[52]}")
+infer_value2=$(func_parser_value "${lines[52]}")
+
+OUT_PATH="./test_tipc/output/${task_name}/${model_name}/${MODE}"
+mkdir -p ${OUT_PATH}
+status_log="${OUT_PATH}/results_python.log"
+echo "------------------------ ${MODE} ------------------------" >> "${status_log}"
+
+# Parse extra args
+parse_extra_args "${lines[@]}"
+for params in ${extra_args[*]}; do
+    IFS=':'
+    arr=(${params})
+    key=${arr[0]}
+    value=${arr[1]}
+    :
+done
+
+function func_inference() {
+    local IFS='|'
+    local _python=$1
+    local _script="$2"
+    local _model_dir="$3"
+    local _log_path="$4"
+    local _img_dir="$5"
+    local _file_list="$6"
+
+    # Do inference
+    for use_gpu in ${use_gpu_list[*]}; do
+        if [ ${use_gpu} = 'False' ] || [ ${use_gpu} = 'cpu' ]; then
+            for use_mkldnn in ${use_mkldnn_list[*]}; do
+                if [ ${use_mkldnn} = 'False' ]; then
+                    continue
+                fi
+                for threads in ${cpu_threads_list[*]}; do
+                    for batch_size in ${batch_size_list[*]}; do
+                        for precision in ${precision_list[*]}; do
+                            if [ ${use_mkldnn} = 'False' ] && [ ${precision} = 'fp16' ]; then
+                                continue
+                            fi # Skip when enable fp16 but disable mkldnn
+
+                            set_precision=$(func_set_params "${precision_key}" "${precision}")
+
+                            _save_log_path="${_log_path}/python_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_precision_${precision}_batchsize_${batch_size}.log"
+                            infer_value1="${_log_path}/python_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_precision_${precision}_batchsize_${batch_size}_results"
+                            set_device=$(func_set_params "${use_gpu_key}" "${use_gpu}")
+                            set_mkldnn=$(func_set_params "${use_mkldnn_key}" "${use_mkldnn}")
+                            set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
+                            set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
+                            set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}")
+                            set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
+                            set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
+                            set_infer_params2=$(func_set_params "${infer_key2}" "${infer_value2}")
+                            
+                            cmd="${_python} ${_script} ${file_list_key} ${_img_dir} ${_file_list} ${set_device} ${set_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_benchmark} ${set_precision} ${set_infer_params1} ${set_infer_params2}"
+                            echo ${cmd}
+                            run_command "${cmd}" "${_save_log_path}"
+                            
+                            last_status=${PIPESTATUS[0]}
+                            status_check ${last_status} "${cmd}" "${status_log}" "${model_name}"
+                        done
+                    done
+                done
+            done
+        elif [ ${use_gpu} = 'True' ] || [ ${use_gpu} = 'gpu' ]; then
+            for use_trt in ${use_trt_list[*]}; do
+                for precision in ${precision_list[*]}; do
+                    if [ ${precision} = 'fp16' ] && [ ${use_trt} = 'False' ]; then
+                        continue
+                    fi # Skip when enable fp16 but disable trt
+
+                    for batch_size in ${batch_size_list[*]}; do
+                        _save_log_path="${_log_path}/python_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log"
+                        infer_value1="${_log_path}/python_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}_results"
+                        set_device=$(func_set_params "${use_gpu_key}" "${use_gpu}")
+                        set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
+                        set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
+                        set_tensorrt=$(func_set_params "${use_trt_key}" "${use_trt}")
+                        set_precision=$(func_set_params "${precision_key}" "${precision}")
+                        set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
+                        set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
+                        set_infer_params2=$(func_set_params "${infer_key2}" "${infer_value2}")
+                        
+                        cmd="${_python} ${_script} ${file_list_key} ${_img_dir} ${_file_list} ${set_device} ${set_tensorrt} ${set_precision} ${set_model_dir} ${set_batchsize} ${set_benchmark} ${set_infer_params2}"
+                        echo ${cmd}
+                        run_command "${cmd}" "${_save_log_path}"
+
+                        last_status=${PIPESTATUS[0]}
+                        status_check $last_status "${cmd}" "${status_log}" "${model_name}"
+
+                    done
+                done
+            done
+        else
+            echo "Currently, hardwares other than CPU and GPU are not supported!"
+        fi
+    done
+}
+
+if [ ${MODE} = 'whole_infer' ]; then
+    GPUID=$3
+    if [ ${#GPUID} -le 0 ]; then
+        env=""
+    else
+        env="export CUDA_VISIBLE_DEVICES=${GPUID}"
+    fi
+    if [ ${infer_model_dir_list} == 'null' ]; then
+        echo -e "\033[33m No inference model is specified! \033[0m"
+        exit 1
+    fi
+    # Set CUDA_VISIBLE_DEVICES
+    eval ${env}
+    export count=0
+    IFS='|'
+    infer_run_exports=(${infer_export_list})
+    for infer_model in ${infer_model_dir_list[*]}; do
+        # Run export
+        if [ ${infer_run_exports[count]} != 'null' ]; then
+            save_infer_dir="${infer_model}/static"
+            set_export_weight=$(func_set_params "${export_weight}" "${infer_model}")
+            set_export_shape=$(func_set_params "${export_shape_key}" "${export_shape_value}")
+            set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}")
+            
+            export_cmd="${python} ${infer_run_exports[count]} ${set_export_weight} ${set_save_infer_key} ${set_export_shape}"
+            echo ${infer_run_exports[count]}
+            eval ${export_cmd}
+            
+            status_export=$?
+            status_check ${status_export} "${export_cmd}" "${status_log}" "${model_name}"
+        else
+            save_infer_dir=${infer_model}
+        fi
+        # Run inference
+        func_inference "${python}" "${inference_py}" "${save_infer_dir}" "${OUT_PATH}" "${infer_img_dir}" "${infer_img_file_list}"
+        count=$((${count} + 1))
+    done
+else
+    IFS='|'
+    export count=0
+    USE_GPU_KEY=(${train_use_gpu_value})
+    for gpu in ${gpu_list[*]}; do
+        train_use_gpu=${USE_GPU_KEY[count]}
+        count=$((${count} + 1))
+        ips=""
+        if [ ${gpu} = '-1' ]; then
+            env=""
+        elif [ ${#gpu} -le 1 ]; then
+            env="export CUDA_VISIBLE_DEVICES=${gpu}"
+            eval ${env}
+        elif [ ${#gpu} -le 15 ]; then
+            IFS=','
+            array=(${gpu})
+            env="export CUDA_VISIBLE_DEVICES=${array[0]}"
+            IFS='|'
+        else
+            IFS=';'
+            array=(${gpu})
+            ips=${array[0]}
+            gpu=${array[1]}
+            IFS='|'
+            env=""
+        fi
+        for autocast in ${autocast_list[*]}; do
+            if [ ${autocast} = 'amp' ]; then
+                set_amp_config="Global.use_amp=True Global.scale_loss=1024.0 Global.use_dynamic_loss_scaling=True"
+            else
+                set_amp_config=""
+            fi
+            for trainer in ${trainer_list[*]}; do
+                if [ ${trainer} = ${pact_key} ]; then
+                    run_train=${pact_trainer}
+                    run_export=${pact_export}
+                elif [ ${trainer} = "${fpgm_key}" ]; then
+                    run_train=${fpgm_trainer}
+                    run_export=${fpgm_export}
+                elif [ ${trainer} = "${distill_key}" ]; then
+                    run_train=${distill_trainer}
+                    run_export=${distill_export}
+                elif [ ${trainer} = ${trainer_key1} ]; then
+                    run_train=${trainer_value1}
+                    run_export=${export_value1}
+                elif [[ ${trainer} = ${trainer_key2} ]]; then
+                    run_train=${trainer_value2}
+                    run_export=${export_value2}
+                else
+                    run_train=${norm_trainer}
+                    run_export=${norm_export}
+                fi
+
+                if [ ${run_train} = 'null' ]; then
+                    continue
+                fi
+                set_autocast=$(func_set_params "${autocast_key}" "${autocast}")
+                set_epoch=$(func_set_params "${epoch_key}" "${epoch_num}")
+                set_pretrain=$(func_set_params "${pretrain_model_key}" "${pretrain_model_value}")
+                set_batchsize=$(func_set_params "${train_batch_key}" "${train_batch_value}")
+                set_train_params1=$(func_set_params "${train_param_key1}" "${train_param_value1}")
+                set_use_gpu=$(func_set_params "${train_use_gpu_key}" "${train_use_gpu}")
+                # If length of ips >= 15, then it is seen as multi-machine.
+                # 15 is the min length of ips info for multi-machine: 0.0.0.0,0.0.0.0
+                if [ ${#ips} -le 15 ]; then
+                    save_dir="${OUT_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}"
+                    nodes=1
+                else
+                    IFS=','
+                    ips_array=(${ips})
+                    IFS='|'
+                    nodes=${#ips_array[@]}
+                    save_dir="${OUT_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}_nodes_${nodes}"
+                fi
+                log_path="${OUT_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}_nodes_${nodes}.log"
+
+                # Load pretrained model from norm training if current trainer is pact or fpgm trainer.
+                if ([ ${trainer} = ${pact_key} ] || [ ${trainer} = ${fpgm_key} ]) && [ ${nodes} -le 1 ]; then
+                    set_pretrain="${load_norm_train_model}"
+                fi
+
+                set_save_model=$(func_set_params "${save_model_key}" "${save_dir}")
+                if [ ${#gpu} -le 2 ]; then  # Train with cpu or single gpu
+                    cmd="${python} ${run_train} ${set_use_gpu}  ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
+                elif [ ${#ips} -le 15 ]; then  # Train with multi-gpu
+                    cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
+                else     # Train with multi-machine
+                    cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_pretrain} ${set_epoch} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
+                fi
+
+                echo ${cmd}
+                # Run train
+                run_command "${cmd}" "${log_path}"
+                status_check $? "${cmd}" "${status_log}" "${model_name}"
+
+                if [[ "${cmd}" == *'paddle.distributed.launch'* ]]; then
+                    cat log/workerlog.0 >> ${log_path} 
+                fi
+
+                set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_dir}/${train_model_name}/model.pdparams")
+                # Save norm trained models to set pretrain for pact training and fpgm training
+                if [ ${trainer} = ${trainer_norm} ] && [ ${nodes} -le 1 ]; then
+                    load_norm_train_model=${set_eval_pretrain}
+                fi
+                # Run evaluation
+                if [ ${eval_py} != 'null' ]; then
+                    log_path="${OUT_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}_nodes_${nodes}_eval.log"
+                    set_eval_params1=$(func_set_params "${eval_key1}" "${eval_value1}")
+                    eval_cmd="${python} ${eval_py} ${set_eval_pretrain} ${set_use_gpu} ${set_eval_params1}"
+                    run_command "${eval_cmd}" "${log_path}"
+                    status_check $? "${eval_cmd}" "${status_log}" "${model_name}"
+                fi
+                # Run export model
+                if [ ${run_export} != 'null' ]; then
+                    log_path="${OUT_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}_nodes_${nodes}_export.log"
+                    save_infer_path="${save_dir}/static"
+                    set_export_weight=$(func_set_params "${export_weight}" "${save_dir}/${train_model_name}")
+                    set_export_shape=$(func_set_params "${export_shape_key}" "${export_shape_value}")
+                    set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_path}")
+                    export_cmd="${python} ${run_export} ${set_export_weight} ${set_save_infer_key} ${set_export_shape}"
+                    run_command "${export_cmd}" "${log_path}"
+                    status_check $? "${export_cmd}" "${status_log}" "${model_name}"
+
+                    # Run inference
+                    eval ${env}
+                    if [[ ${inference_dir} != 'null' ]] && [[ ${inference_dir} != '##' ]]; then
+                        infer_model_dir="${save_infer_path}/${inference_dir}"
+                    else
+                        infer_model_dir=${save_infer_path}
+                    fi
+                    func_inference "${python}" "${inference_py}" "${infer_model_dir}" "${OUT_PATH}" "${train_infer_img_dir}" "${train_infer_img_file_list}"
+
+                    eval "unset CUDA_VISIBLE_DEVICES"
+                fi
+            done  # Done with:    for trainer in ${trainer_list[*]}; do
+        done      # Done with:    for autocast in ${autocast_list[*]}; do
+    done          # Done with:    for gpu in ${gpu_list[*]}; do
+fi  # End if [ ${MODE} = 'infer' ]; then