Initial public commit

This commit is contained in:
Thomas Avé 2024-11-23 21:35:13 +01:00
commit 862c55e03c
43 changed files with 2202 additions and 0 deletions

35
README.md Normal file
View File

@ -0,0 +1,35 @@
# tiny-rl
>**_NOTE_**: Currently only stable baselines is supported out of the box
The goal of this repository is to make an easy-to-use framework that wraps the QPD algorithm to compress models of RL policies.
## Installation
First these libraries are required to install.
``` sh
sudo apt update
sudo apt install libpng-dev libjpeg-dev zlib1g-dev
```
Afterwards install the python library.
>**_Important_**: Make sure u use python version 3.10.*, setup tools version 65.0 and wheel 0.38.0
``` sh
conda create -n compression python=3.10
conda activate compression
python -m pip install setuptools==65.0 wheel==0.38.4
python -m pip install -e .
```
## Currently supported
- A2C
- PPO
- DQN
- QRDQN
- SAC
## Cheetah env installation if used with conda
``` sh
conda install -c conda-forge xorg-libx11
conda install -c conda-forge glew
conda install -c conda-forge mesalib
conda install -c menpo glfw3
```

View File

@ -0,0 +1,46 @@
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from huggingface_sb3 import load_from_hub
from stable_baselines3.dqn.dqn import DQN
import torch
from qpd.networks.models.student_six_model_dqn_ import StudentSixModelDQN
from run import config
from qpd.config import Config
from qpd.networks.wrapper.model_wrapper import ModelWrapper
def get_environment(config: Config):
return make_vec_env("CartPole-v1", n_envs=config.evaluator_config.env_workers, vec_env_cls=DummyVecEnv, env_kwargs={"render_mode": "human"})
c = Config(get_environment, config)
env = get_environment(c)
# Pretrained original model
checkpoint = load_from_hub(
repo_id="sb3/dqn-CartPole-v1",
filename="dqn-CartPole-v1.zip",
)
custom_objects = {
"learning_rate": 0.0,
"lr_schedule": lambda _: 0.0,
"clip_range": lambda _: 0.0,
}
model = DQN.load(checkpoint, env) #, custom_objects=custom_objects)
teacher = ModelWrapper.construct_wrapper(c, model)
c.init_env_model_based_params(teacher)
# Loading quantized student
student = StudentSixModelDQN(c)
student.load("/home/ian/projects-idlab/euROBIN/code/framework/test/Data/compression/2024-06-14_15:56:55/best/student.pth")
rewards_list = []
state = env.reset()
steps = 0
rewards = 0
torch.onnx.export(student, torch.from_numpy(state), "/home/ian/projects-idlab/euROBIN/code/framework/test/Data/compression/2024-06-14_15:56:55/best/clean_student.onnx", verbose=True, input_names=["input_1"], output_names=["output_1"], opset_version=14, qat=True)

View File

@ -0,0 +1,45 @@
import time
import serial
import numpy as np
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
env = make_vec_env("CartPole-v1", n_envs=1, vec_env_cls=DummyVecEnv, env_kwargs={"render_mode": "human"})
rewards_list = []
state = env.reset()
steps = 0
rewards = 0
port = serial.Serial("/dev/ttyUSB0", 115200, timeout=1)
def read_to_np():
string = port.readline()
return np.array([int(string)])
def write_ser(cmd):
port.write(cmd)
while(1):
tick = time.time()
write_ser(str(state).encode())
actions = read_to_np()
tock = time.time()
state, reward, dones, info = env.step(actions)
env.render()
#time.sleep(0.02)
rewards += reward
steps += 1
print(1/(tock-tick))
if np.all(dones):
print(steps)
print(info[0]["episode"]["r"])
print(rewards)
print(info)
rewards = 0
steps = 0

92
examples/run_atari.py Normal file
View File

@ -0,0 +1,92 @@
import gym
import numpy as np
from qpd.config import Config
from qpd.compressor import Compressor
from huggingface_sb3 import load_from_hub
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_atari_env, make_vec_env
from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
from qpd.networks.wrapper.student.cnn_student import CNNStudentNet
from datetime import datetime
# Student driven no memory saving!
config = {
"memory": {
"size": 54000, # Size of memory used for distillation
"update_frequency": 1, # Epoch frequency for updating the memory
"update_size": 5400, # Minimum update size in steps
"device": "cpu",
# Only used with framestacked environments
"frame_stack_optimization": True, # Only store last frame
"check_consistency": True
},
"evaluator": {
"student_driven": True, # Student decide the transitions in the environment
"student_test_frequency": 10, # Epoch frequency
"episodes": 50, # Minimum episodes for testing student
"initialize": 30, # Amount of actions to skip at beginning of episode
"ray_workers": 10, # Parallel ray workers used for updating and testing
"device": "cpu",
"deterministic": False,
"exploration_rate": 0,
},
"compression": {
"checkpoint_frequency": 2, # Epoch frequency for saving students
"epochs": 600,
"learning_rate": 1e-4,
"batch_size": 256,
"device": "cuda",
# Only used in discrete action spaces
"T": 0.01, # Softmax hyperparameter
"categorical": False,
"critic_importance": 0.5,
# Only used in continuous action spaces
"distribution": "Std", # Std, Mean
"loss": "KL" # KL, Huber, MSE
},
"quantization": {
"enabled": False,
"bits": 8
},
"data_directory": "./data",
"run_name": "test", # Change this for every run
}
def env_info_done_filter(info):
key = "ale.lives" if "ale.lives" in info.keys() else "lives"
print(f"Lives: {info[key]}")
return info[key] == 0
def get_environment(config: Config):
vec_env_cls = DummyVecEnv# SubprocVecEnv if subprocenv else DummyVecEnv
env = make_atari_env("BreakoutNoFrameskip-v4", n_envs=config.evaluator_config.env_workers, seed=np.random.randint(0, 1000), vec_env_cls=vec_env_cls) #type: ignore
env = VecFrameStack(env, n_stack=4)
return env
if __name__ == "__main__":
checkpoint = load_from_hub(repo_id="sb3/a2c-BreakoutNoFrameskip-v4",filename="a2c-BreakoutNoFrameskip-v4.zip",)
print(checkpoint)
model = A2C.load(checkpoint)
c = Config(get_environment, config)
comp = Compressor(model, get_environment, c) \
.student_network(CNNStudentNet) \
.set_environment_done_filter(env_info_done_filter)
compressed_model = comp.compress()
compressed_model.save("test")

85
examples/run_cartpole.py Normal file
View File

@ -0,0 +1,85 @@
from qpd.config import Config
from qpd.compressor import Compressor
from huggingface_sb3 import load_from_hub
from stable_baselines3 import A2C
from stable_baselines3.dqn.dqn import DQN
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from qpd.networks.models.student_six_model import StudentSixModel
from qpd.networks.models.student_six_model_dqn_ import StudentSixModelDQN
from qpd.networks.models.student_tiny_dqn import TinyStudentDQN
from qpd.networks.wrapper.student.fully_connected_student import FCStudentNet
# Student driven no memory saving!
config = {
"memory": {
"size": 100000, # Size of memory used for distillation
"update_frequency": 1, # Epoch frequency for updating the memory
"update_size": 10000, # Minimum update size in steps
"device": "cpu",
# Only used with framestacked environments
"frame_stack_optimization": False, # Only store last frame
"check_consistency": True
},
"evaluator": {
"student_driven": False, # Student decide the transitions in the environment
"student_test_frequency": 10, # Epoch frequency
"episodes": 20, # Minimum episodes for testing student
"initialize": 0, # Amount of actions to skip at beginning of episode
"ray_workers": 10, # Parallel ray workers used for updating and testing
"device": "cpu",
"deterministic": False
},
"compression": {
"checkpoint_frequency": 2, # Epoch frequency for saving students
"epochs": 600,
"learning_rate": 5e-4,
"batch_size": 64,
"device": "cuda",
# Only used in discrete action spaces
"T": 0.01, # Softmax hyperparameter
"categorical": False,
"critic_importance": 0.5,
# Only used in continuous action spaces
"distribution": "Std", # Std, Mean
"loss": "KL" # KL, Huber, MSE
},
"quantization": {
"enabled": True,
"bits": 8
},
"data_directory": "./data",
"run_name": "test", # Change this for every run
}
#"/home/user/Workspace/University/PhD/Experiments/QPD",
def get_environment(config: Config):
return make_vec_env("CartPole-v1", n_envs=config.evaluator_config.env_workers, vec_env_cls=DummyVecEnv)
if __name__ == "__main__":
#checkpoint = load_from_hub(repo_id="sb3/a2c-CartPole-v1",filename="a2c-CartPole-v1.zip",)
checkpoint = load_from_hub(repo_id="sb3/dqn-CartPole-v1",filename="dqn-CartPole-v1.zip",)
# custom_objects = {
# "learning_rate": 0.0,
# "lr_schedule": lambda _: 0.0,
# "clip_range": lambda _: 0.0,
# }
print(checkpoint)
c = Config(get_environment, config)
model = DQN.load(checkpoint, env=get_environment(c)) # , custom_objects=custom_objects)
# comp = Compressor(model, get_environment, c).student_network(FCStudentNet)
comp = Compressor(model, get_environment, c).student_model(TinyStudentDQN)
comp.compress()

86
examples/run_cheetah.py Normal file
View File

@ -0,0 +1,86 @@
from qpd.config import Config
from qpd.networks.models.student_six_model import StudentSixModel
from qpd.compressor import Compressor
from huggingface_sb3 import load_from_hub
from stable_baselines3.sac.sac import SAC
from stable_baselines3.a2c.a2c import A2C
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
from qpd.networks.wrapper.student.fully_connected_student import FCStudentNet
from datetime import datetime
config = {
"memory": {
"size": 100000, # Size of memory used for distillation
"update_frequency": 1, # Epoch frequency for updating the memory
"update_size": 10000, # Minimum update size in steps
"device": "cpu",
# Only used with framestacked environments
"frame_stack_optimization": False, # Only store last frame
"check_consistency": True
},
"evaluator": {
"student_driven": True, # Student decide the transitions in the environment
"student_test_frequency": 10, # Epoch frequency
"episodes": 20, # Minimum episodes for testing student
"initialize": 0, # Amount of actions to skip at beginning of episode
"ray_workers": 10, # Parallel ray workers used for updating and testing
"device": "cpu",
"deterministic": False
},
"compression": {
"checkpoint_frequency": 2, # Epoch frequency for saving students
"epochs": 600,
"learning_rate": 5e-4,
"batch_size": 64,
"device": "cuda",
# Only used in discrete action spaces
"T": 0.01, # Softmax hyperparameter
"categorical": False,
"critic_importance": 0.5,
# Only used in continuous action spaces
"distribution": "Std", # Std, Mean
"loss": "KL" # KL, Huber, MSE
},
"quantization": {
"enabled": False,
"bits": 8
},
"data_directory": "./data",
"run_name": "cheetah_no_quant", # Change this for every run
}
#"/home/user/Workspace/University/PhD/Experiments/QPD",
def get_environment_normalized(config: Config):
env = make_vec_env("HalfCheetah-v3", n_envs=config.evaluator_config.env_workers, vec_env_cls=DummyVecEnv)
normalize = load_from_hub(
repo_id="sb3/a2c-HalfCheetah-v3",
filename="vec_normalize.pkl",
)
return VecNormalize.load(normalize, env)
def get_environment(config: Config):
env = make_vec_env("HalfCheetah-v3", n_envs=config.evaluator_config.env_workers, vec_env_cls=DummyVecEnv)
return env
if __name__ == "__main__":
checkpoint = load_from_hub(repo_id="sb3/sac-HalfCheetah-v3",filename="sac-HalfCheetah-v3.zip",)
print(checkpoint)
model = SAC.load(checkpoint)
c = Config(get_environment, config)
# comp = Compressor(model, get_environment, c).student_network(FCStudentNet)
comp = Compressor(model, get_environment, c).student_model(StudentSixModel)
compressed_model = comp.compress()

36
pyproject.toml Normal file
View File

@ -0,0 +1,36 @@
[project]
name = "qpd"
version = "0.0.1"
authors = [
{ name="Thomas Avé", email="thomas.ave@uantwerpen.be" },
{ name="Ian Ravijts", email="ian.ravijts@uantwerpen.be" }
]
readme = "README.md"
requires-python = ">=3.9"
dependencies = [
"wheel==0.38.4",
"shimmy>=0.2.1",
"numpy",
"atari-py==0.2.9",
"gymnasium[accept-rom-license]",
"ale-py",
"pillow",
"readchar",
"matplotlib",
"distiller@git+https://github.com/tiny-rl/distiller.git",
"wandb",
"stable-baselines3[extra]",
"sb3-contrib",
"tqdm",
"ray[default]",
"webdavclient3",
"prefetch_generator",
"pygame",
"pyglet",
"cython<3",
]
[build-system]
requires = [ "setuptools>=61.0,<=66" ]
build-backend = "setuptools.build_meta"

0
src/qpd/__init__.py Normal file
View File

View File

View File

@ -0,0 +1,193 @@
from __future__ import annotations
import torch
import random
import numpy
import gzip
import json
from pathlib import Path
from typing import Iterator, Tuple, Dict, Optional
from qpd.config import Config
class Memory:
def __init__(self, config: Config, size=None, meta_class = None, path: Optional[str] = None) -> None:
self.meta_name = str(meta_class.__class__) if meta_class else None
self.max_size = size if size else config.memory_config.size
self.did_overflow = False
self.current_index = 0
self.output_n = config.output_shape[0]
self.device = config.memory_config.device
self.check_consistency = config.memory_config.check_consistency
self.critic = config.compression_config.critic_importance > 0
self.frame_stack_optim = config.memory_config.frame_stack_optimization
self.framestack_size = config.observation_shape[0] if self.frame_stack_optim else 0
self.student_driven = config.evaluator_config.student_driven
if path:
print(f"Load memory on path: {path}")
self.load(path)
else:
self.states = torch.zeros((self.max_size,) + (config.observation_shape[1:] if self.frame_stack_optim else config.observation_shape), device=self.device) # [(Input frame, output actions)]
self.outputs = torch.zeros((self.max_size, self.output_n), device=self.device) # + 1 = critic
# self.repeats = torch.zeros(self.size, device=self.device)
self.start = set()
def update(self, state: torch.Tensor, outputs: torch.Tensor, start: bool) -> None:
assert not torch.any(outputs == 0), f"[ERROR]: Update input is not valid!"
# update overflow
self.did_overflow = self.will_overflow()
if start:
self.start.add(self.current_index)
else:
self.start.discard(self.current_index)
state = state.to(self.device).data
self.states[self.current_index] = state[:, -1] if self.frame_stack_optim else state # Story only latest frame
self.outputs[self.current_index] = outputs.to(self.device).data
self.current_index = (self.current_index + 1) % self.max_size
if self.check_consistency:
self.verifyConsistency(f"Update current_index: {self.current_index} is start: {start} outputs: {self.outputs}")
def will_overflow(self, update_size=1):
if (self.current_index + update_size) >= self.max_size:
return True
return False
def real_size(self):
return (self.states.element_size() * self.states.nelement()) + (self.outputs.element_size() * self.outputs.nelement())
def sample(self, batch_size: int, shuffle: bool=True) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
state_indexes = list(range(len(self)))
if shuffle:
random.shuffle(state_indexes)
splits = [state_indexes[i:i + batch_size] for i in range(0, len(state_indexes), batch_size)]
if self.frame_stack_optim:
for split in splits:
current_states = []
indexes = torch.tensor(split)
for i in range(self.framestack_size):
current_states.append(self.states[(indexes - (self.framestack_size - 1) + i) % len(self)].unsqueeze(1))
yield (torch.cat(current_states, dim=1), self.outputs[split])
else:
for split in splits:
yield (self.states[split], self.outputs[split])
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
if self.frame_stack_optim:
states_index = list(self.start)[index]
return (torch.cat([self.states[i % len(self)].unsqueeze(0) for i in range(states_index-3, states_index+1)]), self.outputs[states_index])
else:
return (self.states[index], self.outputs[index])
def __len__(self) -> int:
if self.did_overflow:
return len(self.outputs)
else:
return self.current_index
def prune(self) -> None:
if not self.did_overflow:
for item in ["states", "outputs"]:
setattr(self, item, getattr(self, item)[0: self.current_index])
self.max_size = self.current_index
if self.check_consistency:
self.verifyConsistency("Prune")
def grow(self, amount: int) -> int:
print(f"Growing with size: {amount}")
for item in ["states", "outputs"]:
obj = getattr(self, item)
tmp = torch.zeros((self.max_size + amount,) + obj.shape[1:], device = self.device, dtype=obj.dtype)
tmp[:len(self)] = obj
setattr(self, item, tmp)
self.max_size += amount
self.did_overflow = False
if self.verifyConsistency:
self.verifyConsistency("Grow")
return len(self)
def load_memory(self, other: Memory) -> None:
assert other.max_size <= self.max_size, f"Memory to load is bigger than current memory! {other.max_size} > {self.max_size}"
self.did_overflow = self.will_overflow(len(other))
for item in ["states", "outputs"]:
getattr(self, item)[self.current_index:min(self.current_index+len(other), self.max_size)] = getattr(other, item)[:min(len(other), self.max_size - self.current_index)]
if self.current_index + len(other) > self.max_size:
getattr(self, item)[:len(other) - (self.max_size - self.current_index)] = getattr(other, item)[self.max_size - self.current_index:]
for idx in range(len(other)):
new_idx = (self.current_index + idx) % self.max_size
self.start.discard(new_idx)
if idx in other.start:
self.start.add(new_idx)
self.current_index = (self.current_index + len(other)) % self.max_size
if self.check_consistency:
self.verifyConsistency("Load memory")
def to(self, device: torch.device):
for item in ["states", "outputs"]:
setattr(self, item, getattr(self, item).to(device))
self.device = device
def load(self, path: str) -> None:
assert self.meta_name, "Can not load this memory! No meta_class specified!"
with open(f"{path}/metadata.json") as f:
j = json.loads(f.read())
assert j["teacher"] == self.meta_name
self.current_index = j["current_index"]
self.did_overflow = j["did_overflow"]
for item in ["states", "outputs"]:
with gzip.GzipFile(f"{path}/{item}.npy.gz") as f:
setattr(self, item, torch.from_numpy(numpy.load(file=f)))
with gzip.GzipFile(f"{path}/start.npy.gz") as f:
self.start = set(numpy.load(file=f))
self.max_size = self.states.shape[0]
if self.check_consistency:
self.verifyConsistency("Load")
def save(self, path: str) -> None:
assert self.meta_name, "Can not save this memory! No meta_class specified!"
Path(path).mkdir(parents=True, exist_ok=True)
for item in ["states", "outputs"]:
with gzip.GzipFile(f"{path}/{item}.npy.gz", "w") as f:
numpy.save(f, getattr(self, item).cpu().numpy())
with gzip.GzipFile(f"{path}/start.npy.gz", "w") as f:
numpy.save(f, numpy.array(list(self.start)))
with open(f"{path}/metadata.json", "w") as f:
f.write(json.dumps({"teacher": self.meta_name, "did_overflow": self.did_overflow, "current_index": self.current_index}))
def verifyConsistency(self, mes) -> None:
assert not torch.any(self.outputs[:len(self)] == 0), f"[ERROR]: Memory is not consistent! ({mes})"
# assert not torch.any(torch.bitwise_or(self.outputs[:len(self)] > 100, self.outputs[:len(self)] < -100)), \
# f"[ERROR]: Memory is not consistent in a weird way! max: {self.outputs.max()} min: {self.outputs.min()} ({mes})"

View File

@ -0,0 +1,275 @@
# import wandb
import random
import os
import gc
# from ..utils import WebdavManager
from collections import defaultdict
from distiller.quantization import DorefaQuantizer
from gym.spaces import Box, Discrete
from pathlib import Path
from qpd.config import Config
from qpd.networks.wrapper.stable_baselines.sb_wrapper import SBWrapper
from qpd.utils import summary
from qpd.utils.evaluator import Evaluator
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
from torch.distributions import Categorical
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Union
from .memory import Memory
from ..networks import *
from ..networks.wrapper.model_wrapper import ModelWrapper
# from ..utils import test_network, merge_config
class PolicyDistillation:
def __init__(self, teacher: ModelWrapper, student: ModelWrapper, config: Config): #, webdav_manager: Optional[WebdavManager] = None):
self.teacher = teacher
self.student = student
self.config = config
self.epochs = self.config.compression_config.epochs
self.batch_size = self.config.compression_config.batch_size
self.device = self.config.compression_config.device
self.check_point_frequency = self.config.compression_config.checkpoint_frequency
self.mem_update_frequency = self.config.memory_config.update_frequency
# self.webdav_manager = webdav_manager
self.results = defaultdict(list)
self.student.to(self.device)
self.optimizer = torch.optim.RMSprop(self.student.parameters(), lr=config.student_config.learning_rate)
self.evaluator = Evaluator(config, verbose=True)
self.quantizer = self.get_quantizer()
self.initial_environment = self.config.environment_constructor()
self.best_student = None
self.memory_order = list(range(1, self.config.compression_config.epochs+1))
self.memory = self.load_memory()
random.shuffle(self.memory_order)
def train(self):
best = -1 * float("inf")
for epoch in range(self.epochs):
print(f"Start epoch {epoch}")
epoch_loss = 0
self.student.train()
self.student.to(self.device)
with torch.enable_grad():
for states, outputs in self.memory.sample(self.batch_size):
loss = self.continuous_loss(states.to(self.device, dtype=torch.float), outputs.to(self.device)) if self.config.env_continuous_action_space else self.discrete_loss(states.to(self.device, dtype=torch.float), outputs.to(self.device))
assert not torch.isnan(loss), "Error: loss can't be nan!"
epoch_loss += self.step(loss)
# wandb.log({"loss": epoch_loss, "Epoch": epoch})
print(f"Epoch {epoch} completed, loss: {epoch_loss}")
self.results["Loss"].append({"Epoch": epoch, "Value": epoch_loss})
if (epoch and epoch % self.config.evaluator_config.student_test_frequency == 0) or self.config.evaluator_config.student_test_frequency == 1:
best, is_best = self.test(epoch, best)
if is_best:
self.save_student("best")
if (epoch and epoch % self.check_point_frequency == 0) or self.check_point_frequency == 1:
self.save_student(epoch)
if epoch and epoch % self.mem_update_frequency == 0:
self.load_memory(epoch, self.memory)
return self.best_student
# if self.webdav_manager:
# self.webdav_manager.upload_results(self.results)
def load_memory(self, epoch: int=0, memory: Optional[Memory]=None) -> Memory:
memory_index = (self.memory_order.pop() if epoch else 0)
# memory_index = (epoch if epoch else self.memory_order.pop())
path = f"{self.config.data_directory}/Resources/{str(self.initial_environment.envs[0].env) if self.config.is_vec_env else str(self.initial_environment.env)}/Memory/{self.teacher.__class__.__name__}/sd_{str(self.config.evaluator_config.student_driven)}/{memory_index}"
if os.path.exists(path):
tmp_memory = Memory(self.config, meta_class=self.teacher, path=path)
if memory:
memory.load_memory(tmp_memory)
return memory
return tmp_memory
else:
print()
print("= Memory ==========================================================")
print(f"= Could not find memory with index: {memory_index}, generating new data...")
if not memory:
memory = Memory(self.config, meta_class=self.teacher)
steps = memory.max_size
else:
steps = self.config.memory_config.update_size
print(f"= Collecting data for {steps}")
results = self.evaluator.run(self.student, steps=steps, collect=True, teacher=self.teacher) if self.config.evaluator_config.student_driven else self.evaluator.run(self.teacher, steps=steps, collect=True)
print("= Current average steps:", results["Steps"])
print("= Current average reward:", results["Reward"])
print("= Average amount of episode:", results["EpisodeCount"])
print("===================================================================")
print()
if "GPULab" in os.environ or self.config.evaluator_config.student_driven:
for _ in range(len(results["Memory"])):
memory.load_memory(results["Memory"].pop())
gc.collect()
else:
total_size = sum([len(m) for m in results["Memory"]])
tmp_memory = Memory(self.config, size=total_size, meta_class=self.teacher)
for _ in range(len(results["Memory"])):
tmp_memory.load_memory(results["Memory"].pop())
gc.collect()
tmp_memory.save(path)
if epoch:
memory.load_memory(tmp_memory)
else:
return tmp_memory
return memory
def discrete_loss(self, states:torch.Tensor, teacher_outputs:torch.Tensor) -> torch.Tensor:
T = self.config.compression_config.T
student_policy, student_value, _ = self.student(states)
teacher_policy, teacher_value_std = teacher_outputs[:, :student_policy.shape[1]], teacher_outputs[:, student_policy.shape[1]:]
teacher_value = teacher_value_std[:,:teacher_policy.shape[1]]
if self.config.compression_config.categorical:
dist = Categorical(logits=student_policy)
student_policy = dist.probs
dist2 = Categorical(logits=teacher_policy)
teacher_policy = dist2.probs
policy_loss = nn.KLDivLoss(reduction="batchmean")(F.log_softmax(student_policy, dim=1), F.softmax(teacher_policy/T, dim=1)) #type: ignore
critic_importance = self.config.compression_config.critic_importance
if critic_importance and not student_value == None:
value_loss = nn.HuberLoss(reduction="mean")(student_value.to(self.device), teacher_value.to(self.device))
value_sum = value_loss.sum().detach()
policy_sum = policy_loss.sum().detach()
if policy_sum < 1e-6 or value_sum < 1e-6: # Avoid divide by 0
return (1-critic_importance) * policy_loss + critic_importance * value_loss
else:
return ((1-critic_importance) * (policy_loss / policy_sum) + critic_importance * (value_loss / value_sum)) * (value_sum + policy_sum)
else:
return policy_loss
def continuous_loss(self, states:torch.Tensor, teacher_outputs:torch.Tensor) -> torch.Tensor:
self.student.to(self.device)
student_policy, _, student_std = self.student(states)
# student_policy, student_value = student_outputs[:, :actions], student_outputs[:, actions:]
teacher_policy, teacher_value_std = teacher_outputs[:, :student_policy.shape[1]], teacher_outputs[:, student_policy.shape[1]:]
if self.config.compression_config.distribution == "Std":
teacher_std = teacher_value_std[:, -teacher_policy.shape[1]:] # type: ignore
if not teacher_std.gt(0).all():
teacher_std[teacher_std==0] = 1
print("Fixing teacher std")
if not student_std.gt(0).all():
student_std[student_std==0] = 1
print("Fixing student std")
if self.config.compression_config.loss == "KL":
policy_loss = torch.log((teacher_std / student_std)) + (student_std ** 2 + (student_policy - teacher_policy) ** 2) / (2 * (teacher_std ** 2)) - 1/2 # type: ignore
elif self.config.compression_config.loss == "Huber":
policy_loss = nn.HuberLoss(reduction="none")(student_policy, teacher_policy)
std_loss = nn.HuberLoss(reduction="none")(student_std, teacher_std)
policy_loss += std_loss
elif self.config.compression_config.loss == "MSE":
policy_loss = nn.MSELoss(reduction="none")(student_policy, teacher_policy)
std_loss = nn.MSELoss(reduction="none")(student_std, teacher_std)
policy_loss += std_loss
else:
raise RuntimeError("Unknown loss type: " + self.config.compression_config.loss)
else:
if self.config.compression_config.loss == "Huber":
policy_loss = nn.HuberLoss(reduction="none")(student_policy, teacher_policy) #type: ignore
elif self.config.compression_config.loss == "MSE":
policy_loss = nn.MSELoss(reduction="none")(student_policy, teacher_policy) #type: ignore
else:
raise RuntimeError("Unsupported combination of distribution and loss!")
return policy_loss.sum()
def step(self, loss: torch.Tensor) -> float:
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
self.student.zero_grad()
if self.quantizer:
self.quantizer.quantize_params()
return loss.item()
def get_quantizer(self) -> Optional[DorefaQuantizer]:
if self.config.quantization_config.enabled:
bits = self.config.quantization_config.bits
# self.student = distiller.modules.convert_model_to_distiller_lstm(self.student)
quantizer = DorefaQuantizer(self.student, self.optimizer, bits_activations=bits, bits_weights=bits, bits_bias=bits)
quantizer.prepare_model()
quantizer.quantize_params()
return quantizer
return None
def save_student(self, epoch: Union[int, str], name: str="student"):
checkpoint_directory = self.config.data_directory
base_path = f"{checkpoint_directory}/Data/{self.config.run_name}/compression/{self.config.compression_config.starting_time}"
Path(f"{base_path}/{epoch}").mkdir(parents=True, exist_ok=True)
self.student.save(f"{base_path}/{epoch}/{name}")
# if self.webdav_manager:
# self.webdav_manager.upload_model(student_save_path, epoch, not self.config["testing"]["local_checkpoints"])
def test(self, epoch: int, best: Optional[float]=None) -> Optional[float]:
is_best = False
print()
print("= Testing =========================================================")
results = self.evaluator.run(self.student, episodes=self.config.evaluator_config.episodes)
self.student.to(self.config.compression_config.device)
log = {"Epoch": epoch}
for metric in ["Steps", "Reward"]:
self.results[metric].append({"Epoch": epoch, "Value": results[metric]})
log[metric] = results[metric]
# wandb.log(log)
print("= Current average steps:", results["Steps"])
print("= Current average reward:", results["Reward"])
print("= Average amount of episode:", results["EpisodeCount"])
print("===================================================================")
print()
if best and results["Reward"] > best:
# wandb.log({"Best Reward": results["Reward"]})
best = results["Reward"]
is_best = True
return best, is_best
def get_action_space_shape(self):
action_space = self.student.policy.action_space
is_box = isinstance(action_space, Box)
assert is_box or isinstance(action_space, Discrete), "[ERROR]: Unsupported action space!"
if is_box:
return action_space.shape
else:
return action_space.n

77
src/qpd/compressor.py Normal file
View File

@ -0,0 +1,77 @@
from typing import Type
import qpd.networks.network_interface as ni
from qpd.config import Config
from qpd.networks.wrapper.student.student_net import StudentNet
from qpd.utils import summary
from .compression.policy_distillation import PolicyDistillation
from .networks.wrapper.model_wrapper import ModelWrapper
import json
from pathlib import Path
import torch
import torch.nn as nn
class Compressor:
def __init__(self, model, environment_constructor, config):
self.model = model
self.environment_constructor = environment_constructor
self.config: Config = config
self.teacher: ModelWrapper = ModelWrapper.construct_wrapper(self.config, self.model).to(torch.device(self.config.compression_config.device))
self.student = None
self.config.init_env_model_based_params(self.teacher)
def student_network(self, net: Type[StudentNet]):
"""Add student network to compressor with existing wrappers.
Args:
net (Type[nn.Module]): Student network
Returns:
self
"""
assert self.student == None, "Can not set second student!"
self.config.student_config.student_network = net
self.student_model = self.teacher.construct_student()
self.student = ModelWrapper.construct_wrapper(self.config, self.student_model).to(torch.device(self.config.compression_config.device))
return self
def student_model(self, student: Type["ni.NetworkInterface"]):
"""Add end-to-end student model to compressor
Args:
student (Type[&quot;ni.NetworkInterface&quot;]): _description_
Returns:
_type_: _description_
"""
assert self.student == None, "Can not set second student!"
self.student = student(self.config)
return self
def set_environment_done_filter(self, func):
self.config.evaluator_config.environment_is_done_filter = func
return self
def compress(self):
assert self.student is not None, "No student assigned! Use Compressor(**).student_network(**) or Use Compressor(**).student_model(**)"
Path(f"{self.config.data_directory}/Data/{self.config.run_name}").mkdir(parents=True, exist_ok=True)
config_dump = json.dumps(self.config.serializable(), indent=4, sort_keys=True)
with open(f"{self.config.data_directory}/Data/{self.config.run_name}/config.json", "w") as f:
f.write(config_dump)
print("Current configuration:")
print(config_dump)
distillation = PolicyDistillation(
self.teacher,
self.student,
self.config
)
distillation.train()
return distillation.student

165
src/qpd/config.py Normal file
View File

@ -0,0 +1,165 @@
import datetime
from gymnasium.spaces import Box, Discrete
import qpd.networks.wrapper.model_wrapper as wrapper
import qpd.utils.model_utils as model_utils
import torch
import torch.nn as nn
from qpd.networks.wrapper.student.fully_connected_student import FCStudentNet
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
class AbstractConfig:
key = None
def init(self, config: dict):
if self.key in config.keys():
self.update(config[self.key])
return self
def update(self, config):
self.__dict__.update((k, config[k]) for k in set(config).intersection(self.__dict__))
def serializable(self) -> dict:
tmp = {}
for key, value in self.__dict__.items():
v = None
if isinstance(value, AbstractConfig):
v = value.serializable()
elif isinstance(value, type):
v = str(value)
elif callable(value):
continue
else:
v = value
tmp[key] = v
return tmp
class MemoryConfig(AbstractConfig):
key = "memory"
def __init__(self):
self.size = 54000
self.update_frequency = 1
self.update_size = 5400
self.device = "cpu"
self.check_consistency = True
self.frame_stack_optimization = False
class StudentNetConfig(AbstractConfig):
key = "student"
def __init__(self):
self.student_network = FCStudentNet
self.learning_rate = 1e-4
self.activation_func = nn.ReLU
self.extra_student_kwargs = {}
class CompressionConfig(AbstractConfig):
key = "compression"
def __init__(self):
self.checkpoint_frequency = 2
self.starting_time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
self.epochs = 600
self.batch_size = 256
self.T = 0.01
self.device = torch.device("cuda")
self.critic_importance = 0.5
self.categorical = False
self.loss = "MSE"
self.distribution = "Std"
class EvaluatorConfig(AbstractConfig):
key = "evaluator"
def __init__(self):
self.environment_is_done_filter = lambda env_info: True
self.student_test_frequency = 1
self.episodes = 50
self.initialize = 30
self.env_subproc = False
self.env_workers = 1
self.ray_workers = 24
self.device = "cpu"
self.deterministic = False
self.student_driven = False
self.exploration_rate = 0.
class QuantizationConfig(AbstractConfig):
key = "quantization"
def __init__(self):
self.enabled = False
self.bits = 8
class Config(AbstractConfig):
def __init__(self, environment_constructor, config={}):
self.environment_constructor = lambda: environment_constructor(self)
self.memory_config = MemoryConfig().init(config)
self.student_config = StudentNetConfig().init(config)
self.compression_config = CompressionConfig().init(config)
self.evaluator_config = EvaluatorConfig().init(config)
self.quantization_config = QuantizationConfig().init(config)
self.verbose = True
self.data_directory = "."
self.run_name = "Test"
# AUTO-DEFINED
self.model_has_critic = False
self.observation_shape = None
self.action_shape = None
self.output_shape = None
self.is_vec_env = False
self.env_continuous_action_space = False
self.update(config)
def update_memory_config(self, **kwargs):
self.memory_config.update(**kwargs)
def update_student_config(self, **kwargs):
self.student_config.update(**kwargs)
def update_compressor_config(self, **kwargs):
self.compression_config.update(**kwargs)
def update_evaluator_config(self, **kwargs):
self.evaluator_config.update(**kwargs)
def update_quantization_config(self, **kwargs):
self.quantization_config.update(**kwargs)
def init_env_model_based_params(self, wrapper: "wrapper.ModelWrapper"):
wrapper.to(self.compression_config.device)
initialized_environment = self.environment_constructor()
action_space = initialized_environment.action_space
is_box = isinstance(action_space, Box)
assert is_box or isinstance(action_space, Discrete), f"[ERROR]: Unsupported action space of type {type(action_space)}!"
state = initialized_environment.reset() # Get state to determin shape of observation
state = wrapper.observation_to_tensor(state).to(self.compression_config.device).float() # get usable state and determin if env is vec_env
self.is_vec_env = isinstance(initialized_environment, VecEnv)
if self.is_vec_env:
self.observation_shape = state.shape[1:]
else:
self.observation_shape = state.shape
if is_box:
self.env_continuous_action_space = True
else:
self.env_continuous_action_space = False
action, value, action_std = wrapper(state)
self.model_has_critic = True if value is not None else False
outputs = model_utils.get_output(action[0], value[0] if value is not None else None, action_std[0] if action_std is not None else None) if self.is_vec_env else model_utils.get_output(action, value, action_std)
self.action_shape = action.shape[1:] if self.is_vec_env else action.shape
self.output_shape = outputs.shape

View File

@ -0,0 +1,2 @@
# from .student import Student
# from .teacher import Teacher, TeacherA2C, TeacherDQN, TeacherActorCritic, TeacherPPO, TeacherQRDQN #type: ignore

View File

View File

@ -0,0 +1,45 @@
from qpd.config import Config
from qpd.networks.network_interface import NetworkInterface
import torch
import torch.nn as nn
from torch.distributions import Normal
class StudentSixModel(NetworkInterface):
def __init__(self, config: Config):
super(StudentSixModel, self).__init__(config)
self.net = nn.Sequential(
nn.Linear(self.config.observation_shape[0], 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU()
)
self.action_head = nn.Linear(64, 6)
self.std_head = nn.Linear(64, 6)
self.critic_head = nn.Linear(64, 1)
for module in self.children():
if type(module) == nn.Linear:
module.bias.data.uniform_(0.0) # type: ignore
module.weight.data.uniform_(0, 0.01) # type: ignore
def forward(self, observations):
x = self.net(observations)
mean_actions = self.action_head(x)
std = self.std_head(x)**2
dist = Normal(mean_actions, std)
self.actions = dist.sample()
value = self.critic_head(x)
value = torch.cat([value, value], dim=0)
return mean_actions, value, std
def get_actions(self):
return self.actions.cpu().numpy()

View File

@ -0,0 +1,37 @@
from qpd.config import Config
from qpd.networks.network_interface import NetworkInterface
import torch
import torch.nn as nn
from torch.distributions import Normal
class StudentSixModelDQN(NetworkInterface):
def __init__(self, config: Config):
super(StudentSixModelDQN, self).__init__(config)
self.net = nn.Sequential(
nn.Linear(self.config.observation_shape[0], 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU()
)
self.action_head = nn.Linear(64, 2)
for module in self.children():
if type(module) == nn.Linear:
module.bias.data.uniform_(0.0) # type: ignore
module.weight.data.uniform_(0, 0.01) # type: ignore
def forward(self, observations):
x = self.net(observations)
self.mean_actions = self.action_head(x)
return self.mean_actions, None, None
def get_actions(self):
actions = self.mean_actions.argmax(dim=1).reshape(-1)
return actions.cpu().numpy()

View File

@ -0,0 +1,37 @@
from qpd.config import Config
from qpd.networks.network_interface import NetworkInterface
import torch
import torch.nn as nn
from torch.distributions import Normal
class TinyStudentDQN(NetworkInterface):
def __init__(self, config: Config):
super(TinyStudentDQN, self).__init__(config)
self.net = nn.Sequential(
nn.Linear(self.config.observation_shape[0], 16),
nn.ReLU(),
nn.Linear(16, 16),
nn.ReLU(),
# nn.Linear(16, 16),
# nn.ReLU()
)
self.action_head = nn.Linear(16, 2)
for module in self.children():
if type(module) == nn.Linear:
module.bias.data.uniform_(0.0) # type: ignore
module.weight.data.uniform_(0, 0.01) # type: ignore
def forward(self, observations):
x = self.net(observations)
self.mean_actions = self.action_head(x)
return self.mean_actions, None, None
def get_actions(self):
actions = self.mean_actions.argmax(dim=1).reshape(-1)
return actions.cpu().numpy()

View File

@ -0,0 +1,25 @@
import torch
import torch.nn as nn
import qpd.config as config
class NetworkInterface(nn.Module):
def __init__(self, config: "config.Config"):
super(NetworkInterface, self).__init__()
self.config: "config.Config" = config
def forward(self, x):
raise NotImplementedError(f"Forward method not implemented for {type(self)}!")
def get_actions(self):
raise NotImplementedError(f"Actions method not implemented for {type(self)}!")
def save(self, path):
torch.save(self.state_dict(), path + ".pth")
def load(self, path):
self.load_state_dict(torch.load(path))
self.eval()
def observation_to_tensor(self, x):
return torch.from_numpy(x)

View File

@ -0,0 +1,7 @@
from typing import List, Type
import qpd.networks.wrapper.model_wrapper as wrapper
registered_teachers: List[Type[wrapper.ModelWrapper]] = []
# from .stable_baselines import *
import qpd.networks.wrapper.stable_baselines

View File

@ -0,0 +1,19 @@
import qpd.config as config
from qpd.networks.network_interface import NetworkInterface
import qpd.networks.wrapper as wrapper
class ModelWrapper(NetworkInterface):
algorithm_type = None
def __init__(self, config: "config.Config", model):
super(ModelWrapper, self).__init__(config)
self.original_type = type(model)
self.policy = None
def construct_student(self):
raise NotImplementedError("Construct student method not implemented!")
@staticmethod
def construct_wrapper(config, model):
return next(teacher for teacher in wrapper.registered_teachers if teacher.algorithm_type == type(model))(config, model)

View File

@ -0,0 +1,16 @@
# from .sb_wrapper import SBWrapper
# from .value_iteration import WrapperValueIteration
# from .on_actor_critic import WrapperONPActorCritic
# from .qrdqn import WrapperQRDQN
# from .ppo import WrapperPPO
# from .dqn import WrapperDQN
# from .a2c import WrapperA2C
# from .sac import WrapperSAC
import qpd.networks.wrapper.stable_baselines.sb_wrapper
import qpd.networks.wrapper.stable_baselines.value_iteration
import qpd.networks.wrapper.stable_baselines.on_actor_critic
import qpd.networks.wrapper.stable_baselines.qrdqn
import qpd.networks.wrapper.stable_baselines.ppo
import qpd.networks.wrapper.stable_baselines.dqn
import qpd.networks.wrapper.stable_baselines.a2c
import qpd.networks.wrapper.stable_baselines.sac

View File

@ -0,0 +1,8 @@
import qpd.networks.wrapper.stable_baselines.on_actor_critic as sboac
from ..wrapper_decorator import wrapper
from stable_baselines3 import A2C
@wrapper
class WrapperA2C(sboac.WrapperONPActorCritic):
algorithm_type = A2C

View File

@ -0,0 +1,18 @@
import numpy as np
import gym
from stable_baselines3 import DQN
from stable_baselines3.common.preprocessing import maybe_transpose
from stable_baselines3.common.utils import is_vectorized_observation
import qpd.networks.wrapper.stable_baselines.value_iteration as sbvi
from ..wrapper_decorator import wrapper
@wrapper
class WrapperDQN(sbvi.WrapperValueIteration):
algorithm_type = DQN
def forward(self, x):
self.observation = x
self.q_values = self.policy.q_net(self.observation)
return self.q_values, None, None

View File

@ -0,0 +1,25 @@
from typing import Type, Union
import qpd.networks.wrapper.stable_baselines.sb_wrapper as sbw
from qpd.networks.wrapper.wrapper_decorator import wrapper
from qpd.networks.wrapper.model_wrapper import ModelWrapper
from stable_baselines3.sac.sac import SAC
import torch
class WrapperOFFPActorCritic(sbw.SBWrapper):
def forward(self, obs):
mean_actions, log_std, kwargs = self.policy.actor.get_action_dist_params(obs)
self.actions = self.policy.actor.action_dist.actions_from_params(mean_actions, log_std, deterministic=self.config.evaluator_config.deterministic, **kwargs)
value = torch.cat(self.policy.critic(obs, self.actions), dim=1)
action_std = None
if self.config.compression_config.distribution == "Std":
action_std = torch.ones_like(mean_actions) * log_std.exp()
return mean_actions, value, action_std
def _get_actions(self):
return self.actions

View File

@ -0,0 +1,62 @@
import torch
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.distributions import (
BernoulliDistribution,
CategoricalDistribution,
DiagGaussianDistribution,
MultiCategoricalDistribution,
StateDependentNoiseDistribution
)
from stable_baselines3.common.policies import BasePolicy
import qpd.networks.wrapper.stable_baselines.sb_wrapper as sbw
class WrapperONPActorCritic(sbw.SBWrapper):
def forward(self, obs):
features = self.policy.extract_features(obs)
if self.policy.share_features_extractor:
self.latent_pi, latent_vf = self.policy.mlp_extractor(features)
else:
pi_features, vf_features = features
self.latent_pi = self.policy.mlp_extractor.forward_actor(pi_features)
latent_vf = self.policy.mlp_extractor.forward_critic(vf_features)
self.mean_actions = self.policy.action_net(self.latent_pi)
value = self.policy.value_net(latent_vf)
action_std = None
if self.config.compression_config.distribution == "Std":
if isinstance(self.policy.action_dist, StateDependentNoiseDistribution):
variance = torch.mm(self.latent_pi ** 2, self.policy.action_dist.get_std(self.policy.log_std) ** 2)
action_std = torch.sqrt(variance + self.policy.action_dist.epsilon)
elif isinstance(self.policy.action_dist, DiagGaussianDistribution):
action_std = torch.ones_like(self.mean_actions) * self.policy.log_std.exp()
elif isinstance(self.policy.action_dist, CategoricalDistribution):
action_std = None # In case of discrete action space
else:
raise RuntimeError("Invalid distribution type!")
return self.mean_actions, value, action_std
def _get_actions(self):
distribution = None
if isinstance(self.policy.action_dist, DiagGaussianDistribution):
distribution = self.policy.action_dist.proba_distribution(self.mean_actions, self.policy.log_std)
elif isinstance(self.policy.action_dist, CategoricalDistribution):
# Here mean_actions are the logits before the softmax
distribution = self.policy.action_dist.proba_distribution(action_logits=self.mean_actions)
elif isinstance(self.policy.action_dist, MultiCategoricalDistribution):
# Here mean_actions are the flattened logits
distribution = self.policy.action_dist.proba_distribution(action_logits=self.mean_actions)
elif isinstance(self.policy.action_dist, BernoulliDistribution):
# Here mean_actions are the logits (before rounding to get the binary actions)
distribution = self.policy.action_dist.proba_distribution(action_logits=self.mean_actions)
elif isinstance(self.policy.action_dist, StateDependentNoiseDistribution):
distribution = self.policy.action_dist.proba_distribution(self.mean_actions, self.policy.log_std, self.latent_pi)
else:
raise ValueError("Invalid action distribution")
return distribution.get_actions(self.config.evaluator_config.deterministic)

View File

@ -0,0 +1,10 @@
import qpd.networks.wrapper.stable_baselines.on_actor_critic as sboac
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.policies import BasePolicy
from ..wrapper_decorator import wrapper
from stable_baselines3 import PPO
@wrapper
class WrapperPPO(sboac.WrapperONPActorCritic):
algorithm_type = PPO

View File

@ -0,0 +1,12 @@
from ..wrapper_decorator import wrapper
import qpd.networks.wrapper.stable_baselines.value_iteration as sbvi
from sb3_contrib import QRDQN
@wrapper
class WrapperQRDQN(sbvi.WrapperValueIteration):
algorithm_type = QRDQN
def forward(self, obs):
self.observation = obs
self.q_values = self.policy.quantile_net(self.observation)
return self.q_values.view(-1, self.policy.n_quantiles, self.policy.action_space.n).mean(dim=1), None, None

View File

@ -0,0 +1,8 @@
import qpd.networks.wrapper.stable_baselines.off_actor_critic as sbac
from qpd.networks.wrapper.wrapper_decorator import wrapper
from stable_baselines3.sac.sac import SAC
@wrapper
class WrapperSAC(sbac.WrapperOFFPActorCritic):
algorithm_type = SAC

View File

@ -0,0 +1,67 @@
from typing import Type, Union
import gym
import numpy as np
import torch
import qpd.networks.wrapper.stable_baselines.student as student
import qpd.networks.wrapper.model_wrapper as model_wrapper
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.policies import BasePolicy
class SBWrapper(model_wrapper.ModelWrapper):
algorithm_type: Type[BaseAlgorithm] = None
def __init__(self, config, model: Union[BaseAlgorithm | BasePolicy]):
super(SBWrapper, self).__init__(config, model)
self.policy: BasePolicy = model.policy if isinstance(model, BaseAlgorithm) else model
assert isinstance(self.policy, BasePolicy), "Incorrect input model type!"
def forward(self, x):
raise NotImplementedError("Not implemented forward method!")
def _get_actions(self):
raise NotImplementedError("Not implemented _getactions method!")
def observation_to_tensor(self, x):
return self.policy.obs_to_tensor(x)[0]
def save(self, path):
student = self.construct_student()
student.policy = self.policy
student.save(path + ".zip")
def load(self, path):
self.policy = self.algorithm_type.load(path).policy
def get_actions(self):
with torch.no_grad():
actions = self._get_actions()
actions = actions.cpu().numpy()
if isinstance(self.policy.action_space, gym.spaces.Box):
if self.policy.squash_output:
# Rescale to proper domain when using squashing
actions = self.policy.unscale_action(actions)
else:
# Actions could be on arbitrary scale, so clip the actions to avoid
# out of bound error (e.g. if sampling from a Gaussian distribution)
actions = np.clip(actions, self.policy.action_space.low, self.policy.action_space.high)
return actions
def construct_student(self):
tmp = {
"learning_rate": self.config.student_config.learning_rate,
"policy_kwargs": {
"net_arch": [],
"features_extractor_class": student.SBStudentNet,
"features_extractor_kwargs": {
"network": self.config.student_config.student_network
},
"activation_fn": self.config.student_config.activation_func
}
}
tmp.update(self.config.student_config.extra_student_kwargs)
return self.original_type(type(self.policy), self.config.environment_constructor(), **tmp)

View File

@ -0,0 +1,21 @@
from typing import Type
import torch as th
import torch.nn as nn
from gym import spaces
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from qpd.networks.wrapper.student.student_net import StudentNet
from ..student.cnn_student import CNNStudentNet
class SBStudentNet(BaseFeaturesExtractor):
def __init__(self, observation_space: spaces.Space, features_dim: int = 16, network: Type[StudentNet]=CNNStudentNet):
super(SBStudentNet, self).__init__(observation_space, features_dim)
self.network = network(observation_space, features_dim)
self._features_dim = self.network.features_dim
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.network(observations)

View File

@ -0,0 +1,22 @@
import numpy as np
import gym
from stable_baselines3.common.preprocessing import maybe_transpose
from stable_baselines3.common.utils import is_vectorized_observation
import qpd.networks.wrapper.stable_baselines.sb_wrapper as sbw
class WrapperValueIteration(sbw.SBWrapper):
def _get_actions(self):
if not self.config.evaluator_config.deterministic and np.random.rand() < self.config.evaluator_config.exploration_rate:
if is_vectorized_observation(maybe_transpose(self.observation, self.policy.observation_space), self.policy.observation_space):
if isinstance(self.policy.observation_space, gym.spaces.Dict):
n_batch = self.observation[list(self.observation.keys())[0]].shape[0]
else:
n_batch = self.observation.shape[0]
action = np.array([self.action_space.sample() for _ in range(n_batch)])
else:
action = np.array(self.action_space.sample())
else:
action = self.q_values.argmax(dim=1).reshape(-1)
return action

View File

@ -0,0 +1,39 @@
from typing import Type
import torch as th
import torch.nn as nn
from gym import spaces
import qpd.networks.wrapper.student.student_net as student
class CNNStudentNet(student.StudentNet):
def __init__(self, observation_space: spaces.Box, n_output):
super(CNNStudentNet, self).__init__(observation_space, n_output)
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(
nn.Conv2d(n_input_channels, 16, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(16, 16, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(16, 16, kernel_size=3, stride=1),
nn.ReLU(),
nn.Flatten()
)
# Compute shape by doing one forward pass
with th.no_grad():
n_flatten = self.cnn(
th.as_tensor(observation_space.sample()[None]).float()
).shape[1]
self.linear = nn.Sequential(
nn.Linear(n_flatten, self.features_dim),
nn.ReLU()
)
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.linear(self.cnn(observations))

View File

@ -0,0 +1,31 @@
from typing import Type
import torch as th
import torch.nn as nn
from gym import spaces
import qpd.networks.wrapper.student.student_net as student
class FCStudentNet(student.StudentNet):
def __init__(self, observation_space: spaces.Box, n_output: int):
super(FCStudentNet, self).__init__(observation_space, 64)
n_input_channels = observation_space.shape[0]
self.net = nn.Sequential(
nn.Linear(n_input_channels, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, self.features_dim),
nn.ReLU()
)
for module in self.children():
if type(module) == nn.Linear:
module.bias.data.uniform_(0.0) # type: ignore
module.weight.data.uniform_(0, 0.01) # type: ignore
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.net(observations)

View File

@ -0,0 +1,14 @@
from typing import Type
import torch as th
import torch.nn as nn
from gym import spaces
class StudentNet(nn.Module):
def __init__(self, observation_space: spaces.Box, n_output: int):
super(StudentNet, self).__init__()
self._n_output = n_output
@property
def features_dim(self):
return self._n_output

View File

@ -0,0 +1,5 @@
from . import registered_teachers
def wrapper(cls):
registered_teachers.append(cls)
return cls

View File

@ -0,0 +1,4 @@
from .configuration import get_config, get_config_mods, merge_config, import_config
from .summary import *
from .model_utils import get_network
from .webdav import WebdavManager

View File

@ -0,0 +1,91 @@
from argparse import ArgumentParser
import copy
def convert_nested_insert(dct, lst):
for x in lst[:-2]:
dct[x] = dct = dct.get(x, dict())
dct.update({lst[-2]: lst[-1]})
def convert_nested(dct):
# empty dict to store the result
result = dict()
# create an iterator of lists
# representing nested or hierarchical flow
lsts = ([*k.split("."), v] for k, v in dct.items())
# insert each list into the result
for lst in lsts:
convert_nested_insert(result, lst)
return result
def merge_config(source, destination):
destination = copy.deepcopy(destination)
for key, value in source.items():
if isinstance(value, dict):
if key in destination and isinstance(destination[key], dict):
destination_value = destination[key]
else:
destination_value = {}
destination[key] = merge_config(value, destination_value)
else:
try:
value = eval(value)
except Exception:
pass
destination[key] = value
return destination
def get_config_mods(known_args):
config_mods = dict(zip(map(lambda x: x[2:], known_args[1][:-1:2]), known_args[1][1::2]))
config_mods = convert_nested(config_mods)
return config_mods
def parse_args():
parser = ArgumentParser()
parser.add_argument('--config', type=str, default="experiment")
known_args = parser.parse_known_args()
args = known_args[0]
config_mods = get_config_mods(known_args)
return args, config_mods
def get_config():
args, config_mods = parse_args()
try:
config_module = __import__(f"configs.{args.config}", fromlist=[''])
configs = config_module.config
if config_mods is not None:
configs = merge_config(config_mods, configs)
return configs
except ModuleNotFoundError as ex:
print(f"ERROR: Could not find the experiment config file: {ex}")
quit()
def import_config(config, experiment_config):
config["network"]["conv1"] = experiment_config["network"]["conv1"]
config["network"]["conv2"] = experiment_config["network"]["conv2"]
config["network"]["conv3"] = experiment_config["network"]["conv3"]
config["network"]["linear"] = experiment_config["network"]["linear"]
config["network"]["teacher"] = experiment_config["network"]["teacher"]
if "critic_importance" in experiment_config["compression"]:
config["compression"]["critic_importance"] = experiment_config["compression"]["critic_importance"]
config["testing"]["randomness"] = experiment_config["testing"]["randomness"]
config["quantization"]["enabled"] = experiment_config["quantization"]["enabled"]
config["compression"]["epochs"] = experiment_config["compression"]["epochs"]
config["train_run"] = experiment_config["run_name"]
return config

303
src/qpd/utils/evaluator.py Normal file
View File

@ -0,0 +1,303 @@
from abc import abstractmethod
import json
import torch
import time
import numpy
import random
import ray
from typing import Dict, List, Optional
from torch.distributions import Normal
from qpd.config import Config
from qpd.networks.wrapper.model_wrapper import ModelWrapper
from qpd.utils.model_utils import get_output
from ..compression.memory import Memory
from dataclasses import dataclass, field
@dataclass
class Worker:
idx: int = 0
action_queue: List[int] = field(default_factory=list)
memory: Optional[Memory] = None
finished: bool = False
steps: int = 0
rewards: int = 0
repeats: int = 0
def reset(self, number: int):
self.action_queue = [0 for _ in range(int(numpy.round(random.random() * number)))]
self.steps = 0
self.rewards = 0
self.repeats = 0
@dataclass
class Stats:
repeats: int = 0
frames: int = 0
effective_frames: int = 0
steps: List[int] = field(default_factory=list)
rewards: List[int] = field(default_factory=list)
def update(self, worker: Worker):
self.repeats += worker.repeats
self.effective_frames += worker.steps
self.steps.append(worker.steps - worker.repeats)
self.rewards.append(worker.rewards)
class Evaluator:
def __init__(self, config: Config, verbose: bool=False) -> None:
self.config = config
# self.env = create_env(config)
self.verbose = verbose
ray.init(
_system_config={
"local_fs_capacity_threshold": 0.9999999999,
"object_spilling_config": json.dumps(
{
"type": "filesystem",
"params": {
"directory_path": f"{config.data_directory}/spill"
}
},
)
},
)
def run(self, network: torch.nn.Module, collect: bool=False, steps: int=0, episodes: int=0, teacher=None) -> Dict:
network.eval()
network.to(torch.device(self.config.evaluator_config.device))
@ray.remote(num_gpus=0.1)
def run_ray(function, *args, **kwargs):
return function(*args, **kwargs)
test_function = Evaluator.test_vec if self.config.is_vec_env else Evaluator.test
if self.config.evaluator_config.ray_workers:
worker_steps = list(map(lambda x: len(x), numpy.array_split(numpy.arange(steps), self.config.evaluator_config.ray_workers)))
worker_episodes = list(map(lambda x: len(x), numpy.array_split(numpy.arange(episodes), self.config.evaluator_config.ray_workers)))
old_metrics = ray.get(
[run_ray.remote(
test_function,
self.config,
network,
steps=s,
episodes=e,
collect=collect,
verbose=self.verbose,
teacher=teacher
) for s, e in zip(worker_steps, worker_episodes)]
)
metrics = {item: sum(list(map(lambda x: x[item], old_metrics)), []) for item in ["Memory"]}
metrics.update({item: numpy.mean(list(map(lambda x: x[item], old_metrics))) for item in ["Reward", "Steps", "EpisodeCount"]})
# metrics = {item: sum(list(map(lambda x: x[item], old_metrics)), []) for item in ["Memory", "Reward", "Steps"]}
else:
metrics = test_function(
self.config,
network,
steps=steps,
episodes=episodes,
collect=collect,
verbose=self.verbose,
teacher=teacher)
network.train()
return metrics
@abstractmethod
def test(config: Config,
model_wrapper: ModelWrapper,
steps: int=0,
episodes: int=0,
collect: bool=False,
verbose=True,
teacher: ModelWrapper=None) -> Dict:
env = config.environment_constructor()
# model = None
# model = model_wrapper.original_type(type(model_wrapper.policy), env, device=config.evaluator_config.device, **config.build_student_kwargs())
# model.policy = model_wrapper.policy
if teacher:
teacher.to(device=config.evaluator_config.device)
total_rewards = []
total_steps = []
frames = 0
effective_frames = 0
with torch.no_grad():
env_state = env.reset()
worker_action_queue = [0 for _ in range(int(numpy.round(random.random() * config.evaluator_config.initialize)))]
worker_steps_since_last_reward = 0
worker_steps = 0
worker_finished = False
worker_memory = None
if collect:
worker_memory = Memory(config, size=int(steps) + 10000)
start = time.time()
if verbose:
print(f"Start workers")
while not numpy.all(worker_finished):
current_actions = []
state, _ = model_wrapper.observation_to_tensor(env_state)
action, value, action_std = model_wrapper(state)
if len(worker_action_queue) > 0:
current_actions.append(worker_action_queue.pop())
else:
current_actions = model_wrapper.get_actions()
if collect:
if teacher and teacher is not model_wrapper:
action, value, action_std = teacher(state)
if not worker_finished:
if worker_memory.will_overflow:
print(f"Growing memory to size {worker_memory.max_size + 5000}!")
worker_memory.grow(5000)
worker_memory.update(
state.unsqueeze(0),
get_output(action, value, action_std),
worker_steps == 0
)
env_state, reward, done, infos = env.step(current_actions) #type:ignore
frames += 1
worker_steps += 1
if reward > 0:
worker_steps_since_last_reward = 0
else:
worker_steps_since_last_reward += 1
if done:
if config.evaluator_config.environment_is_done_filter(infos):
assert "episode" in infos, "No episode key in infos dict from gym environment"
total_rewards.append(infos["episode"]["r"])
worker_action_queue = [0 for _ in range(int(numpy.round(random.random() * config.evaluator_config.initialize)))]
effective_frames += worker_steps
total_steps.append(worker_steps)
worker_steps = 0
if (steps and effective_frames >= steps) or (episodes and len(total_rewards)):
worker_finished = True
if collect:
worker_memory.prune()
if verbose:
end = time.time()
print("Average and effective framerate:", round(frames / (end-start), 3), round(effective_frames / (end-start), 3))
return {"Steps": numpy.mean(total_steps) if len(total_steps) else 0, "Reward": numpy.mean(total_rewards) if len(total_rewards) else 0, "Memory": worker_memory if collect else []}
@abstractmethod
def test_vec(config: Config,
model_wrapper: ModelWrapper,
steps: int=0,
episodes: int=0,
collect: bool=False,
verbose=True,
teacher: ModelWrapper=None) -> Dict:
env = config.environment_constructor()
model_wrapper.to(device=config.evaluator_config.device)
if teacher:
teacher.eval()
teacher.to(device=config.evaluator_config.device)
total_rewards = []
total_steps = []
frames = 0
effective_frames = 0
workers = config.evaluator_config.env_workers
with torch.no_grad():
state = env.reset()
worker_action_queue = [[0 for _ in range(int(numpy.round(random.random() * config.evaluator_config.initialize)))] for _ in range(workers)]
workers_steps_since_last_reward = [0 for _ in range(workers)]
worker_steps = [0 for _ in range(workers)]
worker_finished = [False for _ in range(workers)]
worker_memories = []
if collect:
worker_memories = [Memory(config, size=int(steps / workers) + 10000) for _ in range(workers)]
start = time.time()
if verbose:
print(f"Start workers")
while not numpy.all(worker_finished):
current_actions = []
state = model_wrapper.observation_to_tensor(state).float()
action, value, action_std = model_wrapper(state)
if max([len(q) for q in worker_action_queue]) > 0:
for worker in range(workers):
if len(worker_action_queue[worker]):
current_actions.append(worker_action_queue[worker].pop())
else:
current_actions = model_wrapper.get_actions()
if collect:
if teacher and teacher is not model_wrapper:
action, value, action_std = teacher(state)
for worker in range(workers):
if not worker_finished[worker]:
if worker_memories[worker].will_overflow():
print(f"Growing memory to size {worker_memories[worker].max_size + 5000}!")
worker_memories[worker].grow(5000)
outputs = get_output(action[worker], value[worker] if value is not None else None, action_std[worker] if action_std is not None else None)
worker_memories[worker].update(state[worker].unsqueeze(0), outputs, worker_steps[worker] == 0)
state, reward, done, infos = env.step(current_actions) #type:ignore
frames += workers
for worker in range(workers):
worker_steps[worker] += 1
if reward[worker] > 0:
workers_steps_since_last_reward[worker] = 0
else:
workers_steps_since_last_reward[worker] += 1
if done[worker] and not worker_finished[worker]:
# print(f"Done: {infos}")
if config.evaluator_config.environment_is_done_filter(infos[worker]):
assert "episode" in infos[worker]
total_rewards.append(infos[worker]["episode"]["r"])
worker_action_queue[worker] = [0 for _ in range(int(numpy.round(random.random() * config.evaluator_config.initialize)))]
effective_frames += worker_steps[worker]
total_steps.append(worker_steps[worker])
worker_steps[worker] = 0
if (steps and effective_frames >= steps) or (episodes and len(total_rewards) + workers - 1 >= episodes):
if verbose:
print(f"Finished with {effective_frames} steps and {len(total_rewards)} episodes!")
worker_finished[worker] = True
if collect:
for worker in range(workers):
worker_memories[worker].prune()
if verbose:
end = time.time()
print("Average and effective framerate:", round(frames / (end-start), 3), round(effective_frames / (end-start), 3))
return {"Steps": numpy.mean(total_steps) if len(total_steps) else 0, "Reward": numpy.mean(total_rewards) if len(total_rewards) else 0, "Memory": worker_memories if collect else [], "EpisodeCount": len(total_rewards)/workers}

View File

@ -0,0 +1,47 @@
from distiller.quantization import DorefaQuantizer
from .summary import summary
import torch.nn as nn
import torch
import os
class Flatten(nn.Module):
def forward(self, x):
return x.reshape(x.size(0), -1)
def init_module(module, weight_init, bias_init, gain=1):
weight_init(module.weight.data, gain=gain)
bias_init(module.bias.data)
return module
def get_output_size(input_shape, layer):
sample = torch.zeros(size=(1, *input_shape))
return layer(sample).view(1, -1).size(1)
def get_network(config, Network, load=None, quantization=0, verbose=True):
network = Network(config)
if verbose:
summary(network)
if quantization:
bits = quantization
optimizer = torch.optim.RMSprop(network.parameters(), lr=0)
quantizer = DorefaQuantizer(network, optimizer, bits_activations=bits, bits_weights=bits, bits_bias=bits)
quantizer.prepare_model()
quantizer.quantize_params()
if load and (type(load) != str or os.path.exists(load)):
network = network.load(load)
network.eval()
return network
def get_output(action, value, action_std):
if action_std is not None:
return torch.cat([action, value, action_std], dim=0)
if value is not None:
return torch.cat([action, value], dim=0)
return action

View File

@ -0,0 +1,52 @@
import os
from collections import OrderedDict
from distiller.quantization import DorefaQuantizer
import numpy as np
from onnxruntime.quantization import quantize_dynamic
import torch
def filter_state_dict(state_dict: OrderedDict):
new = OrderedDict()
for key, state in state_dict.items():
key_split = key.split(".")
state_type = key_split[-1]
if state_type == "float_weight" or state_type == "float_bias":
continue
new[key] = state
return new
def clean_student(student, config):
clean_student = type(student)(config)
clean_student.load_state_dict(filter_state_dict(student.state_dict()))
return clean_student
def export_student_as_onnx(input, output, student_type, config, get_environment, verbose=True, onnx_input_names=["input_1"], onnx_output_names=["output_1"], opset_version=14, qat=True):
# Loading quantized student
student = student_type(config)
bits = config.quantization_config.bits
if qat:
q = DorefaQuantizer(student, torch.optim.RMSprop(student.parameters(), lr=config.student_config.learning_rate), bits, bits, bits)
q.prepare_model()
student.load(input)
c_student = clean_student(student, config)
env = get_environment(config)
state = env.reset().astype(np.float32)
path = output.split("/")[:-1]
orig_onnx = "/".join(path) + "/" + "orig.onnx"
temp_onnx = "/".join(path) + "/" + "clean.onnx"
torch.onnx.export(student, torch.from_numpy(state), orig_onnx, verbose=verbose, input_names=onnx_input_names, output_names=onnx_output_names, opset_version=opset_version)
torch.onnx.export(c_student, torch.from_numpy(state), temp_onnx, verbose=verbose, input_names=onnx_input_names, output_names=onnx_output_names, opset_version=opset_version)
quantize_dynamic(orig_onnx, output)
# os.remove(temp_onnx)

40
src/qpd/utils/summary.py Normal file
View File

@ -0,0 +1,40 @@
from typing import List, Any, Tuple
import torch
def summary(network):
print("======================================")
print("= Network Summary =")
print("======================================")
print(print_layers([(network.__class__.__name__, extract_parameters(network.parameters()), find_layers(network))]))
def find_layers(network) -> List[Tuple[str, Any]]:
layers = []
for child in network.children():
children = list(child.children())
if not len(children):
layers.append((child.__class__.__name__, extract_parameters(child.parameters())))
else:
layers.append((child.__class__.__name__, extract_parameters(child.parameters()), find_layers(child)))
return layers
def extract_parameters(parameters) -> int:
total = 0
for param in parameters:
if isinstance(param, torch.Tensor):
total += torch.prod(torch.LongTensor(list(param.size())))
else:
total += extract_parameters(param)
return int(total)
def print_layers(layers) -> str:
result = ""
for layer in layers:
result += "- " + layer[0]
if len(layer) == 2:
result += f': {layer[1]:n}\n'
else:
result += f': {layer[1]:n}\n'
for line in print_layers(layer[2]).splitlines():
result += " %s\n" % line
return result