import gym import numpy as np from qpd.config import Config from qpd.compressor import Compressor from huggingface_sb3 import load_from_hub from stable_baselines3 import A2C from stable_baselines3.common.env_util import make_atari_env, make_vec_env from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv from stable_baselines3.common.vec_env.base_vec_env import VecEnv from qpd.networks.wrapper.student.cnn_student import CNNStudentNet from datetime import datetime # Student driven no memory saving! config = { "memory": { "size": 54000, # Size of memory used for distillation "update_frequency": 1, # Epoch frequency for updating the memory "update_size": 5400, # Minimum update size in steps "device": "cpu", # Only used with framestacked environments "frame_stack_optimization": True, # Only store last frame "check_consistency": True }, "evaluator": { "student_driven": True, # Student decide the transitions in the environment "student_test_frequency": 10, # Epoch frequency "episodes": 50, # Minimum episodes for testing student "initialize": 30, # Amount of actions to skip at beginning of episode "ray_workers": 10, # Parallel ray workers used for updating and testing "device": "cpu", "deterministic": False, "exploration_rate": 0, }, "compression": { "checkpoint_frequency": 2, # Epoch frequency for saving students "epochs": 600, "learning_rate": 1e-4, "batch_size": 256, "device": "cuda", # Only used in discrete action spaces "T": 0.01, # Softmax hyperparameter "categorical": False, "critic_importance": 0.5, # Only used in continuous action spaces "distribution": "Std", # Std, Mean "loss": "KL" # KL, Huber, MSE }, "quantization": { "enabled": False, "bits": 8 }, "data_directory": "./data", "run_name": "test", # Change this for every run } def env_info_done_filter(info): key = "ale.lives" if "ale.lives" in info.keys() else "lives" print(f"Lives: {info[key]}") return info[key] == 0 def get_environment(config: Config): vec_env_cls = DummyVecEnv# SubprocVecEnv if subprocenv else DummyVecEnv env = make_atari_env("BreakoutNoFrameskip-v4", n_envs=config.evaluator_config.env_workers, seed=np.random.randint(0, 1000), vec_env_cls=vec_env_cls) #type: ignore env = VecFrameStack(env, n_stack=4) return env if __name__ == "__main__": checkpoint = load_from_hub(repo_id="sb3/a2c-BreakoutNoFrameskip-v4",filename="a2c-BreakoutNoFrameskip-v4.zip",) print(checkpoint) model = A2C.load(checkpoint) c = Config(get_environment, config) comp = Compressor(model, get_environment, c) \ .student_network(CNNStudentNet) \ .set_environment_done_filter(env_info_done_filter) compressed_model = comp.compress() compressed_model.save("test")