QPD/examples/run_cartpole.py

85 lines
3.0 KiB
Python

from qpd.config import Config
from qpd.compressor import Compressor
from huggingface_sb3 import load_from_hub
from stable_baselines3 import A2C
from stable_baselines3.dqn.dqn import DQN
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from qpd.networks.models.student_six_model import StudentSixModel
from qpd.networks.models.student_six_model_dqn_ import StudentSixModelDQN
from qpd.networks.models.student_tiny_dqn import TinyStudentDQN
from qpd.networks.wrapper.student.fully_connected_student import FCStudentNet
# Student driven no memory saving!
config = {
"memory": {
"size": 100000, # Size of memory used for distillation
"update_frequency": 1, # Epoch frequency for updating the memory
"update_size": 10000, # Minimum update size in steps
"device": "cpu",
# Only used with framestacked environments
"frame_stack_optimization": False, # Only store last frame
"check_consistency": True
},
"evaluator": {
"student_driven": False, # Student decide the transitions in the environment
"student_test_frequency": 10, # Epoch frequency
"episodes": 20, # Minimum episodes for testing student
"initialize": 0, # Amount of actions to skip at beginning of episode
"ray_workers": 10, # Parallel ray workers used for updating and testing
"device": "cpu",
"deterministic": False
},
"compression": {
"checkpoint_frequency": 2, # Epoch frequency for saving students
"epochs": 600,
"learning_rate": 5e-4,
"batch_size": 64,
"device": "cuda",
# Only used in discrete action spaces
"T": 0.01, # Softmax hyperparameter
"categorical": False,
"critic_importance": 0.5,
# Only used in continuous action spaces
"distribution": "Std", # Std, Mean
"loss": "KL" # KL, Huber, MSE
},
"quantization": {
"enabled": True,
"bits": 8
},
"data_directory": "./data",
"run_name": "test", # Change this for every run
}
#"/home/user/Workspace/University/PhD/Experiments/QPD",
def get_environment(config: Config):
return make_vec_env("CartPole-v1", n_envs=config.evaluator_config.env_workers, vec_env_cls=DummyVecEnv)
if __name__ == "__main__":
#checkpoint = load_from_hub(repo_id="sb3/a2c-CartPole-v1",filename="a2c-CartPole-v1.zip",)
checkpoint = load_from_hub(repo_id="sb3/dqn-CartPole-v1",filename="dqn-CartPole-v1.zip",)
# custom_objects = {
# "learning_rate": 0.0,
# "lr_schedule": lambda _: 0.0,
# "clip_range": lambda _: 0.0,
# }
print(checkpoint)
c = Config(get_environment, config)
model = DQN.load(checkpoint, env=get_environment(c)) # , custom_objects=custom_objects)
# comp = Compressor(model, get_environment, c).student_network(FCStudentNet)
comp = Compressor(model, get_environment, c).student_model(TinyStudentDQN)
comp.compress()