QPD/examples/dqn_to_onnx_cartpole.py

46 lines
1.5 KiB
Python

from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from huggingface_sb3 import load_from_hub
from stable_baselines3.dqn.dqn import DQN
import torch
from qpd.networks.models.student_six_model_dqn_ import StudentSixModelDQN
from run import config
from qpd.config import Config
from qpd.networks.wrapper.model_wrapper import ModelWrapper
def get_environment(config: Config):
return make_vec_env("CartPole-v1", n_envs=config.evaluator_config.env_workers, vec_env_cls=DummyVecEnv, env_kwargs={"render_mode": "human"})
c = Config(get_environment, config)
env = get_environment(c)
# Pretrained original model
checkpoint = load_from_hub(
repo_id="sb3/dqn-CartPole-v1",
filename="dqn-CartPole-v1.zip",
)
custom_objects = {
"learning_rate": 0.0,
"lr_schedule": lambda _: 0.0,
"clip_range": lambda _: 0.0,
}
model = DQN.load(checkpoint, env) #, custom_objects=custom_objects)
teacher = ModelWrapper.construct_wrapper(c, model)
c.init_env_model_based_params(teacher)
# Loading quantized student
student = StudentSixModelDQN(c)
student.load("/home/ian/projects-idlab/euROBIN/code/framework/test/Data/compression/2024-06-14_15:56:55/best/student.pth")
rewards_list = []
state = env.reset()
steps = 0
rewards = 0
torch.onnx.export(student, torch.from_numpy(state), "/home/ian/projects-idlab/euROBIN/code/framework/test/Data/compression/2024-06-14_15:56:55/best/clean_student.onnx", verbose=True, input_names=["input_1"], output_names=["output_1"], opset_version=14, qat=True)