from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv from huggingface_sb3 import load_from_hub from stable_baselines3.dqn.dqn import DQN import torch from qpd.networks.models.student_six_model_dqn_ import StudentSixModelDQN from run import config from qpd.config import Config from qpd.networks.wrapper.model_wrapper import ModelWrapper def get_environment(config: Config): return make_vec_env("CartPole-v1", n_envs=config.evaluator_config.env_workers, vec_env_cls=DummyVecEnv, env_kwargs={"render_mode": "human"}) c = Config(get_environment, config) env = get_environment(c) # Pretrained original model checkpoint = load_from_hub( repo_id="sb3/dqn-CartPole-v1", filename="dqn-CartPole-v1.zip", ) custom_objects = { "learning_rate": 0.0, "lr_schedule": lambda _: 0.0, "clip_range": lambda _: 0.0, } model = DQN.load(checkpoint, env) #, custom_objects=custom_objects) teacher = ModelWrapper.construct_wrapper(c, model) c.init_env_model_based_params(teacher) # Loading quantized student student = StudentSixModelDQN(c) student.load("/home/ian/projects-idlab/euROBIN/code/framework/test/Data/compression/2024-06-14_15:56:55/best/student.pth") rewards_list = [] state = env.reset() steps = 0 rewards = 0 torch.onnx.export(student, torch.from_numpy(state), "/home/ian/projects-idlab/euROBIN/code/framework/test/Data/compression/2024-06-14_15:56:55/best/clean_student.onnx", verbose=True, input_names=["input_1"], output_names=["output_1"], opset_version=14, qat=True)