Przeglądaj źródła

[Fix] Fix RS research example bugs

Bobholamovic 2 lat temu
rodzic
commit
6479c51d4f

+ 15 - 3
examples/rs_research/custom_trainer.py

@@ -23,12 +23,12 @@ from attach_tools import Attach
 attach = Attach.to(paddlers.tasks.change_detector)
 
 
-def make_trainer(net_type, *args, **kwargs):
+def make_trainer(net_type, attach_trainer=True):
     def _init_func(self,
                    num_classes=2,
                    use_mixed_loss=False,
                    losses=None,
-                   **params):
+                   **_params_):
         sig = inspect.signature(net_type.__init__)
         net_params = {
             k: p.default
@@ -36,7 +36,13 @@ def make_trainer(net_type, *args, **kwargs):
         }
         net_params.pop('self', None)
         net_params.pop('num_classes', None)
-        net_params.update(params)
+        # Special rule to parse arguments from `_params_`.
+        # When using pdrs.tasks.load_model, `_params_`` is a dict with the key '_params_'.
+        # This bypasses the dynamic modification/creation of function signature.
+        if '_params_' not in _params_:
+            net_params.update(_params_)
+        else:
+            net_params.update(_params_['_params_'])
 
         super(trainer_type, self).__init__(
             model_name=net_type.__name__,
@@ -52,7 +58,13 @@ def make_trainer(net_type, *args, **kwargs):
 
     trainer_type = type(trainer_name, (BaseChangeDetector, ),
                         {'__init__': _init_func})
+    if attach_trainer:
+        trainer_type = attach(trainer_type)
+    return trainer_type
 
+
+def make_trainer_and_build(net_type, *args, **kwargs):
+    trainer_type = make_trainer(net_type, attach_trainer=True)
     return trainer_type(*args, **kwargs)
 
 

+ 5 - 2
examples/rs_research/predict_cd.py

@@ -23,8 +23,8 @@ import paddle
 import paddlers
 from tqdm import tqdm
 
-import custom_model
-import custom_trainer
+from custom_model import CustomModel
+from custom_trainer import make_trainer
 
 
 def read_file_list(file_list, sep=' '):
@@ -57,6 +57,9 @@ def parse_args():
 if __name__ == '__main__':
     args = parse_args()
 
+    # 注册训练器
+    make_trainer(CustomModel)
+
     model = paddlers.tasks.load_model(args.model_dir)
 
     if not osp.exists(args.save_dir):

+ 5 - 4
examples/rs_research/train_cd.py

@@ -7,7 +7,7 @@ import paddlers as pdrs
 from paddlers import transforms as T
 
 from custom_model import CustomModel
-from custom_trainer import make_trainer
+from custom_trainer import make_trainer_and_build
 
 # 数据集路径
 DATA_DIR = 'data/levircd/'
@@ -75,8 +75,8 @@ test_dataset = pdrs.datasets.CDDataset(
     binarize_labels=True)
 
 # 构建自定义模型CustomModel并为其自动生成训练器
-# make_trainer()的首个参数为模型类型,剩余参数为模型构造所需参数
-model = make_trainer(CustomModel, in_channels=3)
+# make_trainer_and_build()的首个参数为模型类型,剩余参数为模型构造所需参数
+model = make_trainer_and_build(CustomModel, in_channels=3)
 
 # 构建学习率调度器
 # 使用定步长学习率衰减策略
@@ -108,4 +108,5 @@ model.train(
 # 加载验证集上效果最好的模型
 model = pdrs.tasks.load_model(osp.join(EXP_DIR, 'best_model'))
 # 在测试集上计算精度指标
-model.evaluate(test_dataset)
+res = model.evaluate(test_dataset)
+print(res)

+ 0 - 2
paddlers/tasks/change_detector.py

@@ -630,8 +630,6 @@ class BaseChangeDetector(BaseModel):
             if isinstance(im1, str) or isinstance(im2, str):
                 im1 = decode_image(im1, read_raw=True)
                 im2 = decode_image(im2, read_raw=True)
-                np.save('im1_whole.npy', im1)
-                np.save('im2_whole.npy', im2)
             ori_shape = im1.shape[:2]
             # XXX: sample do not contain 'image_t1' and 'image_t2'.
             sample = {'image': im1, 'image2': im2}