92 lines
3.1 KiB
Python
92 lines
3.1 KiB
Python
import gym
|
|
import numpy as np
|
|
|
|
from qpd.config import Config
|
|
from qpd.compressor import Compressor
|
|
from huggingface_sb3 import load_from_hub
|
|
from stable_baselines3 import A2C
|
|
from stable_baselines3.common.env_util import make_atari_env, make_vec_env
|
|
from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack
|
|
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
|
|
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
|
|
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
|
|
|
|
from qpd.networks.wrapper.student.cnn_student import CNNStudentNet
|
|
|
|
from datetime import datetime
|
|
|
|
# Student driven no memory saving!
|
|
|
|
config = {
|
|
"memory": {
|
|
"size": 54000, # Size of memory used for distillation
|
|
"update_frequency": 1, # Epoch frequency for updating the memory
|
|
"update_size": 5400, # Minimum update size in steps
|
|
"device": "cpu",
|
|
|
|
# Only used with framestacked environments
|
|
"frame_stack_optimization": True, # Only store last frame
|
|
|
|
"check_consistency": True
|
|
},
|
|
"evaluator": {
|
|
"student_driven": True, # Student decide the transitions in the environment
|
|
"student_test_frequency": 10, # Epoch frequency
|
|
"episodes": 50, # Minimum episodes for testing student
|
|
"initialize": 30, # Amount of actions to skip at beginning of episode
|
|
"ray_workers": 10, # Parallel ray workers used for updating and testing
|
|
"device": "cpu",
|
|
"deterministic": False,
|
|
"exploration_rate": 0,
|
|
},
|
|
"compression": {
|
|
"checkpoint_frequency": 2, # Epoch frequency for saving students
|
|
"epochs": 600,
|
|
"learning_rate": 1e-4,
|
|
"batch_size": 256,
|
|
"device": "cuda",
|
|
|
|
# Only used in discrete action spaces
|
|
"T": 0.01, # Softmax hyperparameter
|
|
"categorical": False,
|
|
"critic_importance": 0.5,
|
|
|
|
# Only used in continuous action spaces
|
|
"distribution": "Std", # Std, Mean
|
|
"loss": "KL" # KL, Huber, MSE
|
|
},
|
|
"quantization": {
|
|
"enabled": False,
|
|
"bits": 8
|
|
},
|
|
"data_directory": "./data",
|
|
"run_name": "test", # Change this for every run
|
|
}
|
|
|
|
|
|
def env_info_done_filter(info):
|
|
key = "ale.lives" if "ale.lives" in info.keys() else "lives"
|
|
print(f"Lives: {info[key]}")
|
|
return info[key] == 0
|
|
|
|
def get_environment(config: Config):
|
|
vec_env_cls = DummyVecEnv# SubprocVecEnv if subprocenv else DummyVecEnv
|
|
env = make_atari_env("BreakoutNoFrameskip-v4", n_envs=config.evaluator_config.env_workers, seed=np.random.randint(0, 1000), vec_env_cls=vec_env_cls) #type: ignore
|
|
env = VecFrameStack(env, n_stack=4)
|
|
return env
|
|
|
|
|
|
if __name__ == "__main__":
|
|
checkpoint = load_from_hub(repo_id="sb3/a2c-BreakoutNoFrameskip-v4",filename="a2c-BreakoutNoFrameskip-v4.zip",)
|
|
print(checkpoint)
|
|
model = A2C.load(checkpoint)
|
|
|
|
c = Config(get_environment, config)
|
|
|
|
comp = Compressor(model, get_environment, c) \
|
|
.student_network(CNNStudentNet) \
|
|
.set_environment_done_filter(env_info_done_filter)
|
|
|
|
compressed_model = comp.compress()
|
|
|
|
compressed_model.save("test") |