92 lines
2.8 KiB
Python
92 lines
2.8 KiB
Python
from argparse import ArgumentParser
|
|
import copy
|
|
|
|
def convert_nested_insert(dct, lst):
|
|
for x in lst[:-2]:
|
|
dct[x] = dct = dct.get(x, dict())
|
|
dct.update({lst[-2]: lst[-1]})
|
|
|
|
|
|
def convert_nested(dct):
|
|
# empty dict to store the result
|
|
result = dict()
|
|
|
|
# create an iterator of lists
|
|
# representing nested or hierarchical flow
|
|
lsts = ([*k.split("."), v] for k, v in dct.items())
|
|
|
|
# insert each list into the result
|
|
for lst in lsts:
|
|
convert_nested_insert(result, lst)
|
|
return result
|
|
|
|
|
|
def merge_config(source, destination):
|
|
destination = copy.deepcopy(destination)
|
|
for key, value in source.items():
|
|
if isinstance(value, dict):
|
|
if key in destination and isinstance(destination[key], dict):
|
|
destination_value = destination[key]
|
|
else:
|
|
destination_value = {}
|
|
destination[key] = merge_config(value, destination_value)
|
|
else:
|
|
try:
|
|
value = eval(value)
|
|
except Exception:
|
|
pass
|
|
|
|
destination[key] = value
|
|
|
|
return destination
|
|
|
|
|
|
def get_config_mods(known_args):
|
|
config_mods = dict(zip(map(lambda x: x[2:], known_args[1][:-1:2]), known_args[1][1::2]))
|
|
config_mods = convert_nested(config_mods)
|
|
return config_mods
|
|
|
|
|
|
def parse_args():
|
|
parser = ArgumentParser()
|
|
parser.add_argument('--config', type=str, default="experiment")
|
|
|
|
known_args = parser.parse_known_args()
|
|
|
|
args = known_args[0]
|
|
config_mods = get_config_mods(known_args)
|
|
|
|
return args, config_mods
|
|
|
|
|
|
def get_config():
|
|
args, config_mods = parse_args()
|
|
|
|
try:
|
|
config_module = __import__(f"configs.{args.config}", fromlist=[''])
|
|
configs = config_module.config
|
|
|
|
if config_mods is not None:
|
|
configs = merge_config(config_mods, configs)
|
|
|
|
return configs
|
|
except ModuleNotFoundError as ex:
|
|
print(f"ERROR: Could not find the experiment config file: {ex}")
|
|
quit()
|
|
|
|
|
|
def import_config(config, experiment_config):
|
|
config["network"]["conv1"] = experiment_config["network"]["conv1"]
|
|
config["network"]["conv2"] = experiment_config["network"]["conv2"]
|
|
config["network"]["conv3"] = experiment_config["network"]["conv3"]
|
|
config["network"]["linear"] = experiment_config["network"]["linear"]
|
|
config["network"]["teacher"] = experiment_config["network"]["teacher"]
|
|
if "critic_importance" in experiment_config["compression"]:
|
|
config["compression"]["critic_importance"] = experiment_config["compression"]["critic_importance"]
|
|
config["testing"]["randomness"] = experiment_config["testing"]["randomness"]
|
|
config["quantization"]["enabled"] = experiment_config["quantization"]["enabled"]
|
|
config["compression"]["epochs"] = experiment_config["compression"]["epochs"]
|
|
config["train_run"] = experiment_config["run_name"]
|
|
return config
|
|
|