Browse Source

Refactor run_task.py

Bobholamovic 2 years ago
parent
commit
5a24513136
2 changed files with 28 additions and 36 deletions
  1. 14 18
      examples/rs_research/run_task.py
  2. 14 18
      test_tipc/run_task.py

+ 14 - 18
examples/rs_research/run_task.py

@@ -53,6 +53,17 @@ if __name__ == '__main__':
         paddlers.utils.download_and_decompress(
         paddlers.utils.download_and_decompress(
             cfg['download_url'], path=cfg['download_path'])
             cfg['download_url'], path=cfg['download_path'])
 
 
+    if not isinstance(cfg['datasets']['eval'].args, dict):
+        raise ValueError("args of eval dataset must be a dict!")
+    if cfg['datasets']['eval'].args.get('transforms', None) is not None:
+        raise ValueError(
+            "Found key 'transforms' in args of eval dataset and the value is not None."
+        )
+    eval_transforms = T.Compose(build_objects(cfg['transforms']['eval'], mod=T))
+    # Inplace modification
+    cfg['datasets']['eval'].args['transforms'] = eval_transforms
+    eval_dataset = build_objects(cfg['datasets']['eval'], mod=paddlers.datasets)
+
     if cfg['cmd'] == 'train':
     if cfg['cmd'] == 'train':
         if not isinstance(cfg['datasets']['train'].args, dict):
         if not isinstance(cfg['datasets']['train'].args, dict):
             raise ValueError("args of train dataset must be a dict!")
             raise ValueError("args of train dataset must be a dict!")
@@ -67,21 +78,8 @@ if __name__ == '__main__':
         cfg['datasets']['train'].args['transforms'] = train_transforms
         cfg['datasets']['train'].args['transforms'] = train_transforms
         train_dataset = build_objects(
         train_dataset = build_objects(
             cfg['datasets']['train'], mod=paddlers.datasets)
             cfg['datasets']['train'], mod=paddlers.datasets)
-    if not isinstance(cfg['datasets']['eval'].args, dict):
-        raise ValueError("args of eval dataset must be a dict!")
-    if cfg['datasets']['eval'].args.get('transforms', None) is not None:
-        raise ValueError(
-            "Found key 'transforms' in args of eval dataset and the value is not None."
-        )
-    eval_transforms = T.Compose(build_objects(cfg['transforms']['eval'], mod=T))
-    # Inplace modification
-    cfg['datasets']['eval'].args['transforms'] = eval_transforms
-    eval_dataset = build_objects(cfg['datasets']['eval'], mod=paddlers.datasets)
-
-    model = build_objects(
-        cfg['model'], mod=getattr(paddlers.tasks, cfg['task']))
-
-    if cfg['cmd'] == 'train':
+        model = build_objects(
+            cfg['model'], mod=getattr(paddlers.tasks, cfg['task']))
         if cfg['optimizer']:
         if cfg['optimizer']:
             if len(cfg['optimizer'].args) == 0:
             if len(cfg['optimizer'].args) == 0:
                 cfg['optimizer'].args = {}
                 cfg['optimizer'].args = {}
@@ -112,8 +110,6 @@ if __name__ == '__main__':
             resume_checkpoint=cfg['resume_checkpoint'] or None,
             resume_checkpoint=cfg['resume_checkpoint'] or None,
             **cfg['train'])
             **cfg['train'])
     elif cfg['cmd'] == 'eval':
     elif cfg['cmd'] == 'eval':
-        state_dict = paddle.load(
-            os.path.join(cfg['resume_checkpoint'], 'model.pdparams'))
-        model.net.set_state_dict(state_dict)
+        model = paddlers.tasks.load_model(cfg['resume_checkpoint'])
         res = model.evaluate(eval_dataset)
         res = model.evaluate(eval_dataset)
         print(res)
         print(res)

+ 14 - 18
test_tipc/run_task.py

@@ -51,6 +51,17 @@ if __name__ == '__main__':
         paddlers.utils.download_and_decompress(
         paddlers.utils.download_and_decompress(
             cfg['download_url'], path=cfg['download_path'])
             cfg['download_url'], path=cfg['download_path'])
 
 
+    if not isinstance(cfg['datasets']['eval'].args, dict):
+        raise ValueError("args of eval dataset must be a dict!")
+    if cfg['datasets']['eval'].args.get('transforms', None) is not None:
+        raise ValueError(
+            "Found key 'transforms' in args of eval dataset and the value is not None."
+        )
+    eval_transforms = T.Compose(build_objects(cfg['transforms']['eval'], mod=T))
+    # Inplace modification
+    cfg['datasets']['eval'].args['transforms'] = eval_transforms
+    eval_dataset = build_objects(cfg['datasets']['eval'], mod=paddlers.datasets)
+
     if cfg['cmd'] == 'train':
     if cfg['cmd'] == 'train':
         if not isinstance(cfg['datasets']['train'].args, dict):
         if not isinstance(cfg['datasets']['train'].args, dict):
             raise ValueError("args of train dataset must be a dict!")
             raise ValueError("args of train dataset must be a dict!")
@@ -65,21 +76,8 @@ if __name__ == '__main__':
         cfg['datasets']['train'].args['transforms'] = train_transforms
         cfg['datasets']['train'].args['transforms'] = train_transforms
         train_dataset = build_objects(
         train_dataset = build_objects(
             cfg['datasets']['train'], mod=paddlers.datasets)
             cfg['datasets']['train'], mod=paddlers.datasets)
-    if not isinstance(cfg['datasets']['eval'].args, dict):
-        raise ValueError("args of eval dataset must be a dict!")
-    if cfg['datasets']['eval'].args.get('transforms', None) is not None:
-        raise ValueError(
-            "Found key 'transforms' in args of eval dataset and the value is not None."
-        )
-    eval_transforms = T.Compose(build_objects(cfg['transforms']['eval'], mod=T))
-    # Inplace modification
-    cfg['datasets']['eval'].args['transforms'] = eval_transforms
-    eval_dataset = build_objects(cfg['datasets']['eval'], mod=paddlers.datasets)
-
-    model = build_objects(
-        cfg['model'], mod=getattr(paddlers.tasks, cfg['task']))
-
-    if cfg['cmd'] == 'train':
+        model = build_objects(
+            cfg['model'], mod=getattr(paddlers.tasks, cfg['task']))
         if cfg['optimizer']:
         if cfg['optimizer']:
             if len(cfg['optimizer'].args) == 0:
             if len(cfg['optimizer'].args) == 0:
                 cfg['optimizer'].args = {}
                 cfg['optimizer'].args = {}
@@ -110,8 +108,6 @@ if __name__ == '__main__':
             resume_checkpoint=cfg['resume_checkpoint'] or None,
             resume_checkpoint=cfg['resume_checkpoint'] or None,
             **cfg['train'])
             **cfg['train'])
     elif cfg['cmd'] == 'eval':
     elif cfg['cmd'] == 'eval':
-        state_dict = paddle.load(
-            os.path.join(cfg['resume_checkpoint'], 'model.pdparams'))
-        model.net.set_state_dict(state_dict)
+        model = paddlers.tasks.load_model(cfg['resume_checkpoint'])
         res = model.evaluate(eval_dataset)
         res = model.evaluate(eval_dataset)
         print(res)
         print(res)