from argparse import ArgumentParser import copy def convert_nested_insert(dct, lst): for x in lst[:-2]: dct[x] = dct = dct.get(x, dict()) dct.update({lst[-2]: lst[-1]}) def convert_nested(dct): # empty dict to store the result result = dict() # create an iterator of lists # representing nested or hierarchical flow lsts = ([*k.split("."), v] for k, v in dct.items()) # insert each list into the result for lst in lsts: convert_nested_insert(result, lst) return result def merge_config(source, destination): destination = copy.deepcopy(destination) for key, value in source.items(): if isinstance(value, dict): if key in destination and isinstance(destination[key], dict): destination_value = destination[key] else: destination_value = {} destination[key] = merge_config(value, destination_value) else: try: value = eval(value) except Exception: pass destination[key] = value return destination def get_config_mods(known_args): config_mods = dict(zip(map(lambda x: x[2:], known_args[1][:-1:2]), known_args[1][1::2])) config_mods = convert_nested(config_mods) return config_mods def parse_args(): parser = ArgumentParser() parser.add_argument('--config', type=str, default="experiment") known_args = parser.parse_known_args() args = known_args[0] config_mods = get_config_mods(known_args) return args, config_mods def get_config(): args, config_mods = parse_args() try: config_module = __import__(f"configs.{args.config}", fromlist=['']) configs = config_module.config if config_mods is not None: configs = merge_config(config_mods, configs) return configs except ModuleNotFoundError as ex: print(f"ERROR: Could not find the experiment config file: {ex}") quit() def import_config(config, experiment_config): config["network"]["conv1"] = experiment_config["network"]["conv1"] config["network"]["conv2"] = experiment_config["network"]["conv2"] config["network"]["conv3"] = experiment_config["network"]["conv3"] config["network"]["linear"] = experiment_config["network"]["linear"] config["network"]["teacher"] = experiment_config["network"]["teacher"] if "critic_importance" in experiment_config["compression"]: config["compression"]["critic_importance"] = experiment_config["compression"]["critic_importance"] config["testing"]["randomness"] = experiment_config["testing"]["randomness"] config["quantization"]["enabled"] = experiment_config["quantization"]["enabled"] config["compression"]["epochs"] = experiment_config["compression"]["epochs"] config["train_run"] = experiment_config["run_name"] return config