|  | @@ -0,0 +1,253 @@
 | 
	
		
			
				|  |  | +#!/usr/bin/env python
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +import argparse
 | 
	
		
			
				|  |  | +import os.path as osp
 | 
	
		
			
				|  |  | +from collections.abc import Mapping
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +import yaml
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +def _chain_maps(*maps):
 | 
	
		
			
				|  |  | +    chained = dict()
 | 
	
		
			
				|  |  | +    keys = set().union(*maps)
 | 
	
		
			
				|  |  | +    for key in keys:
 | 
	
		
			
				|  |  | +        vals = [m[key] for m in maps if key in m]
 | 
	
		
			
				|  |  | +        if isinstance(vals[0], Mapping):
 | 
	
		
			
				|  |  | +            chained[key] = _chain_maps(*vals)
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            chained[key] = vals[0]
 | 
	
		
			
				|  |  | +    return chained
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +def read_config(config_path):
 | 
	
		
			
				|  |  | +    with open(config_path, 'r', encoding='utf-8') as f:
 | 
	
		
			
				|  |  | +        cfg = yaml.safe_load(f)
 | 
	
		
			
				|  |  | +    return cfg or {}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +def parse_configs(cfg_path, inherit=True):
 | 
	
		
			
				|  |  | +    if inherit:
 | 
	
		
			
				|  |  | +        cfgs = []
 | 
	
		
			
				|  |  | +        cfgs.append(read_config(cfg_path))
 | 
	
		
			
				|  |  | +        while cfgs[-1].get('_base_'):
 | 
	
		
			
				|  |  | +            base_path = cfgs[-1].pop('_base_')
 | 
	
		
			
				|  |  | +            curr_dir = osp.dirname(cfg_path)
 | 
	
		
			
				|  |  | +            cfgs.append(
 | 
	
		
			
				|  |  | +                read_config(osp.normpath(osp.join(curr_dir, base_path))))
 | 
	
		
			
				|  |  | +        return _chain_maps(*cfgs)
 | 
	
		
			
				|  |  | +    else:
 | 
	
		
			
				|  |  | +        return read_config(cfg_path)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +def _cfg2args(cfg, parser, prefix=''):
 | 
	
		
			
				|  |  | +    node_keys = set()
 | 
	
		
			
				|  |  | +    for k, v in cfg.items():
 | 
	
		
			
				|  |  | +        opt = prefix + k
 | 
	
		
			
				|  |  | +        if isinstance(v, list):
 | 
	
		
			
				|  |  | +            if len(v) == 0:
 | 
	
		
			
				|  |  | +                parser.add_argument(
 | 
	
		
			
				|  |  | +                    '--' + opt, type=object, nargs='*', default=v)
 | 
	
		
			
				|  |  | +            else:
 | 
	
		
			
				|  |  | +                # Only apply to homogeneous lists
 | 
	
		
			
				|  |  | +                if isinstance(v[0], CfgNode):
 | 
	
		
			
				|  |  | +                    node_keys.add(opt)
 | 
	
		
			
				|  |  | +                parser.add_argument(
 | 
	
		
			
				|  |  | +                    '--' + opt, type=type(v[0]), nargs='*', default=v)
 | 
	
		
			
				|  |  | +        elif isinstance(v, dict):
 | 
	
		
			
				|  |  | +            # Recursively parse a dict
 | 
	
		
			
				|  |  | +            _, new_node_keys = _cfg2args(v, parser, opt + '.')
 | 
	
		
			
				|  |  | +            node_keys.update(new_node_keys)
 | 
	
		
			
				|  |  | +        elif isinstance(v, CfgNode):
 | 
	
		
			
				|  |  | +            node_keys.add(opt)
 | 
	
		
			
				|  |  | +            _, new_node_keys = _cfg2args(v.to_dict(), parser, opt + '.')
 | 
	
		
			
				|  |  | +            node_keys.update(new_node_keys)
 | 
	
		
			
				|  |  | +        elif isinstance(v, bool):
 | 
	
		
			
				|  |  | +            parser.add_argument('--' + opt, action='store_true', default=v)
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            parser.add_argument('--' + opt, type=type(v), default=v)
 | 
	
		
			
				|  |  | +    return parser, node_keys
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +def _args2cfg(cfg, args, node_keys):
 | 
	
		
			
				|  |  | +    args = vars(args)
 | 
	
		
			
				|  |  | +    for k, v in args.items():
 | 
	
		
			
				|  |  | +        pos = k.find('.')
 | 
	
		
			
				|  |  | +        if pos != -1:
 | 
	
		
			
				|  |  | +            # Iteratively parse a dict
 | 
	
		
			
				|  |  | +            dict_ = cfg
 | 
	
		
			
				|  |  | +            while pos != -1:
 | 
	
		
			
				|  |  | +                dict_.setdefault(k[:pos], {})
 | 
	
		
			
				|  |  | +                dict_ = dict_[k[:pos]]
 | 
	
		
			
				|  |  | +                k = k[pos + 1:]
 | 
	
		
			
				|  |  | +                pos = k.find('.')
 | 
	
		
			
				|  |  | +            dict_[k] = v
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            cfg[k] = v
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    for k in node_keys:
 | 
	
		
			
				|  |  | +        pos = k.find('.')
 | 
	
		
			
				|  |  | +        if pos != -1:
 | 
	
		
			
				|  |  | +            # Iteratively parse a dict
 | 
	
		
			
				|  |  | +            dict_ = cfg
 | 
	
		
			
				|  |  | +            while pos != -1:
 | 
	
		
			
				|  |  | +                dict_.setdefault(k[:pos], {})
 | 
	
		
			
				|  |  | +                dict_ = dict_[k[:pos]]
 | 
	
		
			
				|  |  | +                k = k[pos + 1:]
 | 
	
		
			
				|  |  | +                pos = k.find('.')
 | 
	
		
			
				|  |  | +            v = dict_[k]
 | 
	
		
			
				|  |  | +            dict_[k] = [CfgNode(v_) for v_ in v] if isinstance(
 | 
	
		
			
				|  |  | +                v, list) else CfgNode(v)
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            v = cfg[k]
 | 
	
		
			
				|  |  | +            cfg[k] = [CfgNode(v_) for v_ in v] if isinstance(
 | 
	
		
			
				|  |  | +                v, list) else CfgNode(v)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    return cfg
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +def parse_args(*args, **kwargs):
 | 
	
		
			
				|  |  | +    cfg_parser = argparse.ArgumentParser(add_help=False)
 | 
	
		
			
				|  |  | +    cfg_parser.add_argument('--config', type=str, default='')
 | 
	
		
			
				|  |  | +    cfg_parser.add_argument('--inherit_off', action='store_true')
 | 
	
		
			
				|  |  | +    cfg_args = cfg_parser.parse_known_args()[0]
 | 
	
		
			
				|  |  | +    cfg_path = cfg_args.config
 | 
	
		
			
				|  |  | +    inherit_on = not cfg_args.inherit_off
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # Main parser
 | 
	
		
			
				|  |  | +    parser = argparse.ArgumentParser(
 | 
	
		
			
				|  |  | +        conflict_handler='resolve', parents=[cfg_parser])
 | 
	
		
			
				|  |  | +    # Global settings
 | 
	
		
			
				|  |  | +    parser.add_argument('cmd', choices=['train', 'eval'])
 | 
	
		
			
				|  |  | +    parser.add_argument('task', choices=['cd', 'clas', 'det', 'seg'])
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # Data
 | 
	
		
			
				|  |  | +    parser.add_argument('--datasets', type=dict, default={})
 | 
	
		
			
				|  |  | +    parser.add_argument('--transforms', type=dict, default={})
 | 
	
		
			
				|  |  | +    parser.add_argument('--download_on', action='store_true')
 | 
	
		
			
				|  |  | +    parser.add_argument('--download_url', type=str, default='')
 | 
	
		
			
				|  |  | +    parser.add_argument('--download_path', type=str, default='./')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # Optimizer
 | 
	
		
			
				|  |  | +    parser.add_argument('--optimizer', type=dict, default={})
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # Training related
 | 
	
		
			
				|  |  | +    parser.add_argument('--num_epochs', type=int, default=100)
 | 
	
		
			
				|  |  | +    parser.add_argument('--train_batch_size', type=int, default=8)
 | 
	
		
			
				|  |  | +    parser.add_argument('--save_interval_epochs', type=int, default=1)
 | 
	
		
			
				|  |  | +    parser.add_argument('--log_interval_steps', type=int, default=1)
 | 
	
		
			
				|  |  | +    parser.add_argument('--save_dir', default='../exp/')
 | 
	
		
			
				|  |  | +    parser.add_argument('--learning_rate', type=float, default=0.01)
 | 
	
		
			
				|  |  | +    parser.add_argument('--early_stop', action='store_true')
 | 
	
		
			
				|  |  | +    parser.add_argument('--early_stop_patience', type=int, default=5)
 | 
	
		
			
				|  |  | +    parser.add_argument('--use_vdl', action='store_true')
 | 
	
		
			
				|  |  | +    parser.add_argument('--resume_checkpoint', type=str)
 | 
	
		
			
				|  |  | +    parser.add_argument('--train', type=dict, default={})
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # Loss
 | 
	
		
			
				|  |  | +    parser.add_argument('--losses', type=dict, nargs='+', default={})
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # Model
 | 
	
		
			
				|  |  | +    parser.add_argument('--model', type=dict, default={})
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    if osp.exists(cfg_path):
 | 
	
		
			
				|  |  | +        cfg = parse_configs(cfg_path, inherit_on)
 | 
	
		
			
				|  |  | +        parser, node_keys = _cfg2args(cfg, parser, '')
 | 
	
		
			
				|  |  | +        node_keys = sorted(node_keys, reverse=True)
 | 
	
		
			
				|  |  | +        args = parser.parse_args(*args, **kwargs)
 | 
	
		
			
				|  |  | +        return _args2cfg(dict(), args, node_keys)
 | 
	
		
			
				|  |  | +    elif cfg_path != '':
 | 
	
		
			
				|  |  | +        raise FileNotFoundError
 | 
	
		
			
				|  |  | +    else:
 | 
	
		
			
				|  |  | +        args = parser.parse_args()
 | 
	
		
			
				|  |  | +        return _args2cfg(dict(), args, set())
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class _CfgNodeMeta(yaml.YAMLObjectMetaclass):
 | 
	
		
			
				|  |  | +    def __call__(cls, obj):
 | 
	
		
			
				|  |  | +        if isinstance(obj, CfgNode):
 | 
	
		
			
				|  |  | +            return obj
 | 
	
		
			
				|  |  | +        return super(_CfgNodeMeta, cls).__call__(obj)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class CfgNode(yaml.YAMLObject, metaclass=_CfgNodeMeta):
 | 
	
		
			
				|  |  | +    yaml_tag = u'!Node'
 | 
	
		
			
				|  |  | +    yaml_loader = yaml.SafeLoader
 | 
	
		
			
				|  |  | +    # By default use a lexical scope
 | 
	
		
			
				|  |  | +    ctx = globals()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def __init__(self, dict_):
 | 
	
		
			
				|  |  | +        super().__init__()
 | 
	
		
			
				|  |  | +        self.type = dict_['type']
 | 
	
		
			
				|  |  | +        self.args = dict_.get('args', [])
 | 
	
		
			
				|  |  | +        self.module = dict_.get('module', '')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @classmethod
 | 
	
		
			
				|  |  | +    def set_context(cls, ctx):
 | 
	
		
			
				|  |  | +        # TODO: Implement dynamic scope with inspect.stack()
 | 
	
		
			
				|  |  | +        old_ctx = cls.ctx
 | 
	
		
			
				|  |  | +        cls.ctx = ctx
 | 
	
		
			
				|  |  | +        return old_ctx
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def build_object(self, mod=None):
 | 
	
		
			
				|  |  | +        if mod is None:
 | 
	
		
			
				|  |  | +            mod = self._get_module(self.module)
 | 
	
		
			
				|  |  | +        cls = getattr(mod, self.type)
 | 
	
		
			
				|  |  | +        if isinstance(self.args, list):
 | 
	
		
			
				|  |  | +            args = build_objects(self.args)
 | 
	
		
			
				|  |  | +            obj = cls(*args)
 | 
	
		
			
				|  |  | +        elif isinstance(self.args, dict):
 | 
	
		
			
				|  |  | +            args = build_objects(self.args)
 | 
	
		
			
				|  |  | +            obj = cls(**args)
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            raise NotImplementedError
 | 
	
		
			
				|  |  | +        return obj
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _get_module(self, s):
 | 
	
		
			
				|  |  | +        mod = None
 | 
	
		
			
				|  |  | +        while s:
 | 
	
		
			
				|  |  | +            idx = s.find('.')
 | 
	
		
			
				|  |  | +            if idx == -1:
 | 
	
		
			
				|  |  | +                next_ = s
 | 
	
		
			
				|  |  | +                s = ''
 | 
	
		
			
				|  |  | +            else:
 | 
	
		
			
				|  |  | +                next_ = s[:idx]
 | 
	
		
			
				|  |  | +                s = s[idx + 1:]
 | 
	
		
			
				|  |  | +            if mod is None:
 | 
	
		
			
				|  |  | +                mod = self.ctx[next_]
 | 
	
		
			
				|  |  | +            else:
 | 
	
		
			
				|  |  | +                mod = getattr(mod, next_)
 | 
	
		
			
				|  |  | +        return mod
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @staticmethod
 | 
	
		
			
				|  |  | +    def build_objects(cfg, mod=None):
 | 
	
		
			
				|  |  | +        if isinstance(cfg, list):
 | 
	
		
			
				|  |  | +            return [CfgNode.build_objects(c, mod=mod) for c in cfg]
 | 
	
		
			
				|  |  | +        elif isinstance(cfg, CfgNode):
 | 
	
		
			
				|  |  | +            return cfg.build_object(mod=mod)
 | 
	
		
			
				|  |  | +        elif isinstance(cfg, dict):
 | 
	
		
			
				|  |  | +            return {
 | 
	
		
			
				|  |  | +                k: CfgNode.build_objects(
 | 
	
		
			
				|  |  | +                    v, mod=mod)
 | 
	
		
			
				|  |  | +                for k, v in cfg.items()
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            return cfg
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def __repr__(self):
 | 
	
		
			
				|  |  | +        return f"(type={self.type}, args={self.args}, module={self.module or ' '})"
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @classmethod
 | 
	
		
			
				|  |  | +    def from_yaml(cls, loader, node):
 | 
	
		
			
				|  |  | +        map_ = loader.construct_mapping(node)
 | 
	
		
			
				|  |  | +        return cls(map_)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def items(self):
 | 
	
		
			
				|  |  | +        yield from [('type', self.type), ('args', self.args), ('module',
 | 
	
		
			
				|  |  | +                                                               self.module)]
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def to_dict(self):
 | 
	
		
			
				|  |  | +        return dict(self.items())
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +def build_objects(cfg, mod=None):
 | 
	
		
			
				|  |  | +    return CfgNode.build_objects(cfg, mod=mod)
 |