|
@@ -52,17 +52,29 @@ if __name__ == '__main__':
|
|
|
cfg['download_url'], path=cfg['download_path'])
|
|
|
|
|
|
if cfg['cmd'] == 'train':
|
|
|
- train_dataset = build_objects(
|
|
|
- cfg['datasets']['train'], mod=paddlers.datasets)
|
|
|
+ if not isinstance(cfg['datasets']['train'].args, dict):
|
|
|
+ raise ValueError("args of train dataset must be a dict!")
|
|
|
+ if cfg['datasets']['train'].args.get('transforms', None) is not None:
|
|
|
+ raise ValueError(
|
|
|
+ "Found key 'transforms' in args of train dataset and the value is not None."
|
|
|
+ )
|
|
|
train_transforms = T.Compose(
|
|
|
build_objects(
|
|
|
cfg['transforms']['train'], mod=T))
|
|
|
- # XXX: Late binding of transforms
|
|
|
- train_dataset.transforms = train_transforms
|
|
|
- eval_dataset = build_objects(cfg['datasets']['eval'], mod=paddlers.datasets)
|
|
|
+ # Inplace modification
|
|
|
+ cfg['datasets']['train'].args['transforms'] = train_transforms
|
|
|
+ train_dataset = build_objects(
|
|
|
+ cfg['datasets']['train'], mod=paddlers.datasets)
|
|
|
+ if not isinstance(cfg['datasets']['eval'].args, dict):
|
|
|
+ raise ValueError("args of eval dataset must be a dict!")
|
|
|
+ if cfg['datasets']['eval'].args.get('transforms', None) is not None:
|
|
|
+ raise ValueError(
|
|
|
+ "Found key 'transforms' in args of eval dataset and the value is not None."
|
|
|
+ )
|
|
|
eval_transforms = T.Compose(build_objects(cfg['transforms']['eval'], mod=T))
|
|
|
- # XXX: Late binding of transforms
|
|
|
- eval_dataset.transforms = eval_transforms
|
|
|
+ # Inplace modification
|
|
|
+ cfg['datasets']['eval'].args['transforms'] = eval_transforms
|
|
|
+ eval_dataset = build_objects(cfg['datasets']['eval'], mod=paddlers.datasets)
|
|
|
|
|
|
model = build_objects(
|
|
|
cfg['model'], mod=getattr(paddlers.tasks, cfg['task']))
|
|
@@ -72,9 +84,11 @@ if __name__ == '__main__':
|
|
|
if len(cfg['optimizer'].args) == 0:
|
|
|
cfg['optimizer'].args = {}
|
|
|
if not isinstance(cfg['optimizer'].args, dict):
|
|
|
- raise TypeError
|
|
|
+ raise TypeError("args of optimizer must be a dict!")
|
|
|
if cfg['optimizer'].args.get('parameters', None) is not None:
|
|
|
- raise ValueError
|
|
|
+ raise ValueError(
|
|
|
+ "Found key 'parameters' in args of optimizer and the value is not None."
|
|
|
+ )
|
|
|
cfg['optimizer'].args['parameters'] = model.net.parameters()
|
|
|
optimizer = build_objects(cfg['optimizer'], mod=paddle.optimizer)
|
|
|
else:
|