Explorar o código

Change model structure

Bobholamovic %!s(int64=2) %!d(string=hai) anos
pai
achega
e521a1d373

+ 12 - 4
examples/rs_research/README.md

@@ -73,7 +73,8 @@ class IterativeBIT(nn.Layer):
         super().__init__()
 
         if num_iters <= 0:
-            raise ValueError(f"`num_iters` should have positive value, but got {num_iters}.")
+            raise ValueError(
+                f"`num_iters` should have positive value, but got {num_iters}.")
 
         self.num_iters = num_iters
         self.gamma = gamma
@@ -97,8 +98,7 @@ class IterativeBIT(nn.Layer):
             # Get logits
             logits_list = self.bit(x1, x2)
             # Construct rate map
-            prob_map = F.softmax(logits_list[0], axis=1)
-            rate_map = self._constr_rate_map(prob_map)
+            rate_map = self._constr_rate_map(logits_list[0])
 
         return logits_list
     ...
@@ -157,6 +157,8 @@ class IterativeBIT(BaseChangeDetector):
 
 #### 3.4.3 实验结果
 
+VisualDL、定量指标
+
 ### 3.5 \*Magic Behind
 
 本小节涉及技术细节,对于本案例来说属于进阶内容,您可以选择性了解。
@@ -179,9 +181,15 @@ PaddleRS提供了,只需要。`attach_tools.Attach`对象自动。
 
 #### 4.3.1 LEVIR-CD数据集上的对比结果
 
+**目视效果对比**
+
+**定量指标对比**
+
 #### 4.3.2 SVCD数据集上的对比结果
 
-精度
+**目视效果对比**
+
+**定量指标对比**
 
 ## 5 总结与展望
 

+ 11 - 0
examples/rs_research/configs/levircd/custom_model/iterative_bit_iter2.yaml

@@ -0,0 +1,11 @@
+_base_: ../levircd.yaml
+
+save_dir: ./exp/levircd/custom_model/iter2/
+
+model: !Node
+    type: IterativeBIT
+    args:
+        num_iters: 2
+        num_classes: 2
+        bit_kwargs:
+            in_channels: 3

+ 11 - 0
examples/rs_research/configs/levircd/custom_model/iterative_bit_iter3.yaml

@@ -0,0 +1,11 @@
+_base_: ../levircd.yaml
+
+save_dir: ./exp/levircd/custom_model/iter3/
+
+model: !Node
+    type: IterativeBIT
+    args:
+        num_iters: 3
+        num_classes: 2
+        bit_kwargs:
+            in_channels: 3

+ 11 - 0
examples/rs_research/configs/levircd/custom_model/iterative_bit_iter4.yaml

@@ -0,0 +1,11 @@
+_base_: ../levircd.yaml
+
+save_dir: ./exp/levircd/custom_model/iter4/
+
+model: !Node
+    type: IterativeBIT
+    args:
+        num_iters: 4
+        num_classes: 2
+        bit_kwargs:
+            in_channels: 3

+ 2 - 0
examples/rs_research/configs/levircd/fc_ef.yaml

@@ -4,3 +4,5 @@ save_dir: ./exp/levircd/fc_ef/
 
 model: !Node
     type: FCEarlyFusion
+    args:
+        use_dropout: True

+ 2 - 0
examples/rs_research/configs/levircd/fc_siam_conc.yaml

@@ -4,3 +4,5 @@ save_dir: ./exp/levircd/fc_siam_conc/
 
 model: !Node
     type: FCSiamConc
+    args:
+        use_dropout: True

+ 2 - 0
examples/rs_research/configs/levircd/fc_siam_diff.yaml

@@ -4,3 +4,5 @@ save_dir: ./exp/levircd/fc_siam_diff/
 
 model: !Node
     type: FCSiamDiff
+    args:
+        use_dropout: True

+ 1 - 2
examples/rs_research/configs/svcd/custom_model.yaml

@@ -6,7 +6,6 @@ model: !Node
     type: IterativeBIT
     args:
         num_iters: 3
-        gamma: 0.5
         num_classes: 2
         bit_kwargs:
-            in_channels: 4
+            in_channels: 3

+ 2 - 0
examples/rs_research/configs/svcd/fc_ef.yaml

@@ -4,3 +4,5 @@ save_dir: ./exp/svcd/fc_ef/
 
 model: !Node
     type: FCEarlyFusion
+    args:
+        use_dropout: True

+ 2 - 0
examples/rs_research/configs/svcd/fc_siam_conc.yaml

@@ -4,3 +4,5 @@ save_dir: ./exp/svcd/fc_siam_conc/
 
 model: !Node
     type: FCSiamConc
+    args:
+        use_dropout: True

+ 2 - 0
examples/rs_research/configs/svcd/fc_siam_diff.yaml

@@ -4,3 +4,5 @@ save_dir: ./exp/svcd/fc_siam_diff/
 
 model: !Node
     type: FCSiamDiff
+    args:
+        use_dropout: True

+ 53 - 22
examples/rs_research/custom_model.py

@@ -9,16 +9,17 @@ attach = Attach.to(paddlers.rs_models.cd)
 
 
 @attach
-class IterativeBIT(nn.Layer):
-    def __init__(self, num_iters=1, gamma=0.1, num_classes=2, bit_kwargs=None):
-        super().__init__()
-
+class IterativeBIT(BIT):
+    def __init__(self,
+                 num_iters=1,
+                 feat_channels=32,
+                 num_classes=2,
+                 bit_kwargs=None):
         if num_iters <= 0:
             raise ValueError(
                 f"`num_iters` should have positive value, but got {num_iters}.")
 
         self.num_iters = num_iters
-        self.gamma = gamma
 
         if bit_kwargs is None:
             bit_kwargs = dict()
@@ -27,32 +28,62 @@ class IterativeBIT(nn.Layer):
             raise KeyError("'num_classes' should not be set in `bit_kwargs`.")
         bit_kwargs['num_classes'] = num_classes
 
-        self.bit = BIT(**bit_kwargs)
+        super().__init__(**bit_kwargs)
+
+        self.conv_fuse = nn.Sequential(
+            nn.Conv2D(feat_channels + 1, feat_channels, 1), nn.Sigmoid())
 
     def forward(self, t1, t2):
-        rate_map = self._init_rate_map(t1.shape)
+        # Extract features via shared backbone.
+        x1 = self.backbone(t1)
+        x2 = self.backbone(t2)
+
+        # Tokenization
+        if self.use_tokenizer:
+            token1 = self._get_semantic_tokens(x1)
+            token2 = self._get_semantic_tokens(x2)
+        else:
+            token1 = self._get_reshaped_tokens(x1)
+            token2 = self._get_reshaped_tokens(x2)
+
+        # Transformer encoder forward
+        token = paddle.concat([token1, token2], axis=1)
+        token = self.encode(token)
+        token1, token2 = paddle.chunk(token, 2, axis=1)
+
+        # Get initial rate map
+        rate_map = self._init_rate_map(x1.shape)
 
         for it in range(self.num_iters):
             # Construct inputs
-            x1 = self._constr_iter_input(t1, rate_map)
-            x2 = self._constr_iter_input(t2, rate_map)
-            # Get logits
-            logits_list = self.bit(x1, x2)
+            x1_iter = self._constr_iter_input(x1, rate_map)
+            x2_iter = self._constr_iter_input(x2, rate_map)
+
+            # Transformer decoder forward
+            y1 = self.decode(x1_iter, token1)
+            y2 = self.decode(x2_iter, token2)
+
+            # Feature differencing
+            y = paddle.abs(y1 - y2)
+
             # Construct rate map
-            prob_map = F.softmax(logits_list[0], axis=1)
-            rate_map = self._constr_rate_map(prob_map)
+            rate_map = self._constr_rate_map(y)
 
-        return logits_list
+        y = self.upsample(y)
+        pred = self.conv_out(y)
 
-    def _constr_iter_input(self, im, rate_map):
-        return paddle.concat([im, rate_map], axis=1)
+        return [pred]
 
     def _init_rate_map(self, im_shape):
         b, _, h, w = im_shape
-        return paddle.zeros((b, 1, h, w))
+        return paddle.full((b, 1, h, w), 0.5)
 
-    def _constr_rate_map(self, prob_map):
-        if prob_map.shape[1] != 2:
-            raise ValueError(
-                f"`prob_map.shape[1]` must be 2, but got {prob_map.shape[1]}.")
-        return (prob_map[:, 1:2] * self.gamma)
+    def _constr_iter_input(self, x, rate_map):
+        return self.conv_fuse(paddle.concat([x, rate_map], axis=1))
+
+    def _constr_rate_map(self, x):
+        rate_map = x.mean(1, keepdim=True).detach()  # Cut off gradient workflow
+        # min-max normalization
+        rate_map -= rate_map.min()
+        rate_map /= rate_map.max()
+        return rate_map

+ 2 - 2
examples/rs_research/custom_trainer.py

@@ -13,12 +13,12 @@ class IterativeBIT(BaseChangeDetector):
                  use_mixed_loss=False,
                  losses=None,
                  num_iters=1,
-                 gamma=0.1,
+                 feat_channels=32,
                  bit_kwargs=None,
                  **params):
         params.update({
             'num_iters': num_iters,
-            'gamma': gamma,
+            'feat_channels': feat_channels,
             'bit_kwargs': bit_kwargs
         })
         super().__init__(

+ 6 - 2
examples/rs_research/scripts/run_benchmark.sh

@@ -8,11 +8,15 @@ for dataset in levircd svcd; do
 
     mkdir -p "${log_dir}"
 
-    for config_file in $(ls ${config_dir}); do
+    for config_file in $(ls "${config_dir}"/*.yaml); do
+        filename="$(basename ${config_file})"
+        if [ "${filename}" = "${dataset}.yaml" ]; then
+            continue
+        fi
         printf '=%.0s' {1..100} && echo
         echo -e "\033[33m ${config_file} \033[0m"
         printf '=%.0s' {1..100} && echo
-        python run_task.py train cd --config "${config_dir}/${config_file}" 2>&1 | tee "${log_dir}/${config_file%.*}"
+        python run_task.py train cd --config "${config_file}" 2>&1 | tee "${log_dir}/${filename%.*}.log"
         echo
     done
 done

+ 3 - 2
examples/rs_research/scripts/run_parameter_analysis.sh

@@ -7,10 +7,11 @@ LOG_DIR='exp/logs/parameter_analysis'
 
 mkdir -p "${LOG_DIR}"
 
-for config_file in $(ls ${CONFIG_DIR}); do
+for config_file in $(ls "${CONFIG_DIR}"/*.yaml); do
+    filename="$(basename ${config_file})"
     printf '=%.0s' {1..100} && echo
     echo -e "\033[33m ${config_file} \033[0m"
     printf '=%.0s' {1..100} && echo
-    python run_task.py train cd --config "${CONFIG_DIR}/${config_file}" 2>&1 | tee "${LOG_DIR}/${config_file%.*}"
+    python run_task.py train cd --config "${config_file}" 2>&1 | tee "${LOG_DIR}/${filename%.*}.log"
     echo
 done

+ 4 - 1
test_tipc/common_func.sh

@@ -87,7 +87,10 @@ function download_and_unzip_dataset() {
     fi
 
     wget -nc -P "${ds_dir}" "${url}" --no-check-certificate
-    cd "${ds_dir}" && unzip "${zip_name}" && cd - \
+    
+    # The extracted file/directory must have the same name as the zip file.
+    cd "${ds_dir}" && unzip "${zip_name}" \
+        && mv "${zip_name%.*}" ${ds_name} && cd - \
         && echo "Successfully downloaded ${zip_name} from ${url}. File saved in ${ds_path}. "
 }
 

+ 62 - 0
test_tipc/configs/cd/_base_/levircd.yaml

@@ -0,0 +1,62 @@
+# Basic configurations of LEVIR-CD dataset
+
+datasets:
+    train: !Node
+        type: CDDataset
+        args: 
+            data_dir: ./test_tipc/data/levircd/
+            file_list: ./test_tipc/data/levircd/train.txt
+            label_list: null
+            num_workers: 0
+            shuffle: True
+            with_seg_labels: False
+            binarize_labels: True
+    eval: !Node
+        type: CDDataset
+        args:
+            data_dir: ./test_tipc/data/levircd/
+            file_list: ./test_tipc/data/levircd/val.txt
+            label_list: null
+            num_workers: 0
+            shuffle: False
+            with_seg_labels: False
+            binarize_labels: True
+transforms:
+    train:
+        - !Node
+          type: DecodeImg
+        - !Node
+          type: RandomHorizontalFlip
+          args:
+            prob: 0.5
+        - !Node
+          type: Normalize
+          args:
+            mean: [0.5, 0.5, 0.5]
+            std: [0.5, 0.5, 0.5]
+        - !Node
+          type: ArrangeChangeDetector
+          args: ['train']
+    eval:
+        - !Node
+          type: DecodeImg
+        - !Node
+          type: Normalize
+          args:
+            mean: [0.5, 0.5, 0.5]
+            std: [0.5, 0.5, 0.5]
+        - !Node
+          type: ArrangeChangeDetector
+          args: ['eval']
+download_on: False
+
+num_epochs: 10
+train_batch_size: 8
+save_interval_epochs: 5
+log_interval_steps: 50
+save_dir: ./test_tipc/output/cd/
+learning_rate: 0.002
+early_stop: False
+early_stop_patience: 5
+use_vdl: False
+resume_checkpoint: ''

+ 8 - 0
test_tipc/configs/cd/bit/bit_airchange.yaml

@@ -0,0 +1,8 @@
+# Basic configurations of BIT with AirChange dataset
+
+_base_: ../_base_/airchange.yaml
+
+save_dir: ./test_tipc/output/cd/bit/
+
+model: !Node
+    type: BIT

+ 8 - 0
test_tipc/configs/cd/bit/bit_levircd.yaml

@@ -0,0 +1,8 @@
+# Basic configurations of BIT with LEVIR-CD dataset
+
+_base_: ../_base_/levircd.yaml
+
+save_dir: ./test_tipc/output/cd/bit/
+
+model: !Node
+    type: BIT

+ 3 - 3
test_tipc/configs/cd/bit/train_infer_python.txt

@@ -8,12 +8,12 @@ use_gpu:null|null
 --save_dir:adaptive
 --train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4
 --model_path:null
+--config:lite_train_lite_infer=./test_tipc/configs/cd/bit/bit_airchange.yaml|lite_train_whole_infer=./test_tipc/configs/cd/bit/bit_airchange.yaml|whole_train_whole_infer=./test_tipc/configs/cd/bit/bit_levircd.yaml
 train_model_name:best_model
-train_infer_file_list:./test_tipc/data/airchange/:./test_tipc/data/airchange/eval.txt
 null:null
 ##
 trainer:norm
-norm_train:test_tipc/run_task.py train cd --config ./test_tipc/configs/cd/bit/bit.yaml
+norm_train:test_tipc/run_task.py train cd
 pact_train:null
 fpgm_train:null
 distill_train:null
@@ -46,7 +46,7 @@ inference:test_tipc/infer.py
 --use_trt:False
 --precision:fp32
 --model_dir:null
---file_list:null:null
+--config:null
 --save_log_path:null
 --benchmark:True
 --model_name:bit

+ 11 - 1
test_tipc/prepare.sh

@@ -27,7 +27,6 @@ DATA_DIR='./test_tipc/data/'
 mkdir -p "${DATA_DIR}"
 if [[ ${MODE} == 'lite_train_lite_infer' \
     || ${MODE} == 'lite_train_whole_infer' \
-    || ${MODE} == 'whole_train_whole_infer' \
     || ${MODE} == 'whole_infer' ]]; then
 
     if [[ ${task_name} == 'cd' ]]; then
@@ -40,4 +39,15 @@ if [[ ${MODE} == 'lite_train_lite_infer' \
         download_and_unzip_dataset "${DATA_DIR}" rsseg https://paddlers.bj.bcebos.com/datasets/rsseg_mini.zip
     fi
 
+elif [[ ${MODE} == 'whole_train_whole_infer' ]]; then
+
+    if [[ ${task_name} == 'cd' ]]; then
+        download_and_unzip_dataset "${DATA_DIR}" raw_levircd https://paddlers.bj.bcebos.com/datasets/raw/LEVIR-CD.zip \
+        && python tools/prepare_dataset/prepare_levircd.py \
+            --in_dataset_dir "${DATA_DIR}/raw_levircd" \
+            --out_dataset_dir "${DATA_DIR}/levircd" \
+            --crop_size 256 \
+            --crop_stride 256
+    fi
+
 fi

+ 40 - 46
test_tipc/test_train_inference_python.sh

@@ -22,15 +22,15 @@ train_use_gpu_value=$(func_parser_value "${lines[4]}")
 autocast_list=$(func_parser_value "${lines[5]}")
 autocast_key=$(func_parser_key "${lines[5]}")
 epoch_key=$(func_parser_key "${lines[6]}")
-epoch_num=$(func_parser_params "${lines[6]}")
+epoch_value=$(func_parser_params "${lines[6]}")
 save_model_key=$(func_parser_key "${lines[7]}")
 train_batch_key=$(func_parser_key "${lines[8]}")
 train_batch_value=$(func_parser_params "${lines[8]}")
 pretrain_model_key=$(func_parser_key "${lines[9]}")
 pretrain_model_value=$(func_parser_value "${lines[9]}")
-train_model_name=$(func_parser_value "${lines[10]}")
-train_infer_img_dir=$(parse_first_value "${lines[11]}")
-train_infer_img_file_list=$(parse_second_value "${lines[11]}")
+train_config_key=$(func_parser_key "${lines[10]}")
+train_config_value=$(func_parser_params "${lines[10]}")
+train_model_name=$(func_parser_value "${lines[11]}")
 train_param_key1=$(func_parser_key "${lines[12]}")
 train_param_value1=$(func_parser_value "${lines[12]}")
 
@@ -85,9 +85,8 @@ use_trt_list=$(func_parser_value "${lines[45]}")
 precision_key=$(func_parser_key "${lines[46]}")
 precision_list=$(func_parser_value "${lines[46]}")
 infer_model_key=$(func_parser_key "${lines[47]}")
-file_list_key=$(func_parser_key "${lines[48]}")
-infer_img_dir=$(parse_first_value "${lines[48]}")
-infer_img_file_list=$(parse_second_value "${lines[48]}")
+infer_config_key=$(func_parser_key "${lines[48]}")
+infer_config_value=$(func_parser_value "${lines[48]}")
 save_log_key=$(func_parser_key "${lines[49]}")
 benchmark_key=$(func_parser_key "${lines[50]}")
 benchmark_value=$(func_parser_value "${lines[50]}")
@@ -117,37 +116,37 @@ function func_inference() {
     local _script="$2"
     local _model_dir="$3"
     local _log_path="$4"
-    local _img_dir="$5"
-    local _file_list="$6"
+    local _config="$5"
+
+    local set_infer_config=$(func_set_params "${infer_config_key}" "${_config}")
+    local set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
+    local set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
+    local set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
+    local set_infer_params2=$(func_set_params "${infer_key2}" "${infer_value2}")
 
     # Do inference
     for use_gpu in ${use_gpu_list[*]}; do
+        local set_device=$(func_set_params "${use_gpu_key}" "${use_gpu}")
         if [ ${use_gpu} = 'False' ] || [ ${use_gpu} = 'cpu' ]; then
             for use_mkldnn in ${use_mkldnn_list[*]}; do
                 if [ ${use_mkldnn} = 'False' ]; then
                     continue
                 fi
-                for threads in ${cpu_threads_list[*]}; do
-                    for batch_size in ${batch_size_list[*]}; do
-                        for precision in ${precision_list[*]}; do
-                            if [ ${use_mkldnn} = 'False' ] && [ ${precision} = 'fp16' ]; then
-                                continue
-                            fi # Skip when enable fp16 but disable mkldnn
-
-                            set_precision=$(func_set_params "${precision_key}" "${precision}")
+                for precision in ${precision_list[*]}; do
+                    if [ ${use_mkldnn} = 'False' ] && [ ${precision} = 'fp16' ]; then
+                        continue
+                    fi # Skip when enable fp16 but disable mkldnn
 
-                            _save_log_path="${_log_path}/python_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_precision_${precision}_batchsize_${batch_size}.log"
-                            infer_value1="${_log_path}/python_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_precision_${precision}_batchsize_${batch_size}_results"
-                            set_device=$(func_set_params "${use_gpu_key}" "${use_gpu}")
-                            set_mkldnn=$(func_set_params "${use_mkldnn_key}" "${use_mkldnn}")
-                            set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
-                            set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
-                            set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}")
-                            set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
-                            set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
-                            set_infer_params2=$(func_set_params "${infer_key2}" "${infer_value2}")
+                    for threads in ${cpu_threads_list[*]}; do
+                        for batch_size in ${batch_size_list[*]}; do
+                            local _save_log_path="${_log_path}/python_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_precision_${precision}_batchsize_${batch_size}.log"
+                            local infer_value1="${_log_path}/python_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_precision_${precision}_batchsize_${batch_size}_results"
+                            local set_mkldnn=$(func_set_params "${use_mkldnn_key}" "${use_mkldnn}")
+                            local set_precision=$(func_set_params "${precision_key}" "${precision}")
+                            local set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}")
+                            local set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
                             
-                            cmd="${_python} ${_script} ${file_list_key} ${_img_dir} ${_file_list} ${set_device} ${set_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_benchmark} ${set_precision} ${set_infer_params1} ${set_infer_params2}"
+                            local cmd="${_python} ${_script} ${set_config} ${set_device} ${set_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_benchmark} ${set_precision} ${set_infer_params1} ${set_infer_params2}"
                             echo ${cmd}
                             run_command "${cmd}" "${_save_log_path}"
                             
@@ -165,24 +164,18 @@ function func_inference() {
                     fi # Skip when enable fp16 but disable trt
 
                     for batch_size in ${batch_size_list[*]}; do
-                        _save_log_path="${_log_path}/python_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log"
-                        infer_value1="${_log_path}/python_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}_results"
-                        set_device=$(func_set_params "${use_gpu_key}" "${use_gpu}")
-                        set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
-                        set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
-                        set_tensorrt=$(func_set_params "${use_trt_key}" "${use_trt}")
-                        set_precision=$(func_set_params "${precision_key}" "${precision}")
-                        set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
-                        set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
-                        set_infer_params2=$(func_set_params "${infer_key2}" "${infer_value2}")
+                        local _save_log_path="${_log_path}/python_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log"
+                        local infer_value1="${_log_path}/python_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}_results"
+                        local set_tensorrt=$(func_set_params "${use_trt_key}" "${use_trt}")
+                        local set_precision=$(func_set_params "${precision_key}" "${precision}")
+                        local set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
                         
-                        cmd="${_python} ${_script} ${file_list_key} ${_img_dir} ${_file_list} ${set_device} ${set_tensorrt} ${set_precision} ${set_model_dir} ${set_batchsize} ${set_benchmark} ${set_infer_params2}"
+                        local cmd="${_python} ${_script} ${set_config} ${set_device} ${set_tensorrt} ${set_precision} ${set_model_dir} ${set_batchsize} ${set_benchmark} ${set_infer_params2}"
                         echo ${cmd}
                         run_command "${cmd}" "${_save_log_path}"
 
                         last_status=${PIPESTATUS[0]}
                         status_check $last_status "${cmd}" "${status_log}" "${model_name}"
-
                     done
                 done
             done
@@ -226,7 +219,7 @@ if [ ${MODE} = 'whole_infer' ]; then
             save_infer_dir=${infer_model}
         fi
         # Run inference
-        func_inference "${python}" "${inference_py}" "${save_infer_dir}" "${OUT_PATH}" "${infer_img_dir}" "${infer_img_file_list}"
+        func_inference "${python}" "${inference_py}" "${save_infer_dir}" "${OUT_PATH}" "${infer_config_value}"
         count=$((${count} + 1))
     done
 else
@@ -285,8 +278,9 @@ else
                 if [ ${run_train} = 'null' ]; then
                     continue
                 fi
+                set_config=$(func_set_params "${train_config_key}" "${train_config_value}")
                 set_autocast=$(func_set_params "${autocast_key}" "${autocast}")
-                set_epoch=$(func_set_params "${epoch_key}" "${epoch_num}")
+                set_epoch=$(func_set_params "${epoch_key}" "${epoch_value}")
                 set_pretrain=$(func_set_params "${pretrain_model_key}" "${pretrain_model_value}")
                 set_batchsize=$(func_set_params "${train_batch_key}" "${train_batch_value}")
                 set_train_params1=$(func_set_params "${train_param_key1}" "${train_param_value1}")
@@ -312,11 +306,11 @@ else
 
                 set_save_model=$(func_set_params "${save_model_key}" "${save_dir}")
                 if [ ${#gpu} -le 2 ]; then  # Train with cpu or single gpu
-                    cmd="${python} ${run_train} ${set_use_gpu}  ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
+                    cmd="${python} ${run_train} ${set_config} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
                 elif [ ${#ips} -le 15 ]; then  # Train with multi-gpu
-                    cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
+                    cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_config} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
                 else     # Train with multi-machine
-                    cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_pretrain} ${set_epoch} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
+                    cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${set_config} ${set_use_gpu} ${set_save_model} ${set_pretrain} ${set_epoch} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
                 fi
 
                 echo ${cmd}
@@ -359,7 +353,7 @@ else
                     else
                         infer_model_dir=${save_infer_path}
                     fi
-                    func_inference "${python}" "${inference_py}" "${infer_model_dir}" "${OUT_PATH}" "${train_infer_img_dir}" "${train_infer_img_file_list}"
+                    func_inference "${python}" "${inference_py}" "${infer_model_dir}" "${OUT_PATH}" "${train_config_value}"
 
                     eval "unset CUDA_VISIBLE_DEVICES"
                 fi