46 lines
1.5 KiB
Python
46 lines
1.5 KiB
Python
from stable_baselines3.common.env_util import make_vec_env
|
|
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
|
|
from huggingface_sb3 import load_from_hub
|
|
from stable_baselines3.dqn.dqn import DQN
|
|
import torch
|
|
|
|
from qpd.networks.models.student_six_model_dqn_ import StudentSixModelDQN
|
|
from run import config
|
|
from qpd.config import Config
|
|
from qpd.networks.wrapper.model_wrapper import ModelWrapper
|
|
|
|
def get_environment(config: Config):
|
|
return make_vec_env("CartPole-v1", n_envs=config.evaluator_config.env_workers, vec_env_cls=DummyVecEnv, env_kwargs={"render_mode": "human"})
|
|
|
|
c = Config(get_environment, config)
|
|
|
|
env = get_environment(c)
|
|
|
|
# Pretrained original model
|
|
checkpoint = load_from_hub(
|
|
repo_id="sb3/dqn-CartPole-v1",
|
|
filename="dqn-CartPole-v1.zip",
|
|
)
|
|
|
|
custom_objects = {
|
|
"learning_rate": 0.0,
|
|
"lr_schedule": lambda _: 0.0,
|
|
"clip_range": lambda _: 0.0,
|
|
}
|
|
|
|
model = DQN.load(checkpoint, env) #, custom_objects=custom_objects)
|
|
teacher = ModelWrapper.construct_wrapper(c, model)
|
|
|
|
c.init_env_model_based_params(teacher)
|
|
|
|
# Loading quantized student
|
|
student = StudentSixModelDQN(c)
|
|
student.load("/home/ian/projects-idlab/euROBIN/code/framework/test/Data/compression/2024-06-14_15:56:55/best/student.pth")
|
|
|
|
|
|
rewards_list = []
|
|
state = env.reset()
|
|
steps = 0
|
|
rewards = 0
|
|
|
|
torch.onnx.export(student, torch.from_numpy(state), "/home/ian/projects-idlab/euROBIN/code/framework/test/Data/compression/2024-06-14_15:56:55/best/clean_student.onnx", verbose=True, input_names=["input_1"], output_names=["output_1"], opset_version=14, qat=True) |