From 862c55e03c513e135a29ba38a749b2bd32cad244 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Av=C3=A9?= Date: Sat, 23 Nov 2024 21:35:13 +0100 Subject: [PATCH] Initial public commit --- README.md | 35 ++ examples/dqn_to_onnx_cartpole.py | 46 +++ examples/esp32_test/env_cartpole.py | 45 +++ examples/run_atari.py | 92 ++++++ examples/run_cartpole.py | 85 +++++ examples/run_cheetah.py | 86 +++++ pyproject.toml | 36 +++ src/qpd/__init__.py | 0 src/qpd/compression/__init__.py | 0 src/qpd/compression/memory.py | 193 +++++++++++ src/qpd/compression/policy_distillation.py | 275 ++++++++++++++++ src/qpd/compressor.py | 77 +++++ src/qpd/config.py | 165 ++++++++++ src/qpd/networks/__init__.py | 2 + src/qpd/networks/models/__init__.py | 0 src/qpd/networks/models/student_six_model.py | 45 +++ .../networks/models/student_six_model_dqn_.py | 37 +++ src/qpd/networks/models/student_tiny_dqn.py | 37 +++ src/qpd/networks/network_interface.py | 25 ++ src/qpd/networks/wrapper/__init__.py | 7 + src/qpd/networks/wrapper/model_wrapper.py | 19 ++ .../wrapper/stable_baselines/__init__.py | 16 + .../networks/wrapper/stable_baselines/a2c.py | 8 + .../networks/wrapper/stable_baselines/dqn.py | 18 ++ .../stable_baselines/off_actor_critic.py | 25 ++ .../stable_baselines/on_actor_critic.py | 62 ++++ .../networks/wrapper/stable_baselines/ppo.py | 10 + .../wrapper/stable_baselines/qrdqn.py | 12 + .../networks/wrapper/stable_baselines/sac.py | 8 + .../wrapper/stable_baselines/sb_wrapper.py | 67 ++++ .../wrapper/stable_baselines/student.py | 21 ++ .../stable_baselines/value_iteration.py | 22 ++ src/qpd/networks/wrapper/student/__init__.py | 0 .../networks/wrapper/student/cnn_student.py | 39 +++ .../student/fully_connected_student.py | 31 ++ .../networks/wrapper/student/student_net.py | 14 + src/qpd/networks/wrapper/wrapper_decorator.py | 5 + src/qpd/utils/__init__.py | 4 + src/qpd/utils/configuration.py | 91 ++++++ src/qpd/utils/evaluator.py | 303 ++++++++++++++++++ src/qpd/utils/model_utils.py | 47 +++ src/qpd/utils/onnx_helper.py | 52 +++ src/qpd/utils/summary.py | 40 +++ 43 files changed, 2202 insertions(+) create mode 100644 README.md create mode 100644 examples/dqn_to_onnx_cartpole.py create mode 100644 examples/esp32_test/env_cartpole.py create mode 100644 examples/run_atari.py create mode 100644 examples/run_cartpole.py create mode 100644 examples/run_cheetah.py create mode 100644 pyproject.toml create mode 100644 src/qpd/__init__.py create mode 100644 src/qpd/compression/__init__.py create mode 100644 src/qpd/compression/memory.py create mode 100644 src/qpd/compression/policy_distillation.py create mode 100644 src/qpd/compressor.py create mode 100644 src/qpd/config.py create mode 100644 src/qpd/networks/__init__.py create mode 100644 src/qpd/networks/models/__init__.py create mode 100644 src/qpd/networks/models/student_six_model.py create mode 100644 src/qpd/networks/models/student_six_model_dqn_.py create mode 100644 src/qpd/networks/models/student_tiny_dqn.py create mode 100644 src/qpd/networks/network_interface.py create mode 100644 src/qpd/networks/wrapper/__init__.py create mode 100644 src/qpd/networks/wrapper/model_wrapper.py create mode 100644 src/qpd/networks/wrapper/stable_baselines/__init__.py create mode 100644 src/qpd/networks/wrapper/stable_baselines/a2c.py create mode 100644 src/qpd/networks/wrapper/stable_baselines/dqn.py create mode 100644 src/qpd/networks/wrapper/stable_baselines/off_actor_critic.py create mode 100644 src/qpd/networks/wrapper/stable_baselines/on_actor_critic.py create mode 100644 src/qpd/networks/wrapper/stable_baselines/ppo.py create mode 100644 src/qpd/networks/wrapper/stable_baselines/qrdqn.py create mode 100644 src/qpd/networks/wrapper/stable_baselines/sac.py create mode 100644 src/qpd/networks/wrapper/stable_baselines/sb_wrapper.py create mode 100644 src/qpd/networks/wrapper/stable_baselines/student.py create mode 100644 src/qpd/networks/wrapper/stable_baselines/value_iteration.py create mode 100644 src/qpd/networks/wrapper/student/__init__.py create mode 100644 src/qpd/networks/wrapper/student/cnn_student.py create mode 100644 src/qpd/networks/wrapper/student/fully_connected_student.py create mode 100644 src/qpd/networks/wrapper/student/student_net.py create mode 100644 src/qpd/networks/wrapper/wrapper_decorator.py create mode 100644 src/qpd/utils/__init__.py create mode 100644 src/qpd/utils/configuration.py create mode 100644 src/qpd/utils/evaluator.py create mode 100644 src/qpd/utils/model_utils.py create mode 100644 src/qpd/utils/onnx_helper.py create mode 100644 src/qpd/utils/summary.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..af0c2df --- /dev/null +++ b/README.md @@ -0,0 +1,35 @@ +# tiny-rl +>**_NOTE_**: Currently only stable baselines is supported out of the box + +The goal of this repository is to make an easy-to-use framework that wraps the QPD algorithm to compress models of RL policies. + + +## Installation +First these libraries are required to install. +``` sh +sudo apt update +sudo apt install libpng-dev libjpeg-dev zlib1g-dev +``` +Afterwards install the python library. +>**_Important_**: Make sure u use python version 3.10.*, setup tools version 65.0 and wheel 0.38.0 +``` sh +conda create -n compression python=3.10 +conda activate compression +python -m pip install setuptools==65.0 wheel==0.38.4 +python -m pip install -e . +``` + +## Currently supported +- A2C +- PPO +- DQN +- QRDQN +- SAC + +## Cheetah env installation if used with conda +``` sh +conda install -c conda-forge xorg-libx11 +conda install -c conda-forge glew +conda install -c conda-forge mesalib +conda install -c menpo glfw3 +``` \ No newline at end of file diff --git a/examples/dqn_to_onnx_cartpole.py b/examples/dqn_to_onnx_cartpole.py new file mode 100644 index 0000000..ba69b83 --- /dev/null +++ b/examples/dqn_to_onnx_cartpole.py @@ -0,0 +1,46 @@ +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv +from huggingface_sb3 import load_from_hub +from stable_baselines3.dqn.dqn import DQN +import torch + +from qpd.networks.models.student_six_model_dqn_ import StudentSixModelDQN +from run import config +from qpd.config import Config +from qpd.networks.wrapper.model_wrapper import ModelWrapper + +def get_environment(config: Config): + return make_vec_env("CartPole-v1", n_envs=config.evaluator_config.env_workers, vec_env_cls=DummyVecEnv, env_kwargs={"render_mode": "human"}) + +c = Config(get_environment, config) + +env = get_environment(c) + +# Pretrained original model +checkpoint = load_from_hub( + repo_id="sb3/dqn-CartPole-v1", + filename="dqn-CartPole-v1.zip", +) + +custom_objects = { + "learning_rate": 0.0, + "lr_schedule": lambda _: 0.0, + "clip_range": lambda _: 0.0, +} + +model = DQN.load(checkpoint, env) #, custom_objects=custom_objects) +teacher = ModelWrapper.construct_wrapper(c, model) + +c.init_env_model_based_params(teacher) + +# Loading quantized student +student = StudentSixModelDQN(c) +student.load("/home/ian/projects-idlab/euROBIN/code/framework/test/Data/compression/2024-06-14_15:56:55/best/student.pth") + + +rewards_list = [] +state = env.reset() +steps = 0 +rewards = 0 + +torch.onnx.export(student, torch.from_numpy(state), "/home/ian/projects-idlab/euROBIN/code/framework/test/Data/compression/2024-06-14_15:56:55/best/clean_student.onnx", verbose=True, input_names=["input_1"], output_names=["output_1"], opset_version=14, qat=True) \ No newline at end of file diff --git a/examples/esp32_test/env_cartpole.py b/examples/esp32_test/env_cartpole.py new file mode 100644 index 0000000..a3e0ac6 --- /dev/null +++ b/examples/esp32_test/env_cartpole.py @@ -0,0 +1,45 @@ +import time + +import serial +import numpy as np +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv + +env = make_vec_env("CartPole-v1", n_envs=1, vec_env_cls=DummyVecEnv, env_kwargs={"render_mode": "human"}) + +rewards_list = [] +state = env.reset() +steps = 0 +rewards = 0 + +port = serial.Serial("/dev/ttyUSB0", 115200, timeout=1) + +def read_to_np(): + string = port.readline() + return np.array([int(string)]) + +def write_ser(cmd): + port.write(cmd) + +while(1): + tick = time.time() + write_ser(str(state).encode()) + + actions = read_to_np() + + tock = time.time() + + state, reward, dones, info = env.step(actions) + env.render() + #time.sleep(0.02) + rewards += reward + steps += 1 + + print(1/(tock-tick)) + if np.all(dones): + print(steps) + print(info[0]["episode"]["r"]) + print(rewards) + print(info) + rewards = 0 + steps = 0 \ No newline at end of file diff --git a/examples/run_atari.py b/examples/run_atari.py new file mode 100644 index 0000000..8fbacf2 --- /dev/null +++ b/examples/run_atari.py @@ -0,0 +1,92 @@ +import gym +import numpy as np + +from qpd.config import Config +from qpd.compressor import Compressor +from huggingface_sb3 import load_from_hub +from stable_baselines3 import A2C +from stable_baselines3.common.env_util import make_atari_env, make_vec_env +from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack +from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv +from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv +from stable_baselines3.common.vec_env.base_vec_env import VecEnv + +from qpd.networks.wrapper.student.cnn_student import CNNStudentNet + +from datetime import datetime + +# Student driven no memory saving! + +config = { + "memory": { + "size": 54000, # Size of memory used for distillation + "update_frequency": 1, # Epoch frequency for updating the memory + "update_size": 5400, # Minimum update size in steps + "device": "cpu", + + # Only used with framestacked environments + "frame_stack_optimization": True, # Only store last frame + + "check_consistency": True + }, + "evaluator": { + "student_driven": True, # Student decide the transitions in the environment + "student_test_frequency": 10, # Epoch frequency + "episodes": 50, # Minimum episodes for testing student + "initialize": 30, # Amount of actions to skip at beginning of episode + "ray_workers": 10, # Parallel ray workers used for updating and testing + "device": "cpu", + "deterministic": False, + "exploration_rate": 0, + }, + "compression": { + "checkpoint_frequency": 2, # Epoch frequency for saving students + "epochs": 600, + "learning_rate": 1e-4, + "batch_size": 256, + "device": "cuda", + + # Only used in discrete action spaces + "T": 0.01, # Softmax hyperparameter + "categorical": False, + "critic_importance": 0.5, + + # Only used in continuous action spaces + "distribution": "Std", # Std, Mean + "loss": "KL" # KL, Huber, MSE + }, + "quantization": { + "enabled": False, + "bits": 8 + }, + "data_directory": "./data", + "run_name": "test", # Change this for every run +} + + +def env_info_done_filter(info): + key = "ale.lives" if "ale.lives" in info.keys() else "lives" + print(f"Lives: {info[key]}") + return info[key] == 0 + +def get_environment(config: Config): + vec_env_cls = DummyVecEnv# SubprocVecEnv if subprocenv else DummyVecEnv + env = make_atari_env("BreakoutNoFrameskip-v4", n_envs=config.evaluator_config.env_workers, seed=np.random.randint(0, 1000), vec_env_cls=vec_env_cls) #type: ignore + env = VecFrameStack(env, n_stack=4) + return env + + +if __name__ == "__main__": + checkpoint = load_from_hub(repo_id="sb3/a2c-BreakoutNoFrameskip-v4",filename="a2c-BreakoutNoFrameskip-v4.zip",) + print(checkpoint) + model = A2C.load(checkpoint) + + c = Config(get_environment, config) + + comp = Compressor(model, get_environment, c) \ + .student_network(CNNStudentNet) \ + .set_environment_done_filter(env_info_done_filter) + + compressed_model = comp.compress() + + compressed_model.save("test") \ No newline at end of file diff --git a/examples/run_cartpole.py b/examples/run_cartpole.py new file mode 100644 index 0000000..125d1c9 --- /dev/null +++ b/examples/run_cartpole.py @@ -0,0 +1,85 @@ +from qpd.config import Config + +from qpd.compressor import Compressor +from huggingface_sb3 import load_from_hub +from stable_baselines3 import A2C +from stable_baselines3.dqn.dqn import DQN +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv + +from qpd.networks.models.student_six_model import StudentSixModel +from qpd.networks.models.student_six_model_dqn_ import StudentSixModelDQN +from qpd.networks.models.student_tiny_dqn import TinyStudentDQN +from qpd.networks.wrapper.student.fully_connected_student import FCStudentNet + +# Student driven no memory saving! + +config = { + "memory": { + "size": 100000, # Size of memory used for distillation + "update_frequency": 1, # Epoch frequency for updating the memory + "update_size": 10000, # Minimum update size in steps + "device": "cpu", + + # Only used with framestacked environments + "frame_stack_optimization": False, # Only store last frame + + "check_consistency": True + }, + "evaluator": { + "student_driven": False, # Student decide the transitions in the environment + "student_test_frequency": 10, # Epoch frequency + "episodes": 20, # Minimum episodes for testing student + "initialize": 0, # Amount of actions to skip at beginning of episode + "ray_workers": 10, # Parallel ray workers used for updating and testing + "device": "cpu", + "deterministic": False + }, + "compression": { + "checkpoint_frequency": 2, # Epoch frequency for saving students + "epochs": 600, + "learning_rate": 5e-4, + "batch_size": 64, + "device": "cuda", + + # Only used in discrete action spaces + "T": 0.01, # Softmax hyperparameter + "categorical": False, + "critic_importance": 0.5, + + # Only used in continuous action spaces + "distribution": "Std", # Std, Mean + "loss": "KL" # KL, Huber, MSE + }, + "quantization": { + "enabled": True, + "bits": 8 + }, + "data_directory": "./data", + "run_name": "test", # Change this for every run +} +#"/home/user/Workspace/University/PhD/Experiments/QPD", + + + +def get_environment(config: Config): + return make_vec_env("CartPole-v1", n_envs=config.evaluator_config.env_workers, vec_env_cls=DummyVecEnv) + + +if __name__ == "__main__": + #checkpoint = load_from_hub(repo_id="sb3/a2c-CartPole-v1",filename="a2c-CartPole-v1.zip",) + checkpoint = load_from_hub(repo_id="sb3/dqn-CartPole-v1",filename="dqn-CartPole-v1.zip",) + # custom_objects = { + # "learning_rate": 0.0, + # "lr_schedule": lambda _: 0.0, + # "clip_range": lambda _: 0.0, + # } + print(checkpoint) + + c = Config(get_environment, config) + + model = DQN.load(checkpoint, env=get_environment(c)) # , custom_objects=custom_objects) + + # comp = Compressor(model, get_environment, c).student_network(FCStudentNet) + comp = Compressor(model, get_environment, c).student_model(TinyStudentDQN) + comp.compress() \ No newline at end of file diff --git a/examples/run_cheetah.py b/examples/run_cheetah.py new file mode 100644 index 0000000..6bfadfa --- /dev/null +++ b/examples/run_cheetah.py @@ -0,0 +1,86 @@ +from qpd.config import Config +from qpd.networks.models.student_six_model import StudentSixModel +from qpd.compressor import Compressor +from huggingface_sb3 import load_from_hub +from stable_baselines3.sac.sac import SAC +from stable_baselines3.a2c.a2c import A2C +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv +from stable_baselines3.common.vec_env.vec_normalize import VecNormalize + +from qpd.networks.wrapper.student.fully_connected_student import FCStudentNet + +from datetime import datetime + +config = { + "memory": { + "size": 100000, # Size of memory used for distillation + "update_frequency": 1, # Epoch frequency for updating the memory + "update_size": 10000, # Minimum update size in steps + "device": "cpu", + + # Only used with framestacked environments + "frame_stack_optimization": False, # Only store last frame + + "check_consistency": True + }, + "evaluator": { + "student_driven": True, # Student decide the transitions in the environment + "student_test_frequency": 10, # Epoch frequency + "episodes": 20, # Minimum episodes for testing student + "initialize": 0, # Amount of actions to skip at beginning of episode + "ray_workers": 10, # Parallel ray workers used for updating and testing + "device": "cpu", + "deterministic": False + }, + "compression": { + "checkpoint_frequency": 2, # Epoch frequency for saving students + "epochs": 600, + "learning_rate": 5e-4, + "batch_size": 64, + "device": "cuda", + + # Only used in discrete action spaces + "T": 0.01, # Softmax hyperparameter + "categorical": False, + "critic_importance": 0.5, + + # Only used in continuous action spaces + "distribution": "Std", # Std, Mean + "loss": "KL" # KL, Huber, MSE + }, + "quantization": { + "enabled": False, + "bits": 8 + }, + "data_directory": "./data", + "run_name": "cheetah_no_quant", # Change this for every run +} +#"/home/user/Workspace/University/PhD/Experiments/QPD", + + + +def get_environment_normalized(config: Config): + env = make_vec_env("HalfCheetah-v3", n_envs=config.evaluator_config.env_workers, vec_env_cls=DummyVecEnv) + + normalize = load_from_hub( + repo_id="sb3/a2c-HalfCheetah-v3", + filename="vec_normalize.pkl", + ) + return VecNormalize.load(normalize, env) + +def get_environment(config: Config): + env = make_vec_env("HalfCheetah-v3", n_envs=config.evaluator_config.env_workers, vec_env_cls=DummyVecEnv) + return env + + +if __name__ == "__main__": + checkpoint = load_from_hub(repo_id="sb3/sac-HalfCheetah-v3",filename="sac-HalfCheetah-v3.zip",) + print(checkpoint) + model = SAC.load(checkpoint) + + c = Config(get_environment, config) + + # comp = Compressor(model, get_environment, c).student_network(FCStudentNet) + comp = Compressor(model, get_environment, c).student_model(StudentSixModel) + compressed_model = comp.compress() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..1d6d27d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[project] +name = "qpd" +version = "0.0.1" +authors = [ + { name="Thomas Avé", email="thomas.ave@uantwerpen.be" }, + { name="Ian Ravijts", email="ian.ravijts@uantwerpen.be" } +] +readme = "README.md" +requires-python = ">=3.9" +dependencies = [ + "wheel==0.38.4", + "shimmy>=0.2.1", + "numpy", + "atari-py==0.2.9", + "gymnasium[accept-rom-license]", + "ale-py", + "pillow", + "readchar", + "matplotlib", + "distiller@git+https://github.com/tiny-rl/distiller.git", + "wandb", + "stable-baselines3[extra]", + "sb3-contrib", + "tqdm", + "ray[default]", + "webdavclient3", + "prefetch_generator", + "pygame", + "pyglet", + "cython<3", +] + +[build-system] +requires = [ "setuptools>=61.0,<=66" ] +build-backend = "setuptools.build_meta" + diff --git a/src/qpd/__init__.py b/src/qpd/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/qpd/compression/__init__.py b/src/qpd/compression/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/qpd/compression/memory.py b/src/qpd/compression/memory.py new file mode 100644 index 0000000..4d9d814 --- /dev/null +++ b/src/qpd/compression/memory.py @@ -0,0 +1,193 @@ +from __future__ import annotations +import torch +import random +import numpy +import gzip +import json +from pathlib import Path +from typing import Iterator, Tuple, Dict, Optional + +from qpd.config import Config + + +class Memory: + def __init__(self, config: Config, size=None, meta_class = None, path: Optional[str] = None) -> None: + self.meta_name = str(meta_class.__class__) if meta_class else None + + self.max_size = size if size else config.memory_config.size + self.did_overflow = False + self.current_index = 0 + + self.output_n = config.output_shape[0] + self.device = config.memory_config.device + self.check_consistency = config.memory_config.check_consistency + self.critic = config.compression_config.critic_importance > 0 + self.frame_stack_optim = config.memory_config.frame_stack_optimization + self.framestack_size = config.observation_shape[0] if self.frame_stack_optim else 0 + self.student_driven = config.evaluator_config.student_driven + + if path: + print(f"Load memory on path: {path}") + self.load(path) + else: + self.states = torch.zeros((self.max_size,) + (config.observation_shape[1:] if self.frame_stack_optim else config.observation_shape), device=self.device) # [(Input frame, output actions)] + self.outputs = torch.zeros((self.max_size, self.output_n), device=self.device) # + 1 = critic + # self.repeats = torch.zeros(self.size, device=self.device) + self.start = set() + + + def update(self, state: torch.Tensor, outputs: torch.Tensor, start: bool) -> None: + assert not torch.any(outputs == 0), f"[ERROR]: Update input is not valid!" + + # update overflow + self.did_overflow = self.will_overflow() + + if start: + self.start.add(self.current_index) + else: + self.start.discard(self.current_index) + + state = state.to(self.device).data + self.states[self.current_index] = state[:, -1] if self.frame_stack_optim else state # Story only latest frame + + self.outputs[self.current_index] = outputs.to(self.device).data + self.current_index = (self.current_index + 1) % self.max_size + + if self.check_consistency: + self.verifyConsistency(f"Update current_index: {self.current_index} is start: {start} outputs: {self.outputs}") + + def will_overflow(self, update_size=1): + if (self.current_index + update_size) >= self.max_size: + return True + return False + + def real_size(self): + return (self.states.element_size() * self.states.nelement()) + (self.outputs.element_size() * self.outputs.nelement()) + + def sample(self, batch_size: int, shuffle: bool=True) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]: + state_indexes = list(range(len(self))) + if shuffle: + random.shuffle(state_indexes) + splits = [state_indexes[i:i + batch_size] for i in range(0, len(state_indexes), batch_size)] + + if self.frame_stack_optim: + for split in splits: + current_states = [] + indexes = torch.tensor(split) + for i in range(self.framestack_size): + current_states.append(self.states[(indexes - (self.framestack_size - 1) + i) % len(self)].unsqueeze(1)) + yield (torch.cat(current_states, dim=1), self.outputs[split]) + else: + for split in splits: + yield (self.states[split], self.outputs[split]) + + + def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: + if self.frame_stack_optim: + states_index = list(self.start)[index] + return (torch.cat([self.states[i % len(self)].unsqueeze(0) for i in range(states_index-3, states_index+1)]), self.outputs[states_index]) + else: + return (self.states[index], self.outputs[index]) + + + def __len__(self) -> int: + if self.did_overflow: + return len(self.outputs) + else: + return self.current_index + + + def prune(self) -> None: + if not self.did_overflow: + for item in ["states", "outputs"]: + setattr(self, item, getattr(self, item)[0: self.current_index]) + self.max_size = self.current_index + + if self.check_consistency: + self.verifyConsistency("Prune") + + + def grow(self, amount: int) -> int: + print(f"Growing with size: {amount}") + for item in ["states", "outputs"]: + obj = getattr(self, item) + tmp = torch.zeros((self.max_size + amount,) + obj.shape[1:], device = self.device, dtype=obj.dtype) + tmp[:len(self)] = obj + setattr(self, item, tmp) + + self.max_size += amount + self.did_overflow = False + + if self.verifyConsistency: + self.verifyConsistency("Grow") + return len(self) + + + def load_memory(self, other: Memory) -> None: + assert other.max_size <= self.max_size, f"Memory to load is bigger than current memory! {other.max_size} > {self.max_size}" + + self.did_overflow = self.will_overflow(len(other)) + + for item in ["states", "outputs"]: + getattr(self, item)[self.current_index:min(self.current_index+len(other), self.max_size)] = getattr(other, item)[:min(len(other), self.max_size - self.current_index)] + + if self.current_index + len(other) > self.max_size: + getattr(self, item)[:len(other) - (self.max_size - self.current_index)] = getattr(other, item)[self.max_size - self.current_index:] + + for idx in range(len(other)): + new_idx = (self.current_index + idx) % self.max_size + self.start.discard(new_idx) + if idx in other.start: + self.start.add(new_idx) + + self.current_index = (self.current_index + len(other)) % self.max_size + + if self.check_consistency: + self.verifyConsistency("Load memory") + + + def to(self, device: torch.device): + for item in ["states", "outputs"]: + setattr(self, item, getattr(self, item).to(device)) + + self.device = device + + + def load(self, path: str) -> None: + assert self.meta_name, "Can not load this memory! No meta_class specified!" + with open(f"{path}/metadata.json") as f: + j = json.loads(f.read()) + assert j["teacher"] == self.meta_name + self.current_index = j["current_index"] + self.did_overflow = j["did_overflow"] + + for item in ["states", "outputs"]: + with gzip.GzipFile(f"{path}/{item}.npy.gz") as f: + setattr(self, item, torch.from_numpy(numpy.load(file=f))) + + with gzip.GzipFile(f"{path}/start.npy.gz") as f: + self.start = set(numpy.load(file=f)) + + self.max_size = self.states.shape[0] + + if self.check_consistency: + self.verifyConsistency("Load") + + + def save(self, path: str) -> None: + assert self.meta_name, "Can not save this memory! No meta_class specified!" + Path(path).mkdir(parents=True, exist_ok=True) + for item in ["states", "outputs"]: + with gzip.GzipFile(f"{path}/{item}.npy.gz", "w") as f: + numpy.save(f, getattr(self, item).cpu().numpy()) + + with gzip.GzipFile(f"{path}/start.npy.gz", "w") as f: + numpy.save(f, numpy.array(list(self.start))) + with open(f"{path}/metadata.json", "w") as f: + f.write(json.dumps({"teacher": self.meta_name, "did_overflow": self.did_overflow, "current_index": self.current_index})) + + def verifyConsistency(self, mes) -> None: + assert not torch.any(self.outputs[:len(self)] == 0), f"[ERROR]: Memory is not consistent! ({mes})" + # assert not torch.any(torch.bitwise_or(self.outputs[:len(self)] > 100, self.outputs[:len(self)] < -100)), \ + # f"[ERROR]: Memory is not consistent in a weird way! max: {self.outputs.max()} min: {self.outputs.min()} ({mes})" + diff --git a/src/qpd/compression/policy_distillation.py b/src/qpd/compression/policy_distillation.py new file mode 100644 index 0000000..fd17fc6 --- /dev/null +++ b/src/qpd/compression/policy_distillation.py @@ -0,0 +1,275 @@ +# import wandb +import random +import os +import gc + +# from ..utils import WebdavManager +from collections import defaultdict +from distiller.quantization import DorefaQuantizer +from gym.spaces import Box, Discrete +from pathlib import Path +from qpd.config import Config +from qpd.networks.wrapper.stable_baselines.sb_wrapper import SBWrapper +from qpd.utils import summary +from qpd.utils.evaluator import Evaluator +from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.vec_env.base_vec_env import VecEnv +from torch.distributions import Categorical +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Dict, Optional, Union + +from .memory import Memory + +from ..networks import * +from ..networks.wrapper.model_wrapper import ModelWrapper +# from ..utils import test_network, merge_config + +class PolicyDistillation: + def __init__(self, teacher: ModelWrapper, student: ModelWrapper, config: Config): #, webdav_manager: Optional[WebdavManager] = None): + self.teacher = teacher + self.student = student + self.config = config + + self.epochs = self.config.compression_config.epochs + self.batch_size = self.config.compression_config.batch_size + self.device = self.config.compression_config.device + self.check_point_frequency = self.config.compression_config.checkpoint_frequency + self.mem_update_frequency = self.config.memory_config.update_frequency + + # self.webdav_manager = webdav_manager + self.results = defaultdict(list) + + self.student.to(self.device) + self.optimizer = torch.optim.RMSprop(self.student.parameters(), lr=config.student_config.learning_rate) + self.evaluator = Evaluator(config, verbose=True) + self.quantizer = self.get_quantizer() + + self.initial_environment = self.config.environment_constructor() + self.best_student = None + + self.memory_order = list(range(1, self.config.compression_config.epochs+1)) + self.memory = self.load_memory() + random.shuffle(self.memory_order) + + + def train(self): + best = -1 * float("inf") + for epoch in range(self.epochs): + print(f"Start epoch {epoch}") + epoch_loss = 0 + + self.student.train() + self.student.to(self.device) + + with torch.enable_grad(): + for states, outputs in self.memory.sample(self.batch_size): + loss = self.continuous_loss(states.to(self.device, dtype=torch.float), outputs.to(self.device)) if self.config.env_continuous_action_space else self.discrete_loss(states.to(self.device, dtype=torch.float), outputs.to(self.device)) + + assert not torch.isnan(loss), "Error: loss can't be nan!" + epoch_loss += self.step(loss) + + # wandb.log({"loss": epoch_loss, "Epoch": epoch}) + print(f"Epoch {epoch} completed, loss: {epoch_loss}") + self.results["Loss"].append({"Epoch": epoch, "Value": epoch_loss}) + if (epoch and epoch % self.config.evaluator_config.student_test_frequency == 0) or self.config.evaluator_config.student_test_frequency == 1: + best, is_best = self.test(epoch, best) + if is_best: + self.save_student("best") + + if (epoch and epoch % self.check_point_frequency == 0) or self.check_point_frequency == 1: + self.save_student(epoch) + + if epoch and epoch % self.mem_update_frequency == 0: + self.load_memory(epoch, self.memory) + + return self.best_student + # if self.webdav_manager: + # self.webdav_manager.upload_results(self.results) + + def load_memory(self, epoch: int=0, memory: Optional[Memory]=None) -> Memory: + memory_index = (self.memory_order.pop() if epoch else 0) + # memory_index = (epoch if epoch else self.memory_order.pop()) + path = f"{self.config.data_directory}/Resources/{str(self.initial_environment.envs[0].env) if self.config.is_vec_env else str(self.initial_environment.env)}/Memory/{self.teacher.__class__.__name__}/sd_{str(self.config.evaluator_config.student_driven)}/{memory_index}" + if os.path.exists(path): + tmp_memory = Memory(self.config, meta_class=self.teacher, path=path) + + if memory: + memory.load_memory(tmp_memory) + return memory + return tmp_memory + else: + print() + print("= Memory ==========================================================") + print(f"= Could not find memory with index: {memory_index}, generating new data...") + if not memory: + memory = Memory(self.config, meta_class=self.teacher) + steps = memory.max_size + else: + steps = self.config.memory_config.update_size + + print(f"= Collecting data for {steps}") + results = self.evaluator.run(self.student, steps=steps, collect=True, teacher=self.teacher) if self.config.evaluator_config.student_driven else self.evaluator.run(self.teacher, steps=steps, collect=True) + print("= Current average steps:", results["Steps"]) + print("= Current average reward:", results["Reward"]) + print("= Average amount of episode:", results["EpisodeCount"]) + print("===================================================================") + print() + + if "GPULab" in os.environ or self.config.evaluator_config.student_driven: + for _ in range(len(results["Memory"])): + memory.load_memory(results["Memory"].pop()) + gc.collect() + + else: + total_size = sum([len(m) for m in results["Memory"]]) + tmp_memory = Memory(self.config, size=total_size, meta_class=self.teacher) + for _ in range(len(results["Memory"])): + tmp_memory.load_memory(results["Memory"].pop()) + gc.collect() + tmp_memory.save(path) + if epoch: + memory.load_memory(tmp_memory) + else: + return tmp_memory + return memory + + + def discrete_loss(self, states:torch.Tensor, teacher_outputs:torch.Tensor) -> torch.Tensor: + T = self.config.compression_config.T + student_policy, student_value, _ = self.student(states) + + teacher_policy, teacher_value_std = teacher_outputs[:, :student_policy.shape[1]], teacher_outputs[:, student_policy.shape[1]:] + teacher_value = teacher_value_std[:,:teacher_policy.shape[1]] + + if self.config.compression_config.categorical: + dist = Categorical(logits=student_policy) + student_policy = dist.probs + dist2 = Categorical(logits=teacher_policy) + teacher_policy = dist2.probs + + policy_loss = nn.KLDivLoss(reduction="batchmean")(F.log_softmax(student_policy, dim=1), F.softmax(teacher_policy/T, dim=1)) #type: ignore + critic_importance = self.config.compression_config.critic_importance + if critic_importance and not student_value == None: + value_loss = nn.HuberLoss(reduction="mean")(student_value.to(self.device), teacher_value.to(self.device)) + value_sum = value_loss.sum().detach() + policy_sum = policy_loss.sum().detach() + if policy_sum < 1e-6 or value_sum < 1e-6: # Avoid divide by 0 + return (1-critic_importance) * policy_loss + critic_importance * value_loss + else: + return ((1-critic_importance) * (policy_loss / policy_sum) + critic_importance * (value_loss / value_sum)) * (value_sum + policy_sum) + else: + return policy_loss + + + def continuous_loss(self, states:torch.Tensor, teacher_outputs:torch.Tensor) -> torch.Tensor: + self.student.to(self.device) + student_policy, _, student_std = self.student(states) + # student_policy, student_value = student_outputs[:, :actions], student_outputs[:, actions:] + teacher_policy, teacher_value_std = teacher_outputs[:, :student_policy.shape[1]], teacher_outputs[:, student_policy.shape[1]:] + + + if self.config.compression_config.distribution == "Std": + teacher_std = teacher_value_std[:, -teacher_policy.shape[1]:] # type: ignore + + if not teacher_std.gt(0).all(): + teacher_std[teacher_std==0] = 1 + print("Fixing teacher std") + if not student_std.gt(0).all(): + student_std[student_std==0] = 1 + print("Fixing student std") + + if self.config.compression_config.loss == "KL": + policy_loss = torch.log((teacher_std / student_std)) + (student_std ** 2 + (student_policy - teacher_policy) ** 2) / (2 * (teacher_std ** 2)) - 1/2 # type: ignore + elif self.config.compression_config.loss == "Huber": + policy_loss = nn.HuberLoss(reduction="none")(student_policy, teacher_policy) + std_loss = nn.HuberLoss(reduction="none")(student_std, teacher_std) + policy_loss += std_loss + elif self.config.compression_config.loss == "MSE": + policy_loss = nn.MSELoss(reduction="none")(student_policy, teacher_policy) + std_loss = nn.MSELoss(reduction="none")(student_std, teacher_std) + policy_loss += std_loss + else: + raise RuntimeError("Unknown loss type: " + self.config.compression_config.loss) + else: + if self.config.compression_config.loss == "Huber": + policy_loss = nn.HuberLoss(reduction="none")(student_policy, teacher_policy) #type: ignore + elif self.config.compression_config.loss == "MSE": + policy_loss = nn.MSELoss(reduction="none")(student_policy, teacher_policy) #type: ignore + else: + raise RuntimeError("Unsupported combination of distribution and loss!") + + return policy_loss.sum() + + + def step(self, loss: torch.Tensor) -> float: + loss.backward() + self.optimizer.step() + self.optimizer.zero_grad() + self.student.zero_grad() + + if self.quantizer: + self.quantizer.quantize_params() + + return loss.item() + + + def get_quantizer(self) -> Optional[DorefaQuantizer]: + if self.config.quantization_config.enabled: + bits = self.config.quantization_config.bits + # self.student = distiller.modules.convert_model_to_distiller_lstm(self.student) + quantizer = DorefaQuantizer(self.student, self.optimizer, bits_activations=bits, bits_weights=bits, bits_bias=bits) + quantizer.prepare_model() + quantizer.quantize_params() + return quantizer + return None + + + def save_student(self, epoch: Union[int, str], name: str="student"): + checkpoint_directory = self.config.data_directory + base_path = f"{checkpoint_directory}/Data/{self.config.run_name}/compression/{self.config.compression_config.starting_time}" + Path(f"{base_path}/{epoch}").mkdir(parents=True, exist_ok=True) + + self.student.save(f"{base_path}/{epoch}/{name}") + + # if self.webdav_manager: + # self.webdav_manager.upload_model(student_save_path, epoch, not self.config["testing"]["local_checkpoints"]) + + def test(self, epoch: int, best: Optional[float]=None) -> Optional[float]: + is_best = False + print() + print("= Testing =========================================================") + results = self.evaluator.run(self.student, episodes=self.config.evaluator_config.episodes) + self.student.to(self.config.compression_config.device) + log = {"Epoch": epoch} + for metric in ["Steps", "Reward"]: + self.results[metric].append({"Epoch": epoch, "Value": results[metric]}) + log[metric] = results[metric] + + # wandb.log(log) + + print("= Current average steps:", results["Steps"]) + print("= Current average reward:", results["Reward"]) + print("= Average amount of episode:", results["EpisodeCount"]) + print("===================================================================") + print() + + if best and results["Reward"] > best: + # wandb.log({"Best Reward": results["Reward"]}) + best = results["Reward"] + is_best = True + + return best, is_best + + def get_action_space_shape(self): + action_space = self.student.policy.action_space + is_box = isinstance(action_space, Box) + assert is_box or isinstance(action_space, Discrete), "[ERROR]: Unsupported action space!" + + if is_box: + return action_space.shape + else: + return action_space.n + + diff --git a/src/qpd/compressor.py b/src/qpd/compressor.py new file mode 100644 index 0000000..b4b2bac --- /dev/null +++ b/src/qpd/compressor.py @@ -0,0 +1,77 @@ +from typing import Type +import qpd.networks.network_interface as ni + +from qpd.config import Config +from qpd.networks.wrapper.student.student_net import StudentNet +from qpd.utils import summary +from .compression.policy_distillation import PolicyDistillation +from .networks.wrapper.model_wrapper import ModelWrapper + +import json +from pathlib import Path +import torch +import torch.nn as nn + +class Compressor: + def __init__(self, model, environment_constructor, config): + self.model = model + self.environment_constructor = environment_constructor + self.config: Config = config + + self.teacher: ModelWrapper = ModelWrapper.construct_wrapper(self.config, self.model).to(torch.device(self.config.compression_config.device)) + self.student = None + self.config.init_env_model_based_params(self.teacher) + + + def student_network(self, net: Type[StudentNet]): + """Add student network to compressor with existing wrappers. + + Args: + net (Type[nn.Module]): Student network + + Returns: + self + """ + assert self.student == None, "Can not set second student!" + self.config.student_config.student_network = net + self.student_model = self.teacher.construct_student() + self.student = ModelWrapper.construct_wrapper(self.config, self.student_model).to(torch.device(self.config.compression_config.device)) + return self + + def student_model(self, student: Type["ni.NetworkInterface"]): + """Add end-to-end student model to compressor + + Args: + student (Type["ni.NetworkInterface"]): _description_ + + Returns: + _type_: _description_ + """ + assert self.student == None, "Can not set second student!" + self.student = student(self.config) + return self + + def set_environment_done_filter(self, func): + self.config.evaluator_config.environment_is_done_filter = func + return self + + def compress(self): + assert self.student is not None, "No student assigned! Use Compressor(**).student_network(**) or Use Compressor(**).student_model(**)" + + Path(f"{self.config.data_directory}/Data/{self.config.run_name}").mkdir(parents=True, exist_ok=True) + config_dump = json.dumps(self.config.serializable(), indent=4, sort_keys=True) + with open(f"{self.config.data_directory}/Data/{self.config.run_name}/config.json", "w") as f: + f.write(config_dump) + + print("Current configuration:") + print(config_dump) + + distillation = PolicyDistillation( + self.teacher, + self.student, + self.config + ) + + distillation.train() + + return distillation.student diff --git a/src/qpd/config.py b/src/qpd/config.py new file mode 100644 index 0000000..cd5ca7b --- /dev/null +++ b/src/qpd/config.py @@ -0,0 +1,165 @@ +import datetime + +from gymnasium.spaces import Box, Discrete + +import qpd.networks.wrapper.model_wrapper as wrapper +import qpd.utils.model_utils as model_utils +import torch +import torch.nn as nn + +from qpd.networks.wrapper.student.fully_connected_student import FCStudentNet +from stable_baselines3.common.vec_env.base_vec_env import VecEnv + +class AbstractConfig: + key = None + + def init(self, config: dict): + if self.key in config.keys(): + self.update(config[self.key]) + return self + + def update(self, config): + self.__dict__.update((k, config[k]) for k in set(config).intersection(self.__dict__)) + + def serializable(self) -> dict: + tmp = {} + + for key, value in self.__dict__.items(): + v = None + + if isinstance(value, AbstractConfig): + v = value.serializable() + elif isinstance(value, type): + v = str(value) + elif callable(value): + continue + else: + v = value + + tmp[key] = v + + return tmp + +class MemoryConfig(AbstractConfig): + key = "memory" + def __init__(self): + self.size = 54000 + self.update_frequency = 1 + self.update_size = 5400 + self.device = "cpu" + self.check_consistency = True + self.frame_stack_optimization = False + +class StudentNetConfig(AbstractConfig): + key = "student" + def __init__(self): + self.student_network = FCStudentNet + self.learning_rate = 1e-4 + self.activation_func = nn.ReLU + self.extra_student_kwargs = {} + +class CompressionConfig(AbstractConfig): + key = "compression" + def __init__(self): + self.checkpoint_frequency = 2 + self.starting_time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S") + self.epochs = 600 + self.batch_size = 256 + self.T = 0.01 + self.device = torch.device("cuda") + self.critic_importance = 0.5 + self.categorical = False + self.loss = "MSE" + self.distribution = "Std" + +class EvaluatorConfig(AbstractConfig): + key = "evaluator" + def __init__(self): + self.environment_is_done_filter = lambda env_info: True + self.student_test_frequency = 1 + self.episodes = 50 + self.initialize = 30 + self.env_subproc = False + self.env_workers = 1 + self.ray_workers = 24 + self.device = "cpu" + self.deterministic = False + self.student_driven = False + self.exploration_rate = 0. + +class QuantizationConfig(AbstractConfig): + key = "quantization" + def __init__(self): + self.enabled = False + self.bits = 8 + +class Config(AbstractConfig): + def __init__(self, environment_constructor, config={}): + self.environment_constructor = lambda: environment_constructor(self) + + self.memory_config = MemoryConfig().init(config) + self.student_config = StudentNetConfig().init(config) + self.compression_config = CompressionConfig().init(config) + self.evaluator_config = EvaluatorConfig().init(config) + self.quantization_config = QuantizationConfig().init(config) + + self.verbose = True + self.data_directory = "." + self.run_name = "Test" + + # AUTO-DEFINED + self.model_has_critic = False + self.observation_shape = None + self.action_shape = None + self.output_shape = None + self.is_vec_env = False + self.env_continuous_action_space = False + + self.update(config) + + def update_memory_config(self, **kwargs): + self.memory_config.update(**kwargs) + + def update_student_config(self, **kwargs): + self.student_config.update(**kwargs) + + def update_compressor_config(self, **kwargs): + self.compression_config.update(**kwargs) + + def update_evaluator_config(self, **kwargs): + self.evaluator_config.update(**kwargs) + + def update_quantization_config(self, **kwargs): + self.quantization_config.update(**kwargs) + + def init_env_model_based_params(self, wrapper: "wrapper.ModelWrapper"): + wrapper.to(self.compression_config.device) + initialized_environment = self.environment_constructor() + + action_space = initialized_environment.action_space + is_box = isinstance(action_space, Box) + assert is_box or isinstance(action_space, Discrete), f"[ERROR]: Unsupported action space of type {type(action_space)}!" + + state = initialized_environment.reset() # Get state to determin shape of observation + state = wrapper.observation_to_tensor(state).to(self.compression_config.device).float() # get usable state and determin if env is vec_env + + self.is_vec_env = isinstance(initialized_environment, VecEnv) + + if self.is_vec_env: + self.observation_shape = state.shape[1:] + else: + self.observation_shape = state.shape + + if is_box: + self.env_continuous_action_space = True + else: + self.env_continuous_action_space = False + + action, value, action_std = wrapper(state) + + self.model_has_critic = True if value is not None else False + + outputs = model_utils.get_output(action[0], value[0] if value is not None else None, action_std[0] if action_std is not None else None) if self.is_vec_env else model_utils.get_output(action, value, action_std) + + self.action_shape = action.shape[1:] if self.is_vec_env else action.shape + self.output_shape = outputs.shape \ No newline at end of file diff --git a/src/qpd/networks/__init__.py b/src/qpd/networks/__init__.py new file mode 100644 index 0000000..b3f932d --- /dev/null +++ b/src/qpd/networks/__init__.py @@ -0,0 +1,2 @@ +# from .student import Student +# from .teacher import Teacher, TeacherA2C, TeacherDQN, TeacherActorCritic, TeacherPPO, TeacherQRDQN #type: ignore diff --git a/src/qpd/networks/models/__init__.py b/src/qpd/networks/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/qpd/networks/models/student_six_model.py b/src/qpd/networks/models/student_six_model.py new file mode 100644 index 0000000..2d5b174 --- /dev/null +++ b/src/qpd/networks/models/student_six_model.py @@ -0,0 +1,45 @@ +from qpd.config import Config +from qpd.networks.network_interface import NetworkInterface +import torch +import torch.nn as nn +from torch.distributions import Normal + +class StudentSixModel(NetworkInterface): + def __init__(self, config: Config): + super(StudentSixModel, self).__init__(config) + + self.net = nn.Sequential( + nn.Linear(self.config.observation_shape[0], 64), + nn.ReLU(), + + nn.Linear(64, 64), + nn.ReLU(), + + nn.Linear(64, 64), + nn.ReLU() + ) + + self.action_head = nn.Linear(64, 6) + self.std_head = nn.Linear(64, 6) + self.critic_head = nn.Linear(64, 1) + + for module in self.children(): + if type(module) == nn.Linear: + module.bias.data.uniform_(0.0) # type: ignore + module.weight.data.uniform_(0, 0.01) # type: ignore + + def forward(self, observations): + x = self.net(observations) + mean_actions = self.action_head(x) + std = self.std_head(x)**2 + dist = Normal(mean_actions, std) + self.actions = dist.sample() + + value = self.critic_head(x) + + value = torch.cat([value, value], dim=0) + + return mean_actions, value, std + + def get_actions(self): + return self.actions.cpu().numpy() \ No newline at end of file diff --git a/src/qpd/networks/models/student_six_model_dqn_.py b/src/qpd/networks/models/student_six_model_dqn_.py new file mode 100644 index 0000000..86e6fa5 --- /dev/null +++ b/src/qpd/networks/models/student_six_model_dqn_.py @@ -0,0 +1,37 @@ +from qpd.config import Config +from qpd.networks.network_interface import NetworkInterface +import torch +import torch.nn as nn +from torch.distributions import Normal + +class StudentSixModelDQN(NetworkInterface): + def __init__(self, config: Config): + super(StudentSixModelDQN, self).__init__(config) + + self.net = nn.Sequential( + nn.Linear(self.config.observation_shape[0], 64), + nn.ReLU(), + + nn.Linear(64, 64), + nn.ReLU(), + + nn.Linear(64, 64), + nn.ReLU() + ) + + self.action_head = nn.Linear(64, 2) + + for module in self.children(): + if type(module) == nn.Linear: + module.bias.data.uniform_(0.0) # type: ignore + module.weight.data.uniform_(0, 0.01) # type: ignore + + def forward(self, observations): + x = self.net(observations) + self.mean_actions = self.action_head(x) + + return self.mean_actions, None, None + + def get_actions(self): + actions = self.mean_actions.argmax(dim=1).reshape(-1) + return actions.cpu().numpy() \ No newline at end of file diff --git a/src/qpd/networks/models/student_tiny_dqn.py b/src/qpd/networks/models/student_tiny_dqn.py new file mode 100644 index 0000000..557231f --- /dev/null +++ b/src/qpd/networks/models/student_tiny_dqn.py @@ -0,0 +1,37 @@ +from qpd.config import Config +from qpd.networks.network_interface import NetworkInterface +import torch +import torch.nn as nn +from torch.distributions import Normal + +class TinyStudentDQN(NetworkInterface): + def __init__(self, config: Config): + super(TinyStudentDQN, self).__init__(config) + + self.net = nn.Sequential( + nn.Linear(self.config.observation_shape[0], 16), + nn.ReLU(), + + nn.Linear(16, 16), + nn.ReLU(), + + # nn.Linear(16, 16), + # nn.ReLU() + ) + + self.action_head = nn.Linear(16, 2) + + for module in self.children(): + if type(module) == nn.Linear: + module.bias.data.uniform_(0.0) # type: ignore + module.weight.data.uniform_(0, 0.01) # type: ignore + + def forward(self, observations): + x = self.net(observations) + self.mean_actions = self.action_head(x) + + return self.mean_actions, None, None + + def get_actions(self): + actions = self.mean_actions.argmax(dim=1).reshape(-1) + return actions.cpu().numpy() \ No newline at end of file diff --git a/src/qpd/networks/network_interface.py b/src/qpd/networks/network_interface.py new file mode 100644 index 0000000..4706e2f --- /dev/null +++ b/src/qpd/networks/network_interface.py @@ -0,0 +1,25 @@ +import torch +import torch.nn as nn + +import qpd.config as config + +class NetworkInterface(nn.Module): + def __init__(self, config: "config.Config"): + super(NetworkInterface, self).__init__() + self.config: "config.Config" = config + + def forward(self, x): + raise NotImplementedError(f"Forward method not implemented for {type(self)}!") + + def get_actions(self): + raise NotImplementedError(f"Actions method not implemented for {type(self)}!") + + def save(self, path): + torch.save(self.state_dict(), path + ".pth") + + def load(self, path): + self.load_state_dict(torch.load(path)) + self.eval() + + def observation_to_tensor(self, x): + return torch.from_numpy(x) \ No newline at end of file diff --git a/src/qpd/networks/wrapper/__init__.py b/src/qpd/networks/wrapper/__init__.py new file mode 100644 index 0000000..6186844 --- /dev/null +++ b/src/qpd/networks/wrapper/__init__.py @@ -0,0 +1,7 @@ +from typing import List, Type + +import qpd.networks.wrapper.model_wrapper as wrapper +registered_teachers: List[Type[wrapper.ModelWrapper]] = [] + +# from .stable_baselines import * +import qpd.networks.wrapper.stable_baselines diff --git a/src/qpd/networks/wrapper/model_wrapper.py b/src/qpd/networks/wrapper/model_wrapper.py new file mode 100644 index 0000000..c6998dc --- /dev/null +++ b/src/qpd/networks/wrapper/model_wrapper.py @@ -0,0 +1,19 @@ +import qpd.config as config +from qpd.networks.network_interface import NetworkInterface + +import qpd.networks.wrapper as wrapper + +class ModelWrapper(NetworkInterface): + algorithm_type = None + + def __init__(self, config: "config.Config", model): + super(ModelWrapper, self).__init__(config) + self.original_type = type(model) + self.policy = None + + def construct_student(self): + raise NotImplementedError("Construct student method not implemented!") + + @staticmethod + def construct_wrapper(config, model): + return next(teacher for teacher in wrapper.registered_teachers if teacher.algorithm_type == type(model))(config, model) \ No newline at end of file diff --git a/src/qpd/networks/wrapper/stable_baselines/__init__.py b/src/qpd/networks/wrapper/stable_baselines/__init__.py new file mode 100644 index 0000000..244d159 --- /dev/null +++ b/src/qpd/networks/wrapper/stable_baselines/__init__.py @@ -0,0 +1,16 @@ +# from .sb_wrapper import SBWrapper +# from .value_iteration import WrapperValueIteration +# from .on_actor_critic import WrapperONPActorCritic +# from .qrdqn import WrapperQRDQN +# from .ppo import WrapperPPO +# from .dqn import WrapperDQN +# from .a2c import WrapperA2C +# from .sac import WrapperSAC +import qpd.networks.wrapper.stable_baselines.sb_wrapper +import qpd.networks.wrapper.stable_baselines.value_iteration +import qpd.networks.wrapper.stable_baselines.on_actor_critic +import qpd.networks.wrapper.stable_baselines.qrdqn +import qpd.networks.wrapper.stable_baselines.ppo +import qpd.networks.wrapper.stable_baselines.dqn +import qpd.networks.wrapper.stable_baselines.a2c +import qpd.networks.wrapper.stable_baselines.sac \ No newline at end of file diff --git a/src/qpd/networks/wrapper/stable_baselines/a2c.py b/src/qpd/networks/wrapper/stable_baselines/a2c.py new file mode 100644 index 0000000..45caa1f --- /dev/null +++ b/src/qpd/networks/wrapper/stable_baselines/a2c.py @@ -0,0 +1,8 @@ +import qpd.networks.wrapper.stable_baselines.on_actor_critic as sboac + +from ..wrapper_decorator import wrapper +from stable_baselines3 import A2C + +@wrapper +class WrapperA2C(sboac.WrapperONPActorCritic): + algorithm_type = A2C \ No newline at end of file diff --git a/src/qpd/networks/wrapper/stable_baselines/dqn.py b/src/qpd/networks/wrapper/stable_baselines/dqn.py new file mode 100644 index 0000000..e5a0c6d --- /dev/null +++ b/src/qpd/networks/wrapper/stable_baselines/dqn.py @@ -0,0 +1,18 @@ +import numpy as np +import gym + +from stable_baselines3 import DQN +from stable_baselines3.common.preprocessing import maybe_transpose +from stable_baselines3.common.utils import is_vectorized_observation +import qpd.networks.wrapper.stable_baselines.value_iteration as sbvi + +from ..wrapper_decorator import wrapper + +@wrapper +class WrapperDQN(sbvi.WrapperValueIteration): + algorithm_type = DQN + + def forward(self, x): + self.observation = x + self.q_values = self.policy.q_net(self.observation) + return self.q_values, None, None \ No newline at end of file diff --git a/src/qpd/networks/wrapper/stable_baselines/off_actor_critic.py b/src/qpd/networks/wrapper/stable_baselines/off_actor_critic.py new file mode 100644 index 0000000..aa48348 --- /dev/null +++ b/src/qpd/networks/wrapper/stable_baselines/off_actor_critic.py @@ -0,0 +1,25 @@ +from typing import Type, Union + +import qpd.networks.wrapper.stable_baselines.sb_wrapper as sbw +from qpd.networks.wrapper.wrapper_decorator import wrapper + +from qpd.networks.wrapper.model_wrapper import ModelWrapper +from stable_baselines3.sac.sac import SAC +import torch + +class WrapperOFFPActorCritic(sbw.SBWrapper): + def forward(self, obs): + mean_actions, log_std, kwargs = self.policy.actor.get_action_dist_params(obs) + self.actions = self.policy.actor.action_dist.actions_from_params(mean_actions, log_std, deterministic=self.config.evaluator_config.deterministic, **kwargs) + + value = torch.cat(self.policy.critic(obs, self.actions), dim=1) + + action_std = None + + if self.config.compression_config.distribution == "Std": + action_std = torch.ones_like(mean_actions) * log_std.exp() + + return mean_actions, value, action_std + + def _get_actions(self): + return self.actions diff --git a/src/qpd/networks/wrapper/stable_baselines/on_actor_critic.py b/src/qpd/networks/wrapper/stable_baselines/on_actor_critic.py new file mode 100644 index 0000000..53704fe --- /dev/null +++ b/src/qpd/networks/wrapper/stable_baselines/on_actor_critic.py @@ -0,0 +1,62 @@ +import torch + +from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.distributions import ( + BernoulliDistribution, + CategoricalDistribution, + DiagGaussianDistribution, + MultiCategoricalDistribution, + StateDependentNoiseDistribution +) +from stable_baselines3.common.policies import BasePolicy +import qpd.networks.wrapper.stable_baselines.sb_wrapper as sbw + +class WrapperONPActorCritic(sbw.SBWrapper): + def forward(self, obs): + features = self.policy.extract_features(obs) + if self.policy.share_features_extractor: + self.latent_pi, latent_vf = self.policy.mlp_extractor(features) + else: + pi_features, vf_features = features + self.latent_pi = self.policy.mlp_extractor.forward_actor(pi_features) + latent_vf = self.policy.mlp_extractor.forward_critic(vf_features) + + self.mean_actions = self.policy.action_net(self.latent_pi) + value = self.policy.value_net(latent_vf) + + action_std = None + + if self.config.compression_config.distribution == "Std": + if isinstance(self.policy.action_dist, StateDependentNoiseDistribution): + variance = torch.mm(self.latent_pi ** 2, self.policy.action_dist.get_std(self.policy.log_std) ** 2) + action_std = torch.sqrt(variance + self.policy.action_dist.epsilon) + + elif isinstance(self.policy.action_dist, DiagGaussianDistribution): + action_std = torch.ones_like(self.mean_actions) * self.policy.log_std.exp() + elif isinstance(self.policy.action_dist, CategoricalDistribution): + action_std = None # In case of discrete action space + else: + raise RuntimeError("Invalid distribution type!") + + return self.mean_actions, value, action_std + + def _get_actions(self): + distribution = None + + if isinstance(self.policy.action_dist, DiagGaussianDistribution): + distribution = self.policy.action_dist.proba_distribution(self.mean_actions, self.policy.log_std) + elif isinstance(self.policy.action_dist, CategoricalDistribution): + # Here mean_actions are the logits before the softmax + distribution = self.policy.action_dist.proba_distribution(action_logits=self.mean_actions) + elif isinstance(self.policy.action_dist, MultiCategoricalDistribution): + # Here mean_actions are the flattened logits + distribution = self.policy.action_dist.proba_distribution(action_logits=self.mean_actions) + elif isinstance(self.policy.action_dist, BernoulliDistribution): + # Here mean_actions are the logits (before rounding to get the binary actions) + distribution = self.policy.action_dist.proba_distribution(action_logits=self.mean_actions) + elif isinstance(self.policy.action_dist, StateDependentNoiseDistribution): + distribution = self.policy.action_dist.proba_distribution(self.mean_actions, self.policy.log_std, self.latent_pi) + else: + raise ValueError("Invalid action distribution") + + return distribution.get_actions(self.config.evaluator_config.deterministic) diff --git a/src/qpd/networks/wrapper/stable_baselines/ppo.py b/src/qpd/networks/wrapper/stable_baselines/ppo.py new file mode 100644 index 0000000..0dfdec7 --- /dev/null +++ b/src/qpd/networks/wrapper/stable_baselines/ppo.py @@ -0,0 +1,10 @@ +import qpd.networks.wrapper.stable_baselines.on_actor_critic as sboac + +from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.policies import BasePolicy +from ..wrapper_decorator import wrapper +from stable_baselines3 import PPO + +@wrapper +class WrapperPPO(sboac.WrapperONPActorCritic): + algorithm_type = PPO \ No newline at end of file diff --git a/src/qpd/networks/wrapper/stable_baselines/qrdqn.py b/src/qpd/networks/wrapper/stable_baselines/qrdqn.py new file mode 100644 index 0000000..72b0787 --- /dev/null +++ b/src/qpd/networks/wrapper/stable_baselines/qrdqn.py @@ -0,0 +1,12 @@ +from ..wrapper_decorator import wrapper +import qpd.networks.wrapper.stable_baselines.value_iteration as sbvi +from sb3_contrib import QRDQN + +@wrapper +class WrapperQRDQN(sbvi.WrapperValueIteration): + algorithm_type = QRDQN + + def forward(self, obs): + self.observation = obs + self.q_values = self.policy.quantile_net(self.observation) + return self.q_values.view(-1, self.policy.n_quantiles, self.policy.action_space.n).mean(dim=1), None, None \ No newline at end of file diff --git a/src/qpd/networks/wrapper/stable_baselines/sac.py b/src/qpd/networks/wrapper/stable_baselines/sac.py new file mode 100644 index 0000000..14defec --- /dev/null +++ b/src/qpd/networks/wrapper/stable_baselines/sac.py @@ -0,0 +1,8 @@ +import qpd.networks.wrapper.stable_baselines.off_actor_critic as sbac +from qpd.networks.wrapper.wrapper_decorator import wrapper + +from stable_baselines3.sac.sac import SAC + +@wrapper +class WrapperSAC(sbac.WrapperOFFPActorCritic): + algorithm_type = SAC \ No newline at end of file diff --git a/src/qpd/networks/wrapper/stable_baselines/sb_wrapper.py b/src/qpd/networks/wrapper/stable_baselines/sb_wrapper.py new file mode 100644 index 0000000..c8369de --- /dev/null +++ b/src/qpd/networks/wrapper/stable_baselines/sb_wrapper.py @@ -0,0 +1,67 @@ +from typing import Type, Union + +import gym +import numpy as np +import torch + +import qpd.networks.wrapper.stable_baselines.student as student +import qpd.networks.wrapper.model_wrapper as model_wrapper +from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.policies import BasePolicy + +class SBWrapper(model_wrapper.ModelWrapper): + algorithm_type: Type[BaseAlgorithm] = None + + def __init__(self, config, model: Union[BaseAlgorithm | BasePolicy]): + super(SBWrapper, self).__init__(config, model) + self.policy: BasePolicy = model.policy if isinstance(model, BaseAlgorithm) else model + assert isinstance(self.policy, BasePolicy), "Incorrect input model type!" + + def forward(self, x): + raise NotImplementedError("Not implemented forward method!") + + def _get_actions(self): + raise NotImplementedError("Not implemented _getactions method!") + + def observation_to_tensor(self, x): + return self.policy.obs_to_tensor(x)[0] + + def save(self, path): + student = self.construct_student() + student.policy = self.policy + + student.save(path + ".zip") + + def load(self, path): + self.policy = self.algorithm_type.load(path).policy + + def get_actions(self): + with torch.no_grad(): + actions = self._get_actions() + + actions = actions.cpu().numpy() + if isinstance(self.policy.action_space, gym.spaces.Box): + if self.policy.squash_output: + # Rescale to proper domain when using squashing + actions = self.policy.unscale_action(actions) + else: + # Actions could be on arbitrary scale, so clip the actions to avoid + # out of bound error (e.g. if sampling from a Gaussian distribution) + actions = np.clip(actions, self.policy.action_space.low, self.policy.action_space.high) + + return actions + + def construct_student(self): + tmp = { + "learning_rate": self.config.student_config.learning_rate, + "policy_kwargs": { + "net_arch": [], + "features_extractor_class": student.SBStudentNet, + "features_extractor_kwargs": { + "network": self.config.student_config.student_network + }, + "activation_fn": self.config.student_config.activation_func + } + } + tmp.update(self.config.student_config.extra_student_kwargs) + return self.original_type(type(self.policy), self.config.environment_constructor(), **tmp) \ No newline at end of file diff --git a/src/qpd/networks/wrapper/stable_baselines/student.py b/src/qpd/networks/wrapper/stable_baselines/student.py new file mode 100644 index 0000000..ee35e61 --- /dev/null +++ b/src/qpd/networks/wrapper/stable_baselines/student.py @@ -0,0 +1,21 @@ +from typing import Type +import torch as th +import torch.nn as nn +from gym import spaces + +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor + +from qpd.networks.wrapper.student.student_net import StudentNet + +from ..student.cnn_student import CNNStudentNet + +class SBStudentNet(BaseFeaturesExtractor): + def __init__(self, observation_space: spaces.Space, features_dim: int = 16, network: Type[StudentNet]=CNNStudentNet): + super(SBStudentNet, self).__init__(observation_space, features_dim) + + self.network = network(observation_space, features_dim) + + self._features_dim = self.network.features_dim + + def forward(self, observations: th.Tensor) -> th.Tensor: + return self.network(observations) diff --git a/src/qpd/networks/wrapper/stable_baselines/value_iteration.py b/src/qpd/networks/wrapper/stable_baselines/value_iteration.py new file mode 100644 index 0000000..47668d8 --- /dev/null +++ b/src/qpd/networks/wrapper/stable_baselines/value_iteration.py @@ -0,0 +1,22 @@ +import numpy as np +import gym + +from stable_baselines3.common.preprocessing import maybe_transpose +from stable_baselines3.common.utils import is_vectorized_observation + +import qpd.networks.wrapper.stable_baselines.sb_wrapper as sbw + +class WrapperValueIteration(sbw.SBWrapper): + def _get_actions(self): + if not self.config.evaluator_config.deterministic and np.random.rand() < self.config.evaluator_config.exploration_rate: + if is_vectorized_observation(maybe_transpose(self.observation, self.policy.observation_space), self.policy.observation_space): + if isinstance(self.policy.observation_space, gym.spaces.Dict): + n_batch = self.observation[list(self.observation.keys())[0]].shape[0] + else: + n_batch = self.observation.shape[0] + action = np.array([self.action_space.sample() for _ in range(n_batch)]) + else: + action = np.array(self.action_space.sample()) + else: + action = self.q_values.argmax(dim=1).reshape(-1) + return action \ No newline at end of file diff --git a/src/qpd/networks/wrapper/student/__init__.py b/src/qpd/networks/wrapper/student/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/qpd/networks/wrapper/student/cnn_student.py b/src/qpd/networks/wrapper/student/cnn_student.py new file mode 100644 index 0000000..e26d3a8 --- /dev/null +++ b/src/qpd/networks/wrapper/student/cnn_student.py @@ -0,0 +1,39 @@ +from typing import Type +import torch as th +import torch.nn as nn +from gym import spaces + +import qpd.networks.wrapper.student.student_net as student + +class CNNStudentNet(student.StudentNet): + def __init__(self, observation_space: spaces.Box, n_output): + super(CNNStudentNet, self).__init__(observation_space, n_output) + # We assume CxHxW images (channels first) + # Re-ordering will be done by pre-preprocessing or wrapper + n_input_channels = observation_space.shape[0] + + self.cnn = nn.Sequential( + nn.Conv2d(n_input_channels, 16, kernel_size=8, stride=4), + nn.ReLU(), + + nn.Conv2d(16, 16, kernel_size=4, stride=2), + nn.ReLU(), + + nn.Conv2d(16, 16, kernel_size=3, stride=1), + nn.ReLU(), + nn.Flatten() + ) + + # Compute shape by doing one forward pass + with th.no_grad(): + n_flatten = self.cnn( + th.as_tensor(observation_space.sample()[None]).float() + ).shape[1] + + self.linear = nn.Sequential( + nn.Linear(n_flatten, self.features_dim), + nn.ReLU() + ) + + def forward(self, observations: th.Tensor) -> th.Tensor: + return self.linear(self.cnn(observations)) \ No newline at end of file diff --git a/src/qpd/networks/wrapper/student/fully_connected_student.py b/src/qpd/networks/wrapper/student/fully_connected_student.py new file mode 100644 index 0000000..149ae16 --- /dev/null +++ b/src/qpd/networks/wrapper/student/fully_connected_student.py @@ -0,0 +1,31 @@ +from typing import Type +import torch as th +import torch.nn as nn +from gym import spaces + +import qpd.networks.wrapper.student.student_net as student + +class FCStudentNet(student.StudentNet): + def __init__(self, observation_space: spaces.Box, n_output: int): + super(FCStudentNet, self).__init__(observation_space, 64) + + n_input_channels = observation_space.shape[0] + + self.net = nn.Sequential( + nn.Linear(n_input_channels, 64), + nn.ReLU(), + + nn.Linear(64, 64), + nn.ReLU(), + + nn.Linear(64, self.features_dim), + nn.ReLU() + ) + + for module in self.children(): + if type(module) == nn.Linear: + module.bias.data.uniform_(0.0) # type: ignore + module.weight.data.uniform_(0, 0.01) # type: ignore + + def forward(self, observations: th.Tensor) -> th.Tensor: + return self.net(observations) \ No newline at end of file diff --git a/src/qpd/networks/wrapper/student/student_net.py b/src/qpd/networks/wrapper/student/student_net.py new file mode 100644 index 0000000..61a7395 --- /dev/null +++ b/src/qpd/networks/wrapper/student/student_net.py @@ -0,0 +1,14 @@ +from typing import Type +import torch as th +import torch.nn as nn +from gym import spaces + +class StudentNet(nn.Module): + def __init__(self, observation_space: spaces.Box, n_output: int): + super(StudentNet, self).__init__() + + self._n_output = n_output + + @property + def features_dim(self): + return self._n_output \ No newline at end of file diff --git a/src/qpd/networks/wrapper/wrapper_decorator.py b/src/qpd/networks/wrapper/wrapper_decorator.py new file mode 100644 index 0000000..ab16fe8 --- /dev/null +++ b/src/qpd/networks/wrapper/wrapper_decorator.py @@ -0,0 +1,5 @@ +from . import registered_teachers + +def wrapper(cls): + registered_teachers.append(cls) + return cls \ No newline at end of file diff --git a/src/qpd/utils/__init__.py b/src/qpd/utils/__init__.py new file mode 100644 index 0000000..de1943e --- /dev/null +++ b/src/qpd/utils/__init__.py @@ -0,0 +1,4 @@ +from .configuration import get_config, get_config_mods, merge_config, import_config +from .summary import * +from .model_utils import get_network +from .webdav import WebdavManager diff --git a/src/qpd/utils/configuration.py b/src/qpd/utils/configuration.py new file mode 100644 index 0000000..94ad467 --- /dev/null +++ b/src/qpd/utils/configuration.py @@ -0,0 +1,91 @@ +from argparse import ArgumentParser +import copy + +def convert_nested_insert(dct, lst): + for x in lst[:-2]: + dct[x] = dct = dct.get(x, dict()) + dct.update({lst[-2]: lst[-1]}) + + +def convert_nested(dct): + # empty dict to store the result + result = dict() + + # create an iterator of lists + # representing nested or hierarchical flow + lsts = ([*k.split("."), v] for k, v in dct.items()) + + # insert each list into the result + for lst in lsts: + convert_nested_insert(result, lst) + return result + + +def merge_config(source, destination): + destination = copy.deepcopy(destination) + for key, value in source.items(): + if isinstance(value, dict): + if key in destination and isinstance(destination[key], dict): + destination_value = destination[key] + else: + destination_value = {} + destination[key] = merge_config(value, destination_value) + else: + try: + value = eval(value) + except Exception: + pass + + destination[key] = value + + return destination + + +def get_config_mods(known_args): + config_mods = dict(zip(map(lambda x: x[2:], known_args[1][:-1:2]), known_args[1][1::2])) + config_mods = convert_nested(config_mods) + return config_mods + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument('--config', type=str, default="experiment") + + known_args = parser.parse_known_args() + + args = known_args[0] + config_mods = get_config_mods(known_args) + + return args, config_mods + + +def get_config(): + args, config_mods = parse_args() + + try: + config_module = __import__(f"configs.{args.config}", fromlist=['']) + configs = config_module.config + + if config_mods is not None: + configs = merge_config(config_mods, configs) + + return configs + except ModuleNotFoundError as ex: + print(f"ERROR: Could not find the experiment config file: {ex}") + quit() + + +def import_config(config, experiment_config): + config["network"]["conv1"] = experiment_config["network"]["conv1"] + config["network"]["conv2"] = experiment_config["network"]["conv2"] + config["network"]["conv3"] = experiment_config["network"]["conv3"] + config["network"]["linear"] = experiment_config["network"]["linear"] + config["network"]["teacher"] = experiment_config["network"]["teacher"] + if "critic_importance" in experiment_config["compression"]: + config["compression"]["critic_importance"] = experiment_config["compression"]["critic_importance"] + config["testing"]["randomness"] = experiment_config["testing"]["randomness"] + config["quantization"]["enabled"] = experiment_config["quantization"]["enabled"] + config["compression"]["epochs"] = experiment_config["compression"]["epochs"] + config["train_run"] = experiment_config["run_name"] + return config + diff --git a/src/qpd/utils/evaluator.py b/src/qpd/utils/evaluator.py new file mode 100644 index 0000000..65801a6 --- /dev/null +++ b/src/qpd/utils/evaluator.py @@ -0,0 +1,303 @@ +from abc import abstractmethod +import json +import torch +import time +import numpy +import random +import ray + +from typing import Dict, List, Optional +from torch.distributions import Normal + +from qpd.config import Config +from qpd.networks.wrapper.model_wrapper import ModelWrapper +from qpd.utils.model_utils import get_output +from ..compression.memory import Memory +from dataclasses import dataclass, field + + +@dataclass +class Worker: + idx: int = 0 + action_queue: List[int] = field(default_factory=list) + memory: Optional[Memory] = None + finished: bool = False + steps: int = 0 + rewards: int = 0 + repeats: int = 0 + + def reset(self, number: int): + self.action_queue = [0 for _ in range(int(numpy.round(random.random() * number)))] + self.steps = 0 + self.rewards = 0 + self.repeats = 0 + + +@dataclass +class Stats: + repeats: int = 0 + frames: int = 0 + effective_frames: int = 0 + steps: List[int] = field(default_factory=list) + rewards: List[int] = field(default_factory=list) + + def update(self, worker: Worker): + self.repeats += worker.repeats + self.effective_frames += worker.steps + self.steps.append(worker.steps - worker.repeats) + self.rewards.append(worker.rewards) + + +class Evaluator: + def __init__(self, config: Config, verbose: bool=False) -> None: + self.config = config + # self.env = create_env(config) + self.verbose = verbose + + ray.init( + _system_config={ + "local_fs_capacity_threshold": 0.9999999999, + "object_spilling_config": json.dumps( + { + "type": "filesystem", + "params": { + "directory_path": f"{config.data_directory}/spill" + } + }, + ) + }, + ) + + def run(self, network: torch.nn.Module, collect: bool=False, steps: int=0, episodes: int=0, teacher=None) -> Dict: + network.eval() + network.to(torch.device(self.config.evaluator_config.device)) + + @ray.remote(num_gpus=0.1) + def run_ray(function, *args, **kwargs): + return function(*args, **kwargs) + + test_function = Evaluator.test_vec if self.config.is_vec_env else Evaluator.test + + if self.config.evaluator_config.ray_workers: + worker_steps = list(map(lambda x: len(x), numpy.array_split(numpy.arange(steps), self.config.evaluator_config.ray_workers))) + worker_episodes = list(map(lambda x: len(x), numpy.array_split(numpy.arange(episodes), self.config.evaluator_config.ray_workers))) + old_metrics = ray.get( + [run_ray.remote( + test_function, + self.config, + network, + steps=s, + episodes=e, + collect=collect, + verbose=self.verbose, + teacher=teacher + ) for s, e in zip(worker_steps, worker_episodes)] + ) + metrics = {item: sum(list(map(lambda x: x[item], old_metrics)), []) for item in ["Memory"]} + metrics.update({item: numpy.mean(list(map(lambda x: x[item], old_metrics))) for item in ["Reward", "Steps", "EpisodeCount"]}) + # metrics = {item: sum(list(map(lambda x: x[item], old_metrics)), []) for item in ["Memory", "Reward", "Steps"]} + else: + metrics = test_function( + self.config, + network, + steps=steps, + episodes=episodes, + collect=collect, + verbose=self.verbose, + teacher=teacher) + + network.train() + return metrics + + @abstractmethod + def test(config: Config, + model_wrapper: ModelWrapper, + steps: int=0, + episodes: int=0, + collect: bool=False, + verbose=True, + teacher: ModelWrapper=None) -> Dict: + + env = config.environment_constructor() + # model = None + # model = model_wrapper.original_type(type(model_wrapper.policy), env, device=config.evaluator_config.device, **config.build_student_kwargs()) + # model.policy = model_wrapper.policy + + if teacher: + teacher.to(device=config.evaluator_config.device) + + total_rewards = [] + total_steps = [] + frames = 0 + effective_frames = 0 + + with torch.no_grad(): + env_state = env.reset() + + worker_action_queue = [0 for _ in range(int(numpy.round(random.random() * config.evaluator_config.initialize)))] + worker_steps_since_last_reward = 0 + worker_steps = 0 + worker_finished = False + + worker_memory = None + if collect: + worker_memory = Memory(config, size=int(steps) + 10000) + + start = time.time() + if verbose: + print(f"Start workers") + while not numpy.all(worker_finished): + current_actions = [] + + state, _ = model_wrapper.observation_to_tensor(env_state) + action, value, action_std = model_wrapper(state) + + if len(worker_action_queue) > 0: + current_actions.append(worker_action_queue.pop()) + else: + current_actions = model_wrapper.get_actions() + + if collect: + if teacher and teacher is not model_wrapper: + action, value, action_std = teacher(state) + + if not worker_finished: + if worker_memory.will_overflow: + print(f"Growing memory to size {worker_memory.max_size + 5000}!") + worker_memory.grow(5000) + + worker_memory.update( + state.unsqueeze(0), + get_output(action, value, action_std), + worker_steps == 0 + ) + + env_state, reward, done, infos = env.step(current_actions) #type:ignore + + frames += 1 + + worker_steps += 1 + if reward > 0: + worker_steps_since_last_reward = 0 + else: + worker_steps_since_last_reward += 1 + + if done: + if config.evaluator_config.environment_is_done_filter(infos): + assert "episode" in infos, "No episode key in infos dict from gym environment" + total_rewards.append(infos["episode"]["r"]) + worker_action_queue = [0 for _ in range(int(numpy.round(random.random() * config.evaluator_config.initialize)))] + effective_frames += worker_steps + total_steps.append(worker_steps) + worker_steps = 0 + + if (steps and effective_frames >= steps) or (episodes and len(total_rewards)): + worker_finished = True + + if collect: + worker_memory.prune() + + if verbose: + end = time.time() + print("Average and effective framerate:", round(frames / (end-start), 3), round(effective_frames / (end-start), 3)) + + return {"Steps": numpy.mean(total_steps) if len(total_steps) else 0, "Reward": numpy.mean(total_rewards) if len(total_rewards) else 0, "Memory": worker_memory if collect else []} + + @abstractmethod + def test_vec(config: Config, + model_wrapper: ModelWrapper, + steps: int=0, + episodes: int=0, + collect: bool=False, + verbose=True, + teacher: ModelWrapper=None) -> Dict: + + env = config.environment_constructor() + + model_wrapper.to(device=config.evaluator_config.device) + if teacher: + teacher.eval() + teacher.to(device=config.evaluator_config.device) + + total_rewards = [] + total_steps = [] + frames = 0 + effective_frames = 0 + + workers = config.evaluator_config.env_workers + + with torch.no_grad(): + state = env.reset() + worker_action_queue = [[0 for _ in range(int(numpy.round(random.random() * config.evaluator_config.initialize)))] for _ in range(workers)] + workers_steps_since_last_reward = [0 for _ in range(workers)] + worker_steps = [0 for _ in range(workers)] + worker_finished = [False for _ in range(workers)] + + worker_memories = [] + if collect: + worker_memories = [Memory(config, size=int(steps / workers) + 10000) for _ in range(workers)] + + start = time.time() + if verbose: + print(f"Start workers") + while not numpy.all(worker_finished): + current_actions = [] + state = model_wrapper.observation_to_tensor(state).float() + action, value, action_std = model_wrapper(state) + + if max([len(q) for q in worker_action_queue]) > 0: + for worker in range(workers): + if len(worker_action_queue[worker]): + current_actions.append(worker_action_queue[worker].pop()) + else: + current_actions = model_wrapper.get_actions() + + if collect: + if teacher and teacher is not model_wrapper: + action, value, action_std = teacher(state) + + for worker in range(workers): + if not worker_finished[worker]: + if worker_memories[worker].will_overflow(): + print(f"Growing memory to size {worker_memories[worker].max_size + 5000}!") + worker_memories[worker].grow(5000) + + outputs = get_output(action[worker], value[worker] if value is not None else None, action_std[worker] if action_std is not None else None) + + worker_memories[worker].update(state[worker].unsqueeze(0), outputs, worker_steps[worker] == 0) + + state, reward, done, infos = env.step(current_actions) #type:ignore + + frames += workers + for worker in range(workers): + worker_steps[worker] += 1 + if reward[worker] > 0: + workers_steps_since_last_reward[worker] = 0 + else: + workers_steps_since_last_reward[worker] += 1 + + if done[worker] and not worker_finished[worker]: + # print(f"Done: {infos}") + if config.evaluator_config.environment_is_done_filter(infos[worker]): + assert "episode" in infos[worker] + total_rewards.append(infos[worker]["episode"]["r"]) + worker_action_queue[worker] = [0 for _ in range(int(numpy.round(random.random() * config.evaluator_config.initialize)))] + effective_frames += worker_steps[worker] + total_steps.append(worker_steps[worker]) + worker_steps[worker] = 0 + + if (steps and effective_frames >= steps) or (episodes and len(total_rewards) + workers - 1 >= episodes): + if verbose: + print(f"Finished with {effective_frames} steps and {len(total_rewards)} episodes!") + worker_finished[worker] = True + + + if collect: + for worker in range(workers): + worker_memories[worker].prune() + + if verbose: + end = time.time() + print("Average and effective framerate:", round(frames / (end-start), 3), round(effective_frames / (end-start), 3)) + + return {"Steps": numpy.mean(total_steps) if len(total_steps) else 0, "Reward": numpy.mean(total_rewards) if len(total_rewards) else 0, "Memory": worker_memories if collect else [], "EpisodeCount": len(total_rewards)/workers} diff --git a/src/qpd/utils/model_utils.py b/src/qpd/utils/model_utils.py new file mode 100644 index 0000000..85a738b --- /dev/null +++ b/src/qpd/utils/model_utils.py @@ -0,0 +1,47 @@ +from distiller.quantization import DorefaQuantizer +from .summary import summary +import torch.nn as nn +import torch +import os + + +class Flatten(nn.Module): + def forward(self, x): + return x.reshape(x.size(0), -1) + + +def init_module(module, weight_init, bias_init, gain=1): + weight_init(module.weight.data, gain=gain) + bias_init(module.bias.data) + return module + + +def get_output_size(input_shape, layer): + sample = torch.zeros(size=(1, *input_shape)) + return layer(sample).view(1, -1).size(1) + + +def get_network(config, Network, load=None, quantization=0, verbose=True): + network = Network(config) + + if verbose: + summary(network) + if quantization: + bits = quantization + optimizer = torch.optim.RMSprop(network.parameters(), lr=0) + quantizer = DorefaQuantizer(network, optimizer, bits_activations=bits, bits_weights=bits, bits_bias=bits) + quantizer.prepare_model() + quantizer.quantize_params() + + if load and (type(load) != str or os.path.exists(load)): + network = network.load(load) + + network.eval() + return network + +def get_output(action, value, action_std): + if action_std is not None: + return torch.cat([action, value, action_std], dim=0) + if value is not None: + return torch.cat([action, value], dim=0) + return action diff --git a/src/qpd/utils/onnx_helper.py b/src/qpd/utils/onnx_helper.py new file mode 100644 index 0000000..802df26 --- /dev/null +++ b/src/qpd/utils/onnx_helper.py @@ -0,0 +1,52 @@ +import os +from collections import OrderedDict +from distiller.quantization import DorefaQuantizer +import numpy as np +from onnxruntime.quantization import quantize_dynamic +import torch + +def filter_state_dict(state_dict: OrderedDict): + new = OrderedDict() + for key, state in state_dict.items(): + key_split = key.split(".") + + state_type = key_split[-1] + + if state_type == "float_weight" or state_type == "float_bias": + continue + + new[key] = state + return new + + +def clean_student(student, config): + clean_student = type(student)(config) + clean_student.load_state_dict(filter_state_dict(student.state_dict())) + return clean_student + + +def export_student_as_onnx(input, output, student_type, config, get_environment, verbose=True, onnx_input_names=["input_1"], onnx_output_names=["output_1"], opset_version=14, qat=True): + # Loading quantized student + student = student_type(config) + bits = config.quantization_config.bits + + if qat: + q = DorefaQuantizer(student, torch.optim.RMSprop(student.parameters(), lr=config.student_config.learning_rate), bits, bits, bits) + q.prepare_model() + + student.load(input) + + c_student = clean_student(student, config) + + env = get_environment(config) + state = env.reset().astype(np.float32) + + path = output.split("/")[:-1] + + orig_onnx = "/".join(path) + "/" + "orig.onnx" + temp_onnx = "/".join(path) + "/" + "clean.onnx" + + torch.onnx.export(student, torch.from_numpy(state), orig_onnx, verbose=verbose, input_names=onnx_input_names, output_names=onnx_output_names, opset_version=opset_version) + torch.onnx.export(c_student, torch.from_numpy(state), temp_onnx, verbose=verbose, input_names=onnx_input_names, output_names=onnx_output_names, opset_version=opset_version) + quantize_dynamic(orig_onnx, output) + # os.remove(temp_onnx) diff --git a/src/qpd/utils/summary.py b/src/qpd/utils/summary.py new file mode 100644 index 0000000..c9de55f --- /dev/null +++ b/src/qpd/utils/summary.py @@ -0,0 +1,40 @@ +from typing import List, Any, Tuple +import torch + +def summary(network): + print("======================================") + print("= Network Summary =") + print("======================================") + print(print_layers([(network.__class__.__name__, extract_parameters(network.parameters()), find_layers(network))])) + +def find_layers(network) -> List[Tuple[str, Any]]: + layers = [] + for child in network.children(): + children = list(child.children()) + if not len(children): + layers.append((child.__class__.__name__, extract_parameters(child.parameters()))) + else: + layers.append((child.__class__.__name__, extract_parameters(child.parameters()), find_layers(child))) + return layers + +def extract_parameters(parameters) -> int: + total = 0 + for param in parameters: + if isinstance(param, torch.Tensor): + total += torch.prod(torch.LongTensor(list(param.size()))) + else: + total += extract_parameters(param) + return int(total) + + +def print_layers(layers) -> str: + result = "" + for layer in layers: + result += "- " + layer[0] + if len(layer) == 2: + result += f': {layer[1]:n}\n' + else: + result += f': {layer[1]:n}\n' + for line in print_layers(layer[2]).splitlines(): + result += " %s\n" % line + return result