|
@@ -53,6 +53,17 @@ if __name__ == '__main__':
|
|
|
paddlers.utils.download_and_decompress(
|
|
|
cfg['download_url'], path=cfg['download_path'])
|
|
|
|
|
|
+ if not isinstance(cfg['datasets']['eval'].args, dict):
|
|
|
+ raise ValueError("args of eval dataset must be a dict!")
|
|
|
+ if cfg['datasets']['eval'].args.get('transforms', None) is not None:
|
|
|
+ raise ValueError(
|
|
|
+ "Found key 'transforms' in args of eval dataset and the value is not None."
|
|
|
+ )
|
|
|
+ eval_transforms = T.Compose(build_objects(cfg['transforms']['eval'], mod=T))
|
|
|
+ # Inplace modification
|
|
|
+ cfg['datasets']['eval'].args['transforms'] = eval_transforms
|
|
|
+ eval_dataset = build_objects(cfg['datasets']['eval'], mod=paddlers.datasets)
|
|
|
+
|
|
|
if cfg['cmd'] == 'train':
|
|
|
if not isinstance(cfg['datasets']['train'].args, dict):
|
|
|
raise ValueError("args of train dataset must be a dict!")
|
|
@@ -67,21 +78,8 @@ if __name__ == '__main__':
|
|
|
cfg['datasets']['train'].args['transforms'] = train_transforms
|
|
|
train_dataset = build_objects(
|
|
|
cfg['datasets']['train'], mod=paddlers.datasets)
|
|
|
- if not isinstance(cfg['datasets']['eval'].args, dict):
|
|
|
- raise ValueError("args of eval dataset must be a dict!")
|
|
|
- if cfg['datasets']['eval'].args.get('transforms', None) is not None:
|
|
|
- raise ValueError(
|
|
|
- "Found key 'transforms' in args of eval dataset and the value is not None."
|
|
|
- )
|
|
|
- eval_transforms = T.Compose(build_objects(cfg['transforms']['eval'], mod=T))
|
|
|
- # Inplace modification
|
|
|
- cfg['datasets']['eval'].args['transforms'] = eval_transforms
|
|
|
- eval_dataset = build_objects(cfg['datasets']['eval'], mod=paddlers.datasets)
|
|
|
-
|
|
|
- model = build_objects(
|
|
|
- cfg['model'], mod=getattr(paddlers.tasks, cfg['task']))
|
|
|
-
|
|
|
- if cfg['cmd'] == 'train':
|
|
|
+ model = build_objects(
|
|
|
+ cfg['model'], mod=getattr(paddlers.tasks, cfg['task']))
|
|
|
if cfg['optimizer']:
|
|
|
if len(cfg['optimizer'].args) == 0:
|
|
|
cfg['optimizer'].args = {}
|
|
@@ -112,8 +110,6 @@ if __name__ == '__main__':
|
|
|
resume_checkpoint=cfg['resume_checkpoint'] or None,
|
|
|
**cfg['train'])
|
|
|
elif cfg['cmd'] == 'eval':
|
|
|
- state_dict = paddle.load(
|
|
|
- os.path.join(cfg['resume_checkpoint'], 'model.pdparams'))
|
|
|
- model.net.set_state_dict(state_dict)
|
|
|
+ model = paddlers.tasks.load_model(cfg['resume_checkpoint'])
|
|
|
res = model.evaluate(eval_dataset)
|
|
|
print(res)
|