QPD/src/qpd/utils/configuration.py

92 lines
2.8 KiB
Python

from argparse import ArgumentParser
import copy
def convert_nested_insert(dct, lst):
for x in lst[:-2]:
dct[x] = dct = dct.get(x, dict())
dct.update({lst[-2]: lst[-1]})
def convert_nested(dct):
# empty dict to store the result
result = dict()
# create an iterator of lists
# representing nested or hierarchical flow
lsts = ([*k.split("."), v] for k, v in dct.items())
# insert each list into the result
for lst in lsts:
convert_nested_insert(result, lst)
return result
def merge_config(source, destination):
destination = copy.deepcopy(destination)
for key, value in source.items():
if isinstance(value, dict):
if key in destination and isinstance(destination[key], dict):
destination_value = destination[key]
else:
destination_value = {}
destination[key] = merge_config(value, destination_value)
else:
try:
value = eval(value)
except Exception:
pass
destination[key] = value
return destination
def get_config_mods(known_args):
config_mods = dict(zip(map(lambda x: x[2:], known_args[1][:-1:2]), known_args[1][1::2]))
config_mods = convert_nested(config_mods)
return config_mods
def parse_args():
parser = ArgumentParser()
parser.add_argument('--config', type=str, default="experiment")
known_args = parser.parse_known_args()
args = known_args[0]
config_mods = get_config_mods(known_args)
return args, config_mods
def get_config():
args, config_mods = parse_args()
try:
config_module = __import__(f"configs.{args.config}", fromlist=[''])
configs = config_module.config
if config_mods is not None:
configs = merge_config(config_mods, configs)
return configs
except ModuleNotFoundError as ex:
print(f"ERROR: Could not find the experiment config file: {ex}")
quit()
def import_config(config, experiment_config):
config["network"]["conv1"] = experiment_config["network"]["conv1"]
config["network"]["conv2"] = experiment_config["network"]["conv2"]
config["network"]["conv3"] = experiment_config["network"]["conv3"]
config["network"]["linear"] = experiment_config["network"]["linear"]
config["network"]["teacher"] = experiment_config["network"]["teacher"]
if "critic_importance" in experiment_config["compression"]:
config["compression"]["critic_importance"] = experiment_config["compression"]["critic_importance"]
config["testing"]["randomness"] = experiment_config["testing"]["randomness"]
config["quantization"]["enabled"] = experiment_config["quantization"]["enabled"]
config["compression"]["epochs"] = experiment_config["compression"]["epochs"]
config["train_run"] = experiment_config["run_name"]
return config