Initial public commit
This commit is contained in:
commit
862c55e03c
|
@ -0,0 +1,35 @@
|
|||
# tiny-rl
|
||||
>**_NOTE_**: Currently only stable baselines is supported out of the box
|
||||
|
||||
The goal of this repository is to make an easy-to-use framework that wraps the QPD algorithm to compress models of RL policies.
|
||||
|
||||
|
||||
## Installation
|
||||
First these libraries are required to install.
|
||||
``` sh
|
||||
sudo apt update
|
||||
sudo apt install libpng-dev libjpeg-dev zlib1g-dev
|
||||
```
|
||||
Afterwards install the python library.
|
||||
>**_Important_**: Make sure u use python version 3.10.*, setup tools version 65.0 and wheel 0.38.0
|
||||
``` sh
|
||||
conda create -n compression python=3.10
|
||||
conda activate compression
|
||||
python -m pip install setuptools==65.0 wheel==0.38.4
|
||||
python -m pip install -e .
|
||||
```
|
||||
|
||||
## Currently supported
|
||||
- A2C
|
||||
- PPO
|
||||
- DQN
|
||||
- QRDQN
|
||||
- SAC
|
||||
|
||||
## Cheetah env installation if used with conda
|
||||
``` sh
|
||||
conda install -c conda-forge xorg-libx11
|
||||
conda install -c conda-forge glew
|
||||
conda install -c conda-forge mesalib
|
||||
conda install -c menpo glfw3
|
||||
```
|
|
@ -0,0 +1,46 @@
|
|||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
from huggingface_sb3 import load_from_hub
|
||||
from stable_baselines3.dqn.dqn import DQN
|
||||
import torch
|
||||
|
||||
from qpd.networks.models.student_six_model_dqn_ import StudentSixModelDQN
|
||||
from run import config
|
||||
from qpd.config import Config
|
||||
from qpd.networks.wrapper.model_wrapper import ModelWrapper
|
||||
|
||||
def get_environment(config: Config):
|
||||
return make_vec_env("CartPole-v1", n_envs=config.evaluator_config.env_workers, vec_env_cls=DummyVecEnv, env_kwargs={"render_mode": "human"})
|
||||
|
||||
c = Config(get_environment, config)
|
||||
|
||||
env = get_environment(c)
|
||||
|
||||
# Pretrained original model
|
||||
checkpoint = load_from_hub(
|
||||
repo_id="sb3/dqn-CartPole-v1",
|
||||
filename="dqn-CartPole-v1.zip",
|
||||
)
|
||||
|
||||
custom_objects = {
|
||||
"learning_rate": 0.0,
|
||||
"lr_schedule": lambda _: 0.0,
|
||||
"clip_range": lambda _: 0.0,
|
||||
}
|
||||
|
||||
model = DQN.load(checkpoint, env) #, custom_objects=custom_objects)
|
||||
teacher = ModelWrapper.construct_wrapper(c, model)
|
||||
|
||||
c.init_env_model_based_params(teacher)
|
||||
|
||||
# Loading quantized student
|
||||
student = StudentSixModelDQN(c)
|
||||
student.load("/home/ian/projects-idlab/euROBIN/code/framework/test/Data/compression/2024-06-14_15:56:55/best/student.pth")
|
||||
|
||||
|
||||
rewards_list = []
|
||||
state = env.reset()
|
||||
steps = 0
|
||||
rewards = 0
|
||||
|
||||
torch.onnx.export(student, torch.from_numpy(state), "/home/ian/projects-idlab/euROBIN/code/framework/test/Data/compression/2024-06-14_15:56:55/best/clean_student.onnx", verbose=True, input_names=["input_1"], output_names=["output_1"], opset_version=14, qat=True)
|
|
@ -0,0 +1,45 @@
|
|||
import time
|
||||
|
||||
import serial
|
||||
import numpy as np
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
|
||||
env = make_vec_env("CartPole-v1", n_envs=1, vec_env_cls=DummyVecEnv, env_kwargs={"render_mode": "human"})
|
||||
|
||||
rewards_list = []
|
||||
state = env.reset()
|
||||
steps = 0
|
||||
rewards = 0
|
||||
|
||||
port = serial.Serial("/dev/ttyUSB0", 115200, timeout=1)
|
||||
|
||||
def read_to_np():
|
||||
string = port.readline()
|
||||
return np.array([int(string)])
|
||||
|
||||
def write_ser(cmd):
|
||||
port.write(cmd)
|
||||
|
||||
while(1):
|
||||
tick = time.time()
|
||||
write_ser(str(state).encode())
|
||||
|
||||
actions = read_to_np()
|
||||
|
||||
tock = time.time()
|
||||
|
||||
state, reward, dones, info = env.step(actions)
|
||||
env.render()
|
||||
#time.sleep(0.02)
|
||||
rewards += reward
|
||||
steps += 1
|
||||
|
||||
print(1/(tock-tick))
|
||||
if np.all(dones):
|
||||
print(steps)
|
||||
print(info[0]["episode"]["r"])
|
||||
print(rewards)
|
||||
print(info)
|
||||
rewards = 0
|
||||
steps = 0
|
|
@ -0,0 +1,92 @@
|
|||
import gym
|
||||
import numpy as np
|
||||
|
||||
from qpd.config import Config
|
||||
from qpd.compressor import Compressor
|
||||
from huggingface_sb3 import load_from_hub
|
||||
from stable_baselines3 import A2C
|
||||
from stable_baselines3.common.env_util import make_atari_env, make_vec_env
|
||||
from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack
|
||||
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
|
||||
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
|
||||
|
||||
from qpd.networks.wrapper.student.cnn_student import CNNStudentNet
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
# Student driven no memory saving!
|
||||
|
||||
config = {
|
||||
"memory": {
|
||||
"size": 54000, # Size of memory used for distillation
|
||||
"update_frequency": 1, # Epoch frequency for updating the memory
|
||||
"update_size": 5400, # Minimum update size in steps
|
||||
"device": "cpu",
|
||||
|
||||
# Only used with framestacked environments
|
||||
"frame_stack_optimization": True, # Only store last frame
|
||||
|
||||
"check_consistency": True
|
||||
},
|
||||
"evaluator": {
|
||||
"student_driven": True, # Student decide the transitions in the environment
|
||||
"student_test_frequency": 10, # Epoch frequency
|
||||
"episodes": 50, # Minimum episodes for testing student
|
||||
"initialize": 30, # Amount of actions to skip at beginning of episode
|
||||
"ray_workers": 10, # Parallel ray workers used for updating and testing
|
||||
"device": "cpu",
|
||||
"deterministic": False,
|
||||
"exploration_rate": 0,
|
||||
},
|
||||
"compression": {
|
||||
"checkpoint_frequency": 2, # Epoch frequency for saving students
|
||||
"epochs": 600,
|
||||
"learning_rate": 1e-4,
|
||||
"batch_size": 256,
|
||||
"device": "cuda",
|
||||
|
||||
# Only used in discrete action spaces
|
||||
"T": 0.01, # Softmax hyperparameter
|
||||
"categorical": False,
|
||||
"critic_importance": 0.5,
|
||||
|
||||
# Only used in continuous action spaces
|
||||
"distribution": "Std", # Std, Mean
|
||||
"loss": "KL" # KL, Huber, MSE
|
||||
},
|
||||
"quantization": {
|
||||
"enabled": False,
|
||||
"bits": 8
|
||||
},
|
||||
"data_directory": "./data",
|
||||
"run_name": "test", # Change this for every run
|
||||
}
|
||||
|
||||
|
||||
def env_info_done_filter(info):
|
||||
key = "ale.lives" if "ale.lives" in info.keys() else "lives"
|
||||
print(f"Lives: {info[key]}")
|
||||
return info[key] == 0
|
||||
|
||||
def get_environment(config: Config):
|
||||
vec_env_cls = DummyVecEnv# SubprocVecEnv if subprocenv else DummyVecEnv
|
||||
env = make_atari_env("BreakoutNoFrameskip-v4", n_envs=config.evaluator_config.env_workers, seed=np.random.randint(0, 1000), vec_env_cls=vec_env_cls) #type: ignore
|
||||
env = VecFrameStack(env, n_stack=4)
|
||||
return env
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
checkpoint = load_from_hub(repo_id="sb3/a2c-BreakoutNoFrameskip-v4",filename="a2c-BreakoutNoFrameskip-v4.zip",)
|
||||
print(checkpoint)
|
||||
model = A2C.load(checkpoint)
|
||||
|
||||
c = Config(get_environment, config)
|
||||
|
||||
comp = Compressor(model, get_environment, c) \
|
||||
.student_network(CNNStudentNet) \
|
||||
.set_environment_done_filter(env_info_done_filter)
|
||||
|
||||
compressed_model = comp.compress()
|
||||
|
||||
compressed_model.save("test")
|
|
@ -0,0 +1,85 @@
|
|||
from qpd.config import Config
|
||||
|
||||
from qpd.compressor import Compressor
|
||||
from huggingface_sb3 import load_from_hub
|
||||
from stable_baselines3 import A2C
|
||||
from stable_baselines3.dqn.dqn import DQN
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
|
||||
from qpd.networks.models.student_six_model import StudentSixModel
|
||||
from qpd.networks.models.student_six_model_dqn_ import StudentSixModelDQN
|
||||
from qpd.networks.models.student_tiny_dqn import TinyStudentDQN
|
||||
from qpd.networks.wrapper.student.fully_connected_student import FCStudentNet
|
||||
|
||||
# Student driven no memory saving!
|
||||
|
||||
config = {
|
||||
"memory": {
|
||||
"size": 100000, # Size of memory used for distillation
|
||||
"update_frequency": 1, # Epoch frequency for updating the memory
|
||||
"update_size": 10000, # Minimum update size in steps
|
||||
"device": "cpu",
|
||||
|
||||
# Only used with framestacked environments
|
||||
"frame_stack_optimization": False, # Only store last frame
|
||||
|
||||
"check_consistency": True
|
||||
},
|
||||
"evaluator": {
|
||||
"student_driven": False, # Student decide the transitions in the environment
|
||||
"student_test_frequency": 10, # Epoch frequency
|
||||
"episodes": 20, # Minimum episodes for testing student
|
||||
"initialize": 0, # Amount of actions to skip at beginning of episode
|
||||
"ray_workers": 10, # Parallel ray workers used for updating and testing
|
||||
"device": "cpu",
|
||||
"deterministic": False
|
||||
},
|
||||
"compression": {
|
||||
"checkpoint_frequency": 2, # Epoch frequency for saving students
|
||||
"epochs": 600,
|
||||
"learning_rate": 5e-4,
|
||||
"batch_size": 64,
|
||||
"device": "cuda",
|
||||
|
||||
# Only used in discrete action spaces
|
||||
"T": 0.01, # Softmax hyperparameter
|
||||
"categorical": False,
|
||||
"critic_importance": 0.5,
|
||||
|
||||
# Only used in continuous action spaces
|
||||
"distribution": "Std", # Std, Mean
|
||||
"loss": "KL" # KL, Huber, MSE
|
||||
},
|
||||
"quantization": {
|
||||
"enabled": True,
|
||||
"bits": 8
|
||||
},
|
||||
"data_directory": "./data",
|
||||
"run_name": "test", # Change this for every run
|
||||
}
|
||||
#"/home/user/Workspace/University/PhD/Experiments/QPD",
|
||||
|
||||
|
||||
|
||||
def get_environment(config: Config):
|
||||
return make_vec_env("CartPole-v1", n_envs=config.evaluator_config.env_workers, vec_env_cls=DummyVecEnv)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
#checkpoint = load_from_hub(repo_id="sb3/a2c-CartPole-v1",filename="a2c-CartPole-v1.zip",)
|
||||
checkpoint = load_from_hub(repo_id="sb3/dqn-CartPole-v1",filename="dqn-CartPole-v1.zip",)
|
||||
# custom_objects = {
|
||||
# "learning_rate": 0.0,
|
||||
# "lr_schedule": lambda _: 0.0,
|
||||
# "clip_range": lambda _: 0.0,
|
||||
# }
|
||||
print(checkpoint)
|
||||
|
||||
c = Config(get_environment, config)
|
||||
|
||||
model = DQN.load(checkpoint, env=get_environment(c)) # , custom_objects=custom_objects)
|
||||
|
||||
# comp = Compressor(model, get_environment, c).student_network(FCStudentNet)
|
||||
comp = Compressor(model, get_environment, c).student_model(TinyStudentDQN)
|
||||
comp.compress()
|
|
@ -0,0 +1,86 @@
|
|||
from qpd.config import Config
|
||||
from qpd.networks.models.student_six_model import StudentSixModel
|
||||
from qpd.compressor import Compressor
|
||||
from huggingface_sb3 import load_from_hub
|
||||
from stable_baselines3.sac.sac import SAC
|
||||
from stable_baselines3.a2c.a2c import A2C
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
|
||||
|
||||
from qpd.networks.wrapper.student.fully_connected_student import FCStudentNet
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
config = {
|
||||
"memory": {
|
||||
"size": 100000, # Size of memory used for distillation
|
||||
"update_frequency": 1, # Epoch frequency for updating the memory
|
||||
"update_size": 10000, # Minimum update size in steps
|
||||
"device": "cpu",
|
||||
|
||||
# Only used with framestacked environments
|
||||
"frame_stack_optimization": False, # Only store last frame
|
||||
|
||||
"check_consistency": True
|
||||
},
|
||||
"evaluator": {
|
||||
"student_driven": True, # Student decide the transitions in the environment
|
||||
"student_test_frequency": 10, # Epoch frequency
|
||||
"episodes": 20, # Minimum episodes for testing student
|
||||
"initialize": 0, # Amount of actions to skip at beginning of episode
|
||||
"ray_workers": 10, # Parallel ray workers used for updating and testing
|
||||
"device": "cpu",
|
||||
"deterministic": False
|
||||
},
|
||||
"compression": {
|
||||
"checkpoint_frequency": 2, # Epoch frequency for saving students
|
||||
"epochs": 600,
|
||||
"learning_rate": 5e-4,
|
||||
"batch_size": 64,
|
||||
"device": "cuda",
|
||||
|
||||
# Only used in discrete action spaces
|
||||
"T": 0.01, # Softmax hyperparameter
|
||||
"categorical": False,
|
||||
"critic_importance": 0.5,
|
||||
|
||||
# Only used in continuous action spaces
|
||||
"distribution": "Std", # Std, Mean
|
||||
"loss": "KL" # KL, Huber, MSE
|
||||
},
|
||||
"quantization": {
|
||||
"enabled": False,
|
||||
"bits": 8
|
||||
},
|
||||
"data_directory": "./data",
|
||||
"run_name": "cheetah_no_quant", # Change this for every run
|
||||
}
|
||||
#"/home/user/Workspace/University/PhD/Experiments/QPD",
|
||||
|
||||
|
||||
|
||||
def get_environment_normalized(config: Config):
|
||||
env = make_vec_env("HalfCheetah-v3", n_envs=config.evaluator_config.env_workers, vec_env_cls=DummyVecEnv)
|
||||
|
||||
normalize = load_from_hub(
|
||||
repo_id="sb3/a2c-HalfCheetah-v3",
|
||||
filename="vec_normalize.pkl",
|
||||
)
|
||||
return VecNormalize.load(normalize, env)
|
||||
|
||||
def get_environment(config: Config):
|
||||
env = make_vec_env("HalfCheetah-v3", n_envs=config.evaluator_config.env_workers, vec_env_cls=DummyVecEnv)
|
||||
return env
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
checkpoint = load_from_hub(repo_id="sb3/sac-HalfCheetah-v3",filename="sac-HalfCheetah-v3.zip",)
|
||||
print(checkpoint)
|
||||
model = SAC.load(checkpoint)
|
||||
|
||||
c = Config(get_environment, config)
|
||||
|
||||
# comp = Compressor(model, get_environment, c).student_network(FCStudentNet)
|
||||
comp = Compressor(model, get_environment, c).student_model(StudentSixModel)
|
||||
compressed_model = comp.compress()
|
|
@ -0,0 +1,36 @@
|
|||
[project]
|
||||
name = "qpd"
|
||||
version = "0.0.1"
|
||||
authors = [
|
||||
{ name="Thomas Avé", email="thomas.ave@uantwerpen.be" },
|
||||
{ name="Ian Ravijts", email="ian.ravijts@uantwerpen.be" }
|
||||
]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
dependencies = [
|
||||
"wheel==0.38.4",
|
||||
"shimmy>=0.2.1",
|
||||
"numpy",
|
||||
"atari-py==0.2.9",
|
||||
"gymnasium[accept-rom-license]",
|
||||
"ale-py",
|
||||
"pillow",
|
||||
"readchar",
|
||||
"matplotlib",
|
||||
"distiller@git+https://github.com/tiny-rl/distiller.git",
|
||||
"wandb",
|
||||
"stable-baselines3[extra]",
|
||||
"sb3-contrib",
|
||||
"tqdm",
|
||||
"ray[default]",
|
||||
"webdavclient3",
|
||||
"prefetch_generator",
|
||||
"pygame",
|
||||
"pyglet",
|
||||
"cython<3",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = [ "setuptools>=61.0,<=66" ]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
|
@ -0,0 +1,193 @@
|
|||
from __future__ import annotations
|
||||
import torch
|
||||
import random
|
||||
import numpy
|
||||
import gzip
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Tuple, Dict, Optional
|
||||
|
||||
from qpd.config import Config
|
||||
|
||||
|
||||
class Memory:
|
||||
def __init__(self, config: Config, size=None, meta_class = None, path: Optional[str] = None) -> None:
|
||||
self.meta_name = str(meta_class.__class__) if meta_class else None
|
||||
|
||||
self.max_size = size if size else config.memory_config.size
|
||||
self.did_overflow = False
|
||||
self.current_index = 0
|
||||
|
||||
self.output_n = config.output_shape[0]
|
||||
self.device = config.memory_config.device
|
||||
self.check_consistency = config.memory_config.check_consistency
|
||||
self.critic = config.compression_config.critic_importance > 0
|
||||
self.frame_stack_optim = config.memory_config.frame_stack_optimization
|
||||
self.framestack_size = config.observation_shape[0] if self.frame_stack_optim else 0
|
||||
self.student_driven = config.evaluator_config.student_driven
|
||||
|
||||
if path:
|
||||
print(f"Load memory on path: {path}")
|
||||
self.load(path)
|
||||
else:
|
||||
self.states = torch.zeros((self.max_size,) + (config.observation_shape[1:] if self.frame_stack_optim else config.observation_shape), device=self.device) # [(Input frame, output actions)]
|
||||
self.outputs = torch.zeros((self.max_size, self.output_n), device=self.device) # + 1 = critic
|
||||
# self.repeats = torch.zeros(self.size, device=self.device)
|
||||
self.start = set()
|
||||
|
||||
|
||||
def update(self, state: torch.Tensor, outputs: torch.Tensor, start: bool) -> None:
|
||||
assert not torch.any(outputs == 0), f"[ERROR]: Update input is not valid!"
|
||||
|
||||
# update overflow
|
||||
self.did_overflow = self.will_overflow()
|
||||
|
||||
if start:
|
||||
self.start.add(self.current_index)
|
||||
else:
|
||||
self.start.discard(self.current_index)
|
||||
|
||||
state = state.to(self.device).data
|
||||
self.states[self.current_index] = state[:, -1] if self.frame_stack_optim else state # Story only latest frame
|
||||
|
||||
self.outputs[self.current_index] = outputs.to(self.device).data
|
||||
self.current_index = (self.current_index + 1) % self.max_size
|
||||
|
||||
if self.check_consistency:
|
||||
self.verifyConsistency(f"Update current_index: {self.current_index} is start: {start} outputs: {self.outputs}")
|
||||
|
||||
def will_overflow(self, update_size=1):
|
||||
if (self.current_index + update_size) >= self.max_size:
|
||||
return True
|
||||
return False
|
||||
|
||||
def real_size(self):
|
||||
return (self.states.element_size() * self.states.nelement()) + (self.outputs.element_size() * self.outputs.nelement())
|
||||
|
||||
def sample(self, batch_size: int, shuffle: bool=True) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
state_indexes = list(range(len(self)))
|
||||
if shuffle:
|
||||
random.shuffle(state_indexes)
|
||||
splits = [state_indexes[i:i + batch_size] for i in range(0, len(state_indexes), batch_size)]
|
||||
|
||||
if self.frame_stack_optim:
|
||||
for split in splits:
|
||||
current_states = []
|
||||
indexes = torch.tensor(split)
|
||||
for i in range(self.framestack_size):
|
||||
current_states.append(self.states[(indexes - (self.framestack_size - 1) + i) % len(self)].unsqueeze(1))
|
||||
yield (torch.cat(current_states, dim=1), self.outputs[split])
|
||||
else:
|
||||
for split in splits:
|
||||
yield (self.states[split], self.outputs[split])
|
||||
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.frame_stack_optim:
|
||||
states_index = list(self.start)[index]
|
||||
return (torch.cat([self.states[i % len(self)].unsqueeze(0) for i in range(states_index-3, states_index+1)]), self.outputs[states_index])
|
||||
else:
|
||||
return (self.states[index], self.outputs[index])
|
||||
|
||||
|
||||
def __len__(self) -> int:
|
||||
if self.did_overflow:
|
||||
return len(self.outputs)
|
||||
else:
|
||||
return self.current_index
|
||||
|
||||
|
||||
def prune(self) -> None:
|
||||
if not self.did_overflow:
|
||||
for item in ["states", "outputs"]:
|
||||
setattr(self, item, getattr(self, item)[0: self.current_index])
|
||||
self.max_size = self.current_index
|
||||
|
||||
if self.check_consistency:
|
||||
self.verifyConsistency("Prune")
|
||||
|
||||
|
||||
def grow(self, amount: int) -> int:
|
||||
print(f"Growing with size: {amount}")
|
||||
for item in ["states", "outputs"]:
|
||||
obj = getattr(self, item)
|
||||
tmp = torch.zeros((self.max_size + amount,) + obj.shape[1:], device = self.device, dtype=obj.dtype)
|
||||
tmp[:len(self)] = obj
|
||||
setattr(self, item, tmp)
|
||||
|
||||
self.max_size += amount
|
||||
self.did_overflow = False
|
||||
|
||||
if self.verifyConsistency:
|
||||
self.verifyConsistency("Grow")
|
||||
return len(self)
|
||||
|
||||
|
||||
def load_memory(self, other: Memory) -> None:
|
||||
assert other.max_size <= self.max_size, f"Memory to load is bigger than current memory! {other.max_size} > {self.max_size}"
|
||||
|
||||
self.did_overflow = self.will_overflow(len(other))
|
||||
|
||||
for item in ["states", "outputs"]:
|
||||
getattr(self, item)[self.current_index:min(self.current_index+len(other), self.max_size)] = getattr(other, item)[:min(len(other), self.max_size - self.current_index)]
|
||||
|
||||
if self.current_index + len(other) > self.max_size:
|
||||
getattr(self, item)[:len(other) - (self.max_size - self.current_index)] = getattr(other, item)[self.max_size - self.current_index:]
|
||||
|
||||
for idx in range(len(other)):
|
||||
new_idx = (self.current_index + idx) % self.max_size
|
||||
self.start.discard(new_idx)
|
||||
if idx in other.start:
|
||||
self.start.add(new_idx)
|
||||
|
||||
self.current_index = (self.current_index + len(other)) % self.max_size
|
||||
|
||||
if self.check_consistency:
|
||||
self.verifyConsistency("Load memory")
|
||||
|
||||
|
||||
def to(self, device: torch.device):
|
||||
for item in ["states", "outputs"]:
|
||||
setattr(self, item, getattr(self, item).to(device))
|
||||
|
||||
self.device = device
|
||||
|
||||
|
||||
def load(self, path: str) -> None:
|
||||
assert self.meta_name, "Can not load this memory! No meta_class specified!"
|
||||
with open(f"{path}/metadata.json") as f:
|
||||
j = json.loads(f.read())
|
||||
assert j["teacher"] == self.meta_name
|
||||
self.current_index = j["current_index"]
|
||||
self.did_overflow = j["did_overflow"]
|
||||
|
||||
for item in ["states", "outputs"]:
|
||||
with gzip.GzipFile(f"{path}/{item}.npy.gz") as f:
|
||||
setattr(self, item, torch.from_numpy(numpy.load(file=f)))
|
||||
|
||||
with gzip.GzipFile(f"{path}/start.npy.gz") as f:
|
||||
self.start = set(numpy.load(file=f))
|
||||
|
||||
self.max_size = self.states.shape[0]
|
||||
|
||||
if self.check_consistency:
|
||||
self.verifyConsistency("Load")
|
||||
|
||||
|
||||
def save(self, path: str) -> None:
|
||||
assert self.meta_name, "Can not save this memory! No meta_class specified!"
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
for item in ["states", "outputs"]:
|
||||
with gzip.GzipFile(f"{path}/{item}.npy.gz", "w") as f:
|
||||
numpy.save(f, getattr(self, item).cpu().numpy())
|
||||
|
||||
with gzip.GzipFile(f"{path}/start.npy.gz", "w") as f:
|
||||
numpy.save(f, numpy.array(list(self.start)))
|
||||
with open(f"{path}/metadata.json", "w") as f:
|
||||
f.write(json.dumps({"teacher": self.meta_name, "did_overflow": self.did_overflow, "current_index": self.current_index}))
|
||||
|
||||
def verifyConsistency(self, mes) -> None:
|
||||
assert not torch.any(self.outputs[:len(self)] == 0), f"[ERROR]: Memory is not consistent! ({mes})"
|
||||
# assert not torch.any(torch.bitwise_or(self.outputs[:len(self)] > 100, self.outputs[:len(self)] < -100)), \
|
||||
# f"[ERROR]: Memory is not consistent in a weird way! max: {self.outputs.max()} min: {self.outputs.min()} ({mes})"
|
||||
|
|
@ -0,0 +1,275 @@
|
|||
# import wandb
|
||||
import random
|
||||
import os
|
||||
import gc
|
||||
|
||||
# from ..utils import WebdavManager
|
||||
from collections import defaultdict
|
||||
from distiller.quantization import DorefaQuantizer
|
||||
from gym.spaces import Box, Discrete
|
||||
from pathlib import Path
|
||||
from qpd.config import Config
|
||||
from qpd.networks.wrapper.stable_baselines.sb_wrapper import SBWrapper
|
||||
from qpd.utils import summary
|
||||
from qpd.utils.evaluator import Evaluator
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
|
||||
from torch.distributions import Categorical
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from .memory import Memory
|
||||
|
||||
from ..networks import *
|
||||
from ..networks.wrapper.model_wrapper import ModelWrapper
|
||||
# from ..utils import test_network, merge_config
|
||||
|
||||
class PolicyDistillation:
|
||||
def __init__(self, teacher: ModelWrapper, student: ModelWrapper, config: Config): #, webdav_manager: Optional[WebdavManager] = None):
|
||||
self.teacher = teacher
|
||||
self.student = student
|
||||
self.config = config
|
||||
|
||||
self.epochs = self.config.compression_config.epochs
|
||||
self.batch_size = self.config.compression_config.batch_size
|
||||
self.device = self.config.compression_config.device
|
||||
self.check_point_frequency = self.config.compression_config.checkpoint_frequency
|
||||
self.mem_update_frequency = self.config.memory_config.update_frequency
|
||||
|
||||
# self.webdav_manager = webdav_manager
|
||||
self.results = defaultdict(list)
|
||||
|
||||
self.student.to(self.device)
|
||||
self.optimizer = torch.optim.RMSprop(self.student.parameters(), lr=config.student_config.learning_rate)
|
||||
self.evaluator = Evaluator(config, verbose=True)
|
||||
self.quantizer = self.get_quantizer()
|
||||
|
||||
self.initial_environment = self.config.environment_constructor()
|
||||
self.best_student = None
|
||||
|
||||
self.memory_order = list(range(1, self.config.compression_config.epochs+1))
|
||||
self.memory = self.load_memory()
|
||||
random.shuffle(self.memory_order)
|
||||
|
||||
|
||||
def train(self):
|
||||
best = -1 * float("inf")
|
||||
for epoch in range(self.epochs):
|
||||
print(f"Start epoch {epoch}")
|
||||
epoch_loss = 0
|
||||
|
||||
self.student.train()
|
||||
self.student.to(self.device)
|
||||
|
||||
with torch.enable_grad():
|
||||
for states, outputs in self.memory.sample(self.batch_size):
|
||||
loss = self.continuous_loss(states.to(self.device, dtype=torch.float), outputs.to(self.device)) if self.config.env_continuous_action_space else self.discrete_loss(states.to(self.device, dtype=torch.float), outputs.to(self.device))
|
||||
|
||||
assert not torch.isnan(loss), "Error: loss can't be nan!"
|
||||
epoch_loss += self.step(loss)
|
||||
|
||||
# wandb.log({"loss": epoch_loss, "Epoch": epoch})
|
||||
print(f"Epoch {epoch} completed, loss: {epoch_loss}")
|
||||
self.results["Loss"].append({"Epoch": epoch, "Value": epoch_loss})
|
||||
if (epoch and epoch % self.config.evaluator_config.student_test_frequency == 0) or self.config.evaluator_config.student_test_frequency == 1:
|
||||
best, is_best = self.test(epoch, best)
|
||||
if is_best:
|
||||
self.save_student("best")
|
||||
|
||||
if (epoch and epoch % self.check_point_frequency == 0) or self.check_point_frequency == 1:
|
||||
self.save_student(epoch)
|
||||
|
||||
if epoch and epoch % self.mem_update_frequency == 0:
|
||||
self.load_memory(epoch, self.memory)
|
||||
|
||||
return self.best_student
|
||||
# if self.webdav_manager:
|
||||
# self.webdav_manager.upload_results(self.results)
|
||||
|
||||
def load_memory(self, epoch: int=0, memory: Optional[Memory]=None) -> Memory:
|
||||
memory_index = (self.memory_order.pop() if epoch else 0)
|
||||
# memory_index = (epoch if epoch else self.memory_order.pop())
|
||||
path = f"{self.config.data_directory}/Resources/{str(self.initial_environment.envs[0].env) if self.config.is_vec_env else str(self.initial_environment.env)}/Memory/{self.teacher.__class__.__name__}/sd_{str(self.config.evaluator_config.student_driven)}/{memory_index}"
|
||||
if os.path.exists(path):
|
||||
tmp_memory = Memory(self.config, meta_class=self.teacher, path=path)
|
||||
|
||||
if memory:
|
||||
memory.load_memory(tmp_memory)
|
||||
return memory
|
||||
return tmp_memory
|
||||
else:
|
||||
print()
|
||||
print("= Memory ==========================================================")
|
||||
print(f"= Could not find memory with index: {memory_index}, generating new data...")
|
||||
if not memory:
|
||||
memory = Memory(self.config, meta_class=self.teacher)
|
||||
steps = memory.max_size
|
||||
else:
|
||||
steps = self.config.memory_config.update_size
|
||||
|
||||
print(f"= Collecting data for {steps}")
|
||||
results = self.evaluator.run(self.student, steps=steps, collect=True, teacher=self.teacher) if self.config.evaluator_config.student_driven else self.evaluator.run(self.teacher, steps=steps, collect=True)
|
||||
print("= Current average steps:", results["Steps"])
|
||||
print("= Current average reward:", results["Reward"])
|
||||
print("= Average amount of episode:", results["EpisodeCount"])
|
||||
print("===================================================================")
|
||||
print()
|
||||
|
||||
if "GPULab" in os.environ or self.config.evaluator_config.student_driven:
|
||||
for _ in range(len(results["Memory"])):
|
||||
memory.load_memory(results["Memory"].pop())
|
||||
gc.collect()
|
||||
|
||||
else:
|
||||
total_size = sum([len(m) for m in results["Memory"]])
|
||||
tmp_memory = Memory(self.config, size=total_size, meta_class=self.teacher)
|
||||
for _ in range(len(results["Memory"])):
|
||||
tmp_memory.load_memory(results["Memory"].pop())
|
||||
gc.collect()
|
||||
tmp_memory.save(path)
|
||||
if epoch:
|
||||
memory.load_memory(tmp_memory)
|
||||
else:
|
||||
return tmp_memory
|
||||
return memory
|
||||
|
||||
|
||||
def discrete_loss(self, states:torch.Tensor, teacher_outputs:torch.Tensor) -> torch.Tensor:
|
||||
T = self.config.compression_config.T
|
||||
student_policy, student_value, _ = self.student(states)
|
||||
|
||||
teacher_policy, teacher_value_std = teacher_outputs[:, :student_policy.shape[1]], teacher_outputs[:, student_policy.shape[1]:]
|
||||
teacher_value = teacher_value_std[:,:teacher_policy.shape[1]]
|
||||
|
||||
if self.config.compression_config.categorical:
|
||||
dist = Categorical(logits=student_policy)
|
||||
student_policy = dist.probs
|
||||
dist2 = Categorical(logits=teacher_policy)
|
||||
teacher_policy = dist2.probs
|
||||
|
||||
policy_loss = nn.KLDivLoss(reduction="batchmean")(F.log_softmax(student_policy, dim=1), F.softmax(teacher_policy/T, dim=1)) #type: ignore
|
||||
critic_importance = self.config.compression_config.critic_importance
|
||||
if critic_importance and not student_value == None:
|
||||
value_loss = nn.HuberLoss(reduction="mean")(student_value.to(self.device), teacher_value.to(self.device))
|
||||
value_sum = value_loss.sum().detach()
|
||||
policy_sum = policy_loss.sum().detach()
|
||||
if policy_sum < 1e-6 or value_sum < 1e-6: # Avoid divide by 0
|
||||
return (1-critic_importance) * policy_loss + critic_importance * value_loss
|
||||
else:
|
||||
return ((1-critic_importance) * (policy_loss / policy_sum) + critic_importance * (value_loss / value_sum)) * (value_sum + policy_sum)
|
||||
else:
|
||||
return policy_loss
|
||||
|
||||
|
||||
def continuous_loss(self, states:torch.Tensor, teacher_outputs:torch.Tensor) -> torch.Tensor:
|
||||
self.student.to(self.device)
|
||||
student_policy, _, student_std = self.student(states)
|
||||
# student_policy, student_value = student_outputs[:, :actions], student_outputs[:, actions:]
|
||||
teacher_policy, teacher_value_std = teacher_outputs[:, :student_policy.shape[1]], teacher_outputs[:, student_policy.shape[1]:]
|
||||
|
||||
|
||||
if self.config.compression_config.distribution == "Std":
|
||||
teacher_std = teacher_value_std[:, -teacher_policy.shape[1]:] # type: ignore
|
||||
|
||||
if not teacher_std.gt(0).all():
|
||||
teacher_std[teacher_std==0] = 1
|
||||
print("Fixing teacher std")
|
||||
if not student_std.gt(0).all():
|
||||
student_std[student_std==0] = 1
|
||||
print("Fixing student std")
|
||||
|
||||
if self.config.compression_config.loss == "KL":
|
||||
policy_loss = torch.log((teacher_std / student_std)) + (student_std ** 2 + (student_policy - teacher_policy) ** 2) / (2 * (teacher_std ** 2)) - 1/2 # type: ignore
|
||||
elif self.config.compression_config.loss == "Huber":
|
||||
policy_loss = nn.HuberLoss(reduction="none")(student_policy, teacher_policy)
|
||||
std_loss = nn.HuberLoss(reduction="none")(student_std, teacher_std)
|
||||
policy_loss += std_loss
|
||||
elif self.config.compression_config.loss == "MSE":
|
||||
policy_loss = nn.MSELoss(reduction="none")(student_policy, teacher_policy)
|
||||
std_loss = nn.MSELoss(reduction="none")(student_std, teacher_std)
|
||||
policy_loss += std_loss
|
||||
else:
|
||||
raise RuntimeError("Unknown loss type: " + self.config.compression_config.loss)
|
||||
else:
|
||||
if self.config.compression_config.loss == "Huber":
|
||||
policy_loss = nn.HuberLoss(reduction="none")(student_policy, teacher_policy) #type: ignore
|
||||
elif self.config.compression_config.loss == "MSE":
|
||||
policy_loss = nn.MSELoss(reduction="none")(student_policy, teacher_policy) #type: ignore
|
||||
else:
|
||||
raise RuntimeError("Unsupported combination of distribution and loss!")
|
||||
|
||||
return policy_loss.sum()
|
||||
|
||||
|
||||
def step(self, loss: torch.Tensor) -> float:
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.student.zero_grad()
|
||||
|
||||
if self.quantizer:
|
||||
self.quantizer.quantize_params()
|
||||
|
||||
return loss.item()
|
||||
|
||||
|
||||
def get_quantizer(self) -> Optional[DorefaQuantizer]:
|
||||
if self.config.quantization_config.enabled:
|
||||
bits = self.config.quantization_config.bits
|
||||
# self.student = distiller.modules.convert_model_to_distiller_lstm(self.student)
|
||||
quantizer = DorefaQuantizer(self.student, self.optimizer, bits_activations=bits, bits_weights=bits, bits_bias=bits)
|
||||
quantizer.prepare_model()
|
||||
quantizer.quantize_params()
|
||||
return quantizer
|
||||
return None
|
||||
|
||||
|
||||
def save_student(self, epoch: Union[int, str], name: str="student"):
|
||||
checkpoint_directory = self.config.data_directory
|
||||
base_path = f"{checkpoint_directory}/Data/{self.config.run_name}/compression/{self.config.compression_config.starting_time}"
|
||||
Path(f"{base_path}/{epoch}").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.student.save(f"{base_path}/{epoch}/{name}")
|
||||
|
||||
# if self.webdav_manager:
|
||||
# self.webdav_manager.upload_model(student_save_path, epoch, not self.config["testing"]["local_checkpoints"])
|
||||
|
||||
def test(self, epoch: int, best: Optional[float]=None) -> Optional[float]:
|
||||
is_best = False
|
||||
print()
|
||||
print("= Testing =========================================================")
|
||||
results = self.evaluator.run(self.student, episodes=self.config.evaluator_config.episodes)
|
||||
self.student.to(self.config.compression_config.device)
|
||||
log = {"Epoch": epoch}
|
||||
for metric in ["Steps", "Reward"]:
|
||||
self.results[metric].append({"Epoch": epoch, "Value": results[metric]})
|
||||
log[metric] = results[metric]
|
||||
|
||||
# wandb.log(log)
|
||||
|
||||
print("= Current average steps:", results["Steps"])
|
||||
print("= Current average reward:", results["Reward"])
|
||||
print("= Average amount of episode:", results["EpisodeCount"])
|
||||
print("===================================================================")
|
||||
print()
|
||||
|
||||
if best and results["Reward"] > best:
|
||||
# wandb.log({"Best Reward": results["Reward"]})
|
||||
best = results["Reward"]
|
||||
is_best = True
|
||||
|
||||
return best, is_best
|
||||
|
||||
def get_action_space_shape(self):
|
||||
action_space = self.student.policy.action_space
|
||||
is_box = isinstance(action_space, Box)
|
||||
assert is_box or isinstance(action_space, Discrete), "[ERROR]: Unsupported action space!"
|
||||
|
||||
if is_box:
|
||||
return action_space.shape
|
||||
else:
|
||||
return action_space.n
|
||||
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
from typing import Type
|
||||
import qpd.networks.network_interface as ni
|
||||
|
||||
from qpd.config import Config
|
||||
from qpd.networks.wrapper.student.student_net import StudentNet
|
||||
from qpd.utils import summary
|
||||
from .compression.policy_distillation import PolicyDistillation
|
||||
from .networks.wrapper.model_wrapper import ModelWrapper
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Compressor:
|
||||
def __init__(self, model, environment_constructor, config):
|
||||
self.model = model
|
||||
self.environment_constructor = environment_constructor
|
||||
self.config: Config = config
|
||||
|
||||
self.teacher: ModelWrapper = ModelWrapper.construct_wrapper(self.config, self.model).to(torch.device(self.config.compression_config.device))
|
||||
self.student = None
|
||||
self.config.init_env_model_based_params(self.teacher)
|
||||
|
||||
|
||||
def student_network(self, net: Type[StudentNet]):
|
||||
"""Add student network to compressor with existing wrappers.
|
||||
|
||||
Args:
|
||||
net (Type[nn.Module]): Student network
|
||||
|
||||
Returns:
|
||||
self
|
||||
"""
|
||||
assert self.student == None, "Can not set second student!"
|
||||
self.config.student_config.student_network = net
|
||||
self.student_model = self.teacher.construct_student()
|
||||
self.student = ModelWrapper.construct_wrapper(self.config, self.student_model).to(torch.device(self.config.compression_config.device))
|
||||
return self
|
||||
|
||||
def student_model(self, student: Type["ni.NetworkInterface"]):
|
||||
"""Add end-to-end student model to compressor
|
||||
|
||||
Args:
|
||||
student (Type["ni.NetworkInterface"]): _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
assert self.student == None, "Can not set second student!"
|
||||
self.student = student(self.config)
|
||||
return self
|
||||
|
||||
def set_environment_done_filter(self, func):
|
||||
self.config.evaluator_config.environment_is_done_filter = func
|
||||
return self
|
||||
|
||||
def compress(self):
|
||||
assert self.student is not None, "No student assigned! Use Compressor(**).student_network(**) or Use Compressor(**).student_model(**)"
|
||||
|
||||
Path(f"{self.config.data_directory}/Data/{self.config.run_name}").mkdir(parents=True, exist_ok=True)
|
||||
config_dump = json.dumps(self.config.serializable(), indent=4, sort_keys=True)
|
||||
with open(f"{self.config.data_directory}/Data/{self.config.run_name}/config.json", "w") as f:
|
||||
f.write(config_dump)
|
||||
|
||||
print("Current configuration:")
|
||||
print(config_dump)
|
||||
|
||||
distillation = PolicyDistillation(
|
||||
self.teacher,
|
||||
self.student,
|
||||
self.config
|
||||
)
|
||||
|
||||
distillation.train()
|
||||
|
||||
return distillation.student
|
|
@ -0,0 +1,165 @@
|
|||
import datetime
|
||||
|
||||
from gymnasium.spaces import Box, Discrete
|
||||
|
||||
import qpd.networks.wrapper.model_wrapper as wrapper
|
||||
import qpd.utils.model_utils as model_utils
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from qpd.networks.wrapper.student.fully_connected_student import FCStudentNet
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
|
||||
|
||||
class AbstractConfig:
|
||||
key = None
|
||||
|
||||
def init(self, config: dict):
|
||||
if self.key in config.keys():
|
||||
self.update(config[self.key])
|
||||
return self
|
||||
|
||||
def update(self, config):
|
||||
self.__dict__.update((k, config[k]) for k in set(config).intersection(self.__dict__))
|
||||
|
||||
def serializable(self) -> dict:
|
||||
tmp = {}
|
||||
|
||||
for key, value in self.__dict__.items():
|
||||
v = None
|
||||
|
||||
if isinstance(value, AbstractConfig):
|
||||
v = value.serializable()
|
||||
elif isinstance(value, type):
|
||||
v = str(value)
|
||||
elif callable(value):
|
||||
continue
|
||||
else:
|
||||
v = value
|
||||
|
||||
tmp[key] = v
|
||||
|
||||
return tmp
|
||||
|
||||
class MemoryConfig(AbstractConfig):
|
||||
key = "memory"
|
||||
def __init__(self):
|
||||
self.size = 54000
|
||||
self.update_frequency = 1
|
||||
self.update_size = 5400
|
||||
self.device = "cpu"
|
||||
self.check_consistency = True
|
||||
self.frame_stack_optimization = False
|
||||
|
||||
class StudentNetConfig(AbstractConfig):
|
||||
key = "student"
|
||||
def __init__(self):
|
||||
self.student_network = FCStudentNet
|
||||
self.learning_rate = 1e-4
|
||||
self.activation_func = nn.ReLU
|
||||
self.extra_student_kwargs = {}
|
||||
|
||||
class CompressionConfig(AbstractConfig):
|
||||
key = "compression"
|
||||
def __init__(self):
|
||||
self.checkpoint_frequency = 2
|
||||
self.starting_time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
|
||||
self.epochs = 600
|
||||
self.batch_size = 256
|
||||
self.T = 0.01
|
||||
self.device = torch.device("cuda")
|
||||
self.critic_importance = 0.5
|
||||
self.categorical = False
|
||||
self.loss = "MSE"
|
||||
self.distribution = "Std"
|
||||
|
||||
class EvaluatorConfig(AbstractConfig):
|
||||
key = "evaluator"
|
||||
def __init__(self):
|
||||
self.environment_is_done_filter = lambda env_info: True
|
||||
self.student_test_frequency = 1
|
||||
self.episodes = 50
|
||||
self.initialize = 30
|
||||
self.env_subproc = False
|
||||
self.env_workers = 1
|
||||
self.ray_workers = 24
|
||||
self.device = "cpu"
|
||||
self.deterministic = False
|
||||
self.student_driven = False
|
||||
self.exploration_rate = 0.
|
||||
|
||||
class QuantizationConfig(AbstractConfig):
|
||||
key = "quantization"
|
||||
def __init__(self):
|
||||
self.enabled = False
|
||||
self.bits = 8
|
||||
|
||||
class Config(AbstractConfig):
|
||||
def __init__(self, environment_constructor, config={}):
|
||||
self.environment_constructor = lambda: environment_constructor(self)
|
||||
|
||||
self.memory_config = MemoryConfig().init(config)
|
||||
self.student_config = StudentNetConfig().init(config)
|
||||
self.compression_config = CompressionConfig().init(config)
|
||||
self.evaluator_config = EvaluatorConfig().init(config)
|
||||
self.quantization_config = QuantizationConfig().init(config)
|
||||
|
||||
self.verbose = True
|
||||
self.data_directory = "."
|
||||
self.run_name = "Test"
|
||||
|
||||
# AUTO-DEFINED
|
||||
self.model_has_critic = False
|
||||
self.observation_shape = None
|
||||
self.action_shape = None
|
||||
self.output_shape = None
|
||||
self.is_vec_env = False
|
||||
self.env_continuous_action_space = False
|
||||
|
||||
self.update(config)
|
||||
|
||||
def update_memory_config(self, **kwargs):
|
||||
self.memory_config.update(**kwargs)
|
||||
|
||||
def update_student_config(self, **kwargs):
|
||||
self.student_config.update(**kwargs)
|
||||
|
||||
def update_compressor_config(self, **kwargs):
|
||||
self.compression_config.update(**kwargs)
|
||||
|
||||
def update_evaluator_config(self, **kwargs):
|
||||
self.evaluator_config.update(**kwargs)
|
||||
|
||||
def update_quantization_config(self, **kwargs):
|
||||
self.quantization_config.update(**kwargs)
|
||||
|
||||
def init_env_model_based_params(self, wrapper: "wrapper.ModelWrapper"):
|
||||
wrapper.to(self.compression_config.device)
|
||||
initialized_environment = self.environment_constructor()
|
||||
|
||||
action_space = initialized_environment.action_space
|
||||
is_box = isinstance(action_space, Box)
|
||||
assert is_box or isinstance(action_space, Discrete), f"[ERROR]: Unsupported action space of type {type(action_space)}!"
|
||||
|
||||
state = initialized_environment.reset() # Get state to determin shape of observation
|
||||
state = wrapper.observation_to_tensor(state).to(self.compression_config.device).float() # get usable state and determin if env is vec_env
|
||||
|
||||
self.is_vec_env = isinstance(initialized_environment, VecEnv)
|
||||
|
||||
if self.is_vec_env:
|
||||
self.observation_shape = state.shape[1:]
|
||||
else:
|
||||
self.observation_shape = state.shape
|
||||
|
||||
if is_box:
|
||||
self.env_continuous_action_space = True
|
||||
else:
|
||||
self.env_continuous_action_space = False
|
||||
|
||||
action, value, action_std = wrapper(state)
|
||||
|
||||
self.model_has_critic = True if value is not None else False
|
||||
|
||||
outputs = model_utils.get_output(action[0], value[0] if value is not None else None, action_std[0] if action_std is not None else None) if self.is_vec_env else model_utils.get_output(action, value, action_std)
|
||||
|
||||
self.action_shape = action.shape[1:] if self.is_vec_env else action.shape
|
||||
self.output_shape = outputs.shape
|
|
@ -0,0 +1,2 @@
|
|||
# from .student import Student
|
||||
# from .teacher import Teacher, TeacherA2C, TeacherDQN, TeacherActorCritic, TeacherPPO, TeacherQRDQN #type: ignore
|
|
@ -0,0 +1,45 @@
|
|||
from qpd.config import Config
|
||||
from qpd.networks.network_interface import NetworkInterface
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributions import Normal
|
||||
|
||||
class StudentSixModel(NetworkInterface):
|
||||
def __init__(self, config: Config):
|
||||
super(StudentSixModel, self).__init__(config)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(self.config.observation_shape[0], 64),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Linear(64, 64),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Linear(64, 64),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.action_head = nn.Linear(64, 6)
|
||||
self.std_head = nn.Linear(64, 6)
|
||||
self.critic_head = nn.Linear(64, 1)
|
||||
|
||||
for module in self.children():
|
||||
if type(module) == nn.Linear:
|
||||
module.bias.data.uniform_(0.0) # type: ignore
|
||||
module.weight.data.uniform_(0, 0.01) # type: ignore
|
||||
|
||||
def forward(self, observations):
|
||||
x = self.net(observations)
|
||||
mean_actions = self.action_head(x)
|
||||
std = self.std_head(x)**2
|
||||
dist = Normal(mean_actions, std)
|
||||
self.actions = dist.sample()
|
||||
|
||||
value = self.critic_head(x)
|
||||
|
||||
value = torch.cat([value, value], dim=0)
|
||||
|
||||
return mean_actions, value, std
|
||||
|
||||
def get_actions(self):
|
||||
return self.actions.cpu().numpy()
|
|
@ -0,0 +1,37 @@
|
|||
from qpd.config import Config
|
||||
from qpd.networks.network_interface import NetworkInterface
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributions import Normal
|
||||
|
||||
class StudentSixModelDQN(NetworkInterface):
|
||||
def __init__(self, config: Config):
|
||||
super(StudentSixModelDQN, self).__init__(config)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(self.config.observation_shape[0], 64),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Linear(64, 64),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Linear(64, 64),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.action_head = nn.Linear(64, 2)
|
||||
|
||||
for module in self.children():
|
||||
if type(module) == nn.Linear:
|
||||
module.bias.data.uniform_(0.0) # type: ignore
|
||||
module.weight.data.uniform_(0, 0.01) # type: ignore
|
||||
|
||||
def forward(self, observations):
|
||||
x = self.net(observations)
|
||||
self.mean_actions = self.action_head(x)
|
||||
|
||||
return self.mean_actions, None, None
|
||||
|
||||
def get_actions(self):
|
||||
actions = self.mean_actions.argmax(dim=1).reshape(-1)
|
||||
return actions.cpu().numpy()
|
|
@ -0,0 +1,37 @@
|
|||
from qpd.config import Config
|
||||
from qpd.networks.network_interface import NetworkInterface
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributions import Normal
|
||||
|
||||
class TinyStudentDQN(NetworkInterface):
|
||||
def __init__(self, config: Config):
|
||||
super(TinyStudentDQN, self).__init__(config)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(self.config.observation_shape[0], 16),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Linear(16, 16),
|
||||
nn.ReLU(),
|
||||
|
||||
# nn.Linear(16, 16),
|
||||
# nn.ReLU()
|
||||
)
|
||||
|
||||
self.action_head = nn.Linear(16, 2)
|
||||
|
||||
for module in self.children():
|
||||
if type(module) == nn.Linear:
|
||||
module.bias.data.uniform_(0.0) # type: ignore
|
||||
module.weight.data.uniform_(0, 0.01) # type: ignore
|
||||
|
||||
def forward(self, observations):
|
||||
x = self.net(observations)
|
||||
self.mean_actions = self.action_head(x)
|
||||
|
||||
return self.mean_actions, None, None
|
||||
|
||||
def get_actions(self):
|
||||
actions = self.mean_actions.argmax(dim=1).reshape(-1)
|
||||
return actions.cpu().numpy()
|
|
@ -0,0 +1,25 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import qpd.config as config
|
||||
|
||||
class NetworkInterface(nn.Module):
|
||||
def __init__(self, config: "config.Config"):
|
||||
super(NetworkInterface, self).__init__()
|
||||
self.config: "config.Config" = config
|
||||
|
||||
def forward(self, x):
|
||||
raise NotImplementedError(f"Forward method not implemented for {type(self)}!")
|
||||
|
||||
def get_actions(self):
|
||||
raise NotImplementedError(f"Actions method not implemented for {type(self)}!")
|
||||
|
||||
def save(self, path):
|
||||
torch.save(self.state_dict(), path + ".pth")
|
||||
|
||||
def load(self, path):
|
||||
self.load_state_dict(torch.load(path))
|
||||
self.eval()
|
||||
|
||||
def observation_to_tensor(self, x):
|
||||
return torch.from_numpy(x)
|
|
@ -0,0 +1,7 @@
|
|||
from typing import List, Type
|
||||
|
||||
import qpd.networks.wrapper.model_wrapper as wrapper
|
||||
registered_teachers: List[Type[wrapper.ModelWrapper]] = []
|
||||
|
||||
# from .stable_baselines import *
|
||||
import qpd.networks.wrapper.stable_baselines
|
|
@ -0,0 +1,19 @@
|
|||
import qpd.config as config
|
||||
from qpd.networks.network_interface import NetworkInterface
|
||||
|
||||
import qpd.networks.wrapper as wrapper
|
||||
|
||||
class ModelWrapper(NetworkInterface):
|
||||
algorithm_type = None
|
||||
|
||||
def __init__(self, config: "config.Config", model):
|
||||
super(ModelWrapper, self).__init__(config)
|
||||
self.original_type = type(model)
|
||||
self.policy = None
|
||||
|
||||
def construct_student(self):
|
||||
raise NotImplementedError("Construct student method not implemented!")
|
||||
|
||||
@staticmethod
|
||||
def construct_wrapper(config, model):
|
||||
return next(teacher for teacher in wrapper.registered_teachers if teacher.algorithm_type == type(model))(config, model)
|
|
@ -0,0 +1,16 @@
|
|||
# from .sb_wrapper import SBWrapper
|
||||
# from .value_iteration import WrapperValueIteration
|
||||
# from .on_actor_critic import WrapperONPActorCritic
|
||||
# from .qrdqn import WrapperQRDQN
|
||||
# from .ppo import WrapperPPO
|
||||
# from .dqn import WrapperDQN
|
||||
# from .a2c import WrapperA2C
|
||||
# from .sac import WrapperSAC
|
||||
import qpd.networks.wrapper.stable_baselines.sb_wrapper
|
||||
import qpd.networks.wrapper.stable_baselines.value_iteration
|
||||
import qpd.networks.wrapper.stable_baselines.on_actor_critic
|
||||
import qpd.networks.wrapper.stable_baselines.qrdqn
|
||||
import qpd.networks.wrapper.stable_baselines.ppo
|
||||
import qpd.networks.wrapper.stable_baselines.dqn
|
||||
import qpd.networks.wrapper.stable_baselines.a2c
|
||||
import qpd.networks.wrapper.stable_baselines.sac
|
|
@ -0,0 +1,8 @@
|
|||
import qpd.networks.wrapper.stable_baselines.on_actor_critic as sboac
|
||||
|
||||
from ..wrapper_decorator import wrapper
|
||||
from stable_baselines3 import A2C
|
||||
|
||||
@wrapper
|
||||
class WrapperA2C(sboac.WrapperONPActorCritic):
|
||||
algorithm_type = A2C
|
|
@ -0,0 +1,18 @@
|
|||
import numpy as np
|
||||
import gym
|
||||
|
||||
from stable_baselines3 import DQN
|
||||
from stable_baselines3.common.preprocessing import maybe_transpose
|
||||
from stable_baselines3.common.utils import is_vectorized_observation
|
||||
import qpd.networks.wrapper.stable_baselines.value_iteration as sbvi
|
||||
|
||||
from ..wrapper_decorator import wrapper
|
||||
|
||||
@wrapper
|
||||
class WrapperDQN(sbvi.WrapperValueIteration):
|
||||
algorithm_type = DQN
|
||||
|
||||
def forward(self, x):
|
||||
self.observation = x
|
||||
self.q_values = self.policy.q_net(self.observation)
|
||||
return self.q_values, None, None
|
|
@ -0,0 +1,25 @@
|
|||
from typing import Type, Union
|
||||
|
||||
import qpd.networks.wrapper.stable_baselines.sb_wrapper as sbw
|
||||
from qpd.networks.wrapper.wrapper_decorator import wrapper
|
||||
|
||||
from qpd.networks.wrapper.model_wrapper import ModelWrapper
|
||||
from stable_baselines3.sac.sac import SAC
|
||||
import torch
|
||||
|
||||
class WrapperOFFPActorCritic(sbw.SBWrapper):
|
||||
def forward(self, obs):
|
||||
mean_actions, log_std, kwargs = self.policy.actor.get_action_dist_params(obs)
|
||||
self.actions = self.policy.actor.action_dist.actions_from_params(mean_actions, log_std, deterministic=self.config.evaluator_config.deterministic, **kwargs)
|
||||
|
||||
value = torch.cat(self.policy.critic(obs, self.actions), dim=1)
|
||||
|
||||
action_std = None
|
||||
|
||||
if self.config.compression_config.distribution == "Std":
|
||||
action_std = torch.ones_like(mean_actions) * log_std.exp()
|
||||
|
||||
return mean_actions, value, action_std
|
||||
|
||||
def _get_actions(self):
|
||||
return self.actions
|
|
@ -0,0 +1,62 @@
|
|||
import torch
|
||||
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
from stable_baselines3.common.distributions import (
|
||||
BernoulliDistribution,
|
||||
CategoricalDistribution,
|
||||
DiagGaussianDistribution,
|
||||
MultiCategoricalDistribution,
|
||||
StateDependentNoiseDistribution
|
||||
)
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
import qpd.networks.wrapper.stable_baselines.sb_wrapper as sbw
|
||||
|
||||
class WrapperONPActorCritic(sbw.SBWrapper):
|
||||
def forward(self, obs):
|
||||
features = self.policy.extract_features(obs)
|
||||
if self.policy.share_features_extractor:
|
||||
self.latent_pi, latent_vf = self.policy.mlp_extractor(features)
|
||||
else:
|
||||
pi_features, vf_features = features
|
||||
self.latent_pi = self.policy.mlp_extractor.forward_actor(pi_features)
|
||||
latent_vf = self.policy.mlp_extractor.forward_critic(vf_features)
|
||||
|
||||
self.mean_actions = self.policy.action_net(self.latent_pi)
|
||||
value = self.policy.value_net(latent_vf)
|
||||
|
||||
action_std = None
|
||||
|
||||
if self.config.compression_config.distribution == "Std":
|
||||
if isinstance(self.policy.action_dist, StateDependentNoiseDistribution):
|
||||
variance = torch.mm(self.latent_pi ** 2, self.policy.action_dist.get_std(self.policy.log_std) ** 2)
|
||||
action_std = torch.sqrt(variance + self.policy.action_dist.epsilon)
|
||||
|
||||
elif isinstance(self.policy.action_dist, DiagGaussianDistribution):
|
||||
action_std = torch.ones_like(self.mean_actions) * self.policy.log_std.exp()
|
||||
elif isinstance(self.policy.action_dist, CategoricalDistribution):
|
||||
action_std = None # In case of discrete action space
|
||||
else:
|
||||
raise RuntimeError("Invalid distribution type!")
|
||||
|
||||
return self.mean_actions, value, action_std
|
||||
|
||||
def _get_actions(self):
|
||||
distribution = None
|
||||
|
||||
if isinstance(self.policy.action_dist, DiagGaussianDistribution):
|
||||
distribution = self.policy.action_dist.proba_distribution(self.mean_actions, self.policy.log_std)
|
||||
elif isinstance(self.policy.action_dist, CategoricalDistribution):
|
||||
# Here mean_actions are the logits before the softmax
|
||||
distribution = self.policy.action_dist.proba_distribution(action_logits=self.mean_actions)
|
||||
elif isinstance(self.policy.action_dist, MultiCategoricalDistribution):
|
||||
# Here mean_actions are the flattened logits
|
||||
distribution = self.policy.action_dist.proba_distribution(action_logits=self.mean_actions)
|
||||
elif isinstance(self.policy.action_dist, BernoulliDistribution):
|
||||
# Here mean_actions are the logits (before rounding to get the binary actions)
|
||||
distribution = self.policy.action_dist.proba_distribution(action_logits=self.mean_actions)
|
||||
elif isinstance(self.policy.action_dist, StateDependentNoiseDistribution):
|
||||
distribution = self.policy.action_dist.proba_distribution(self.mean_actions, self.policy.log_std, self.latent_pi)
|
||||
else:
|
||||
raise ValueError("Invalid action distribution")
|
||||
|
||||
return distribution.get_actions(self.config.evaluator_config.deterministic)
|
|
@ -0,0 +1,10 @@
|
|||
import qpd.networks.wrapper.stable_baselines.on_actor_critic as sboac
|
||||
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from ..wrapper_decorator import wrapper
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
@wrapper
|
||||
class WrapperPPO(sboac.WrapperONPActorCritic):
|
||||
algorithm_type = PPO
|
|
@ -0,0 +1,12 @@
|
|||
from ..wrapper_decorator import wrapper
|
||||
import qpd.networks.wrapper.stable_baselines.value_iteration as sbvi
|
||||
from sb3_contrib import QRDQN
|
||||
|
||||
@wrapper
|
||||
class WrapperQRDQN(sbvi.WrapperValueIteration):
|
||||
algorithm_type = QRDQN
|
||||
|
||||
def forward(self, obs):
|
||||
self.observation = obs
|
||||
self.q_values = self.policy.quantile_net(self.observation)
|
||||
return self.q_values.view(-1, self.policy.n_quantiles, self.policy.action_space.n).mean(dim=1), None, None
|
|
@ -0,0 +1,8 @@
|
|||
import qpd.networks.wrapper.stable_baselines.off_actor_critic as sbac
|
||||
from qpd.networks.wrapper.wrapper_decorator import wrapper
|
||||
|
||||
from stable_baselines3.sac.sac import SAC
|
||||
|
||||
@wrapper
|
||||
class WrapperSAC(sbac.WrapperOFFPActorCritic):
|
||||
algorithm_type = SAC
|
|
@ -0,0 +1,67 @@
|
|||
from typing import Type, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import qpd.networks.wrapper.stable_baselines.student as student
|
||||
import qpd.networks.wrapper.model_wrapper as model_wrapper
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
|
||||
class SBWrapper(model_wrapper.ModelWrapper):
|
||||
algorithm_type: Type[BaseAlgorithm] = None
|
||||
|
||||
def __init__(self, config, model: Union[BaseAlgorithm | BasePolicy]):
|
||||
super(SBWrapper, self).__init__(config, model)
|
||||
self.policy: BasePolicy = model.policy if isinstance(model, BaseAlgorithm) else model
|
||||
assert isinstance(self.policy, BasePolicy), "Incorrect input model type!"
|
||||
|
||||
def forward(self, x):
|
||||
raise NotImplementedError("Not implemented forward method!")
|
||||
|
||||
def _get_actions(self):
|
||||
raise NotImplementedError("Not implemented _getactions method!")
|
||||
|
||||
def observation_to_tensor(self, x):
|
||||
return self.policy.obs_to_tensor(x)[0]
|
||||
|
||||
def save(self, path):
|
||||
student = self.construct_student()
|
||||
student.policy = self.policy
|
||||
|
||||
student.save(path + ".zip")
|
||||
|
||||
def load(self, path):
|
||||
self.policy = self.algorithm_type.load(path).policy
|
||||
|
||||
def get_actions(self):
|
||||
with torch.no_grad():
|
||||
actions = self._get_actions()
|
||||
|
||||
actions = actions.cpu().numpy()
|
||||
if isinstance(self.policy.action_space, gym.spaces.Box):
|
||||
if self.policy.squash_output:
|
||||
# Rescale to proper domain when using squashing
|
||||
actions = self.policy.unscale_action(actions)
|
||||
else:
|
||||
# Actions could be on arbitrary scale, so clip the actions to avoid
|
||||
# out of bound error (e.g. if sampling from a Gaussian distribution)
|
||||
actions = np.clip(actions, self.policy.action_space.low, self.policy.action_space.high)
|
||||
|
||||
return actions
|
||||
|
||||
def construct_student(self):
|
||||
tmp = {
|
||||
"learning_rate": self.config.student_config.learning_rate,
|
||||
"policy_kwargs": {
|
||||
"net_arch": [],
|
||||
"features_extractor_class": student.SBStudentNet,
|
||||
"features_extractor_kwargs": {
|
||||
"network": self.config.student_config.student_network
|
||||
},
|
||||
"activation_fn": self.config.student_config.activation_func
|
||||
}
|
||||
}
|
||||
tmp.update(self.config.student_config.extra_student_kwargs)
|
||||
return self.original_type(type(self.policy), self.config.environment_constructor(), **tmp)
|
|
@ -0,0 +1,21 @@
|
|||
from typing import Type
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
||||
|
||||
from qpd.networks.wrapper.student.student_net import StudentNet
|
||||
|
||||
from ..student.cnn_student import CNNStudentNet
|
||||
|
||||
class SBStudentNet(BaseFeaturesExtractor):
|
||||
def __init__(self, observation_space: spaces.Space, features_dim: int = 16, network: Type[StudentNet]=CNNStudentNet):
|
||||
super(SBStudentNet, self).__init__(observation_space, features_dim)
|
||||
|
||||
self.network = network(observation_space, features_dim)
|
||||
|
||||
self._features_dim = self.network.features_dim
|
||||
|
||||
def forward(self, observations: th.Tensor) -> th.Tensor:
|
||||
return self.network(observations)
|
|
@ -0,0 +1,22 @@
|
|||
import numpy as np
|
||||
import gym
|
||||
|
||||
from stable_baselines3.common.preprocessing import maybe_transpose
|
||||
from stable_baselines3.common.utils import is_vectorized_observation
|
||||
|
||||
import qpd.networks.wrapper.stable_baselines.sb_wrapper as sbw
|
||||
|
||||
class WrapperValueIteration(sbw.SBWrapper):
|
||||
def _get_actions(self):
|
||||
if not self.config.evaluator_config.deterministic and np.random.rand() < self.config.evaluator_config.exploration_rate:
|
||||
if is_vectorized_observation(maybe_transpose(self.observation, self.policy.observation_space), self.policy.observation_space):
|
||||
if isinstance(self.policy.observation_space, gym.spaces.Dict):
|
||||
n_batch = self.observation[list(self.observation.keys())[0]].shape[0]
|
||||
else:
|
||||
n_batch = self.observation.shape[0]
|
||||
action = np.array([self.action_space.sample() for _ in range(n_batch)])
|
||||
else:
|
||||
action = np.array(self.action_space.sample())
|
||||
else:
|
||||
action = self.q_values.argmax(dim=1).reshape(-1)
|
||||
return action
|
|
@ -0,0 +1,39 @@
|
|||
from typing import Type
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
from gym import spaces
|
||||
|
||||
import qpd.networks.wrapper.student.student_net as student
|
||||
|
||||
class CNNStudentNet(student.StudentNet):
|
||||
def __init__(self, observation_space: spaces.Box, n_output):
|
||||
super(CNNStudentNet, self).__init__(observation_space, n_output)
|
||||
# We assume CxHxW images (channels first)
|
||||
# Re-ordering will be done by pre-preprocessing or wrapper
|
||||
n_input_channels = observation_space.shape[0]
|
||||
|
||||
self.cnn = nn.Sequential(
|
||||
nn.Conv2d(n_input_channels, 16, kernel_size=8, stride=4),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Conv2d(16, 16, kernel_size=4, stride=2),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Conv2d(16, 16, kernel_size=3, stride=1),
|
||||
nn.ReLU(),
|
||||
nn.Flatten()
|
||||
)
|
||||
|
||||
# Compute shape by doing one forward pass
|
||||
with th.no_grad():
|
||||
n_flatten = self.cnn(
|
||||
th.as_tensor(observation_space.sample()[None]).float()
|
||||
).shape[1]
|
||||
|
||||
self.linear = nn.Sequential(
|
||||
nn.Linear(n_flatten, self.features_dim),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, observations: th.Tensor) -> th.Tensor:
|
||||
return self.linear(self.cnn(observations))
|
|
@ -0,0 +1,31 @@
|
|||
from typing import Type
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
from gym import spaces
|
||||
|
||||
import qpd.networks.wrapper.student.student_net as student
|
||||
|
||||
class FCStudentNet(student.StudentNet):
|
||||
def __init__(self, observation_space: spaces.Box, n_output: int):
|
||||
super(FCStudentNet, self).__init__(observation_space, 64)
|
||||
|
||||
n_input_channels = observation_space.shape[0]
|
||||
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(n_input_channels, 64),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Linear(64, 64),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Linear(64, self.features_dim),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
for module in self.children():
|
||||
if type(module) == nn.Linear:
|
||||
module.bias.data.uniform_(0.0) # type: ignore
|
||||
module.weight.data.uniform_(0, 0.01) # type: ignore
|
||||
|
||||
def forward(self, observations: th.Tensor) -> th.Tensor:
|
||||
return self.net(observations)
|
|
@ -0,0 +1,14 @@
|
|||
from typing import Type
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
from gym import spaces
|
||||
|
||||
class StudentNet(nn.Module):
|
||||
def __init__(self, observation_space: spaces.Box, n_output: int):
|
||||
super(StudentNet, self).__init__()
|
||||
|
||||
self._n_output = n_output
|
||||
|
||||
@property
|
||||
def features_dim(self):
|
||||
return self._n_output
|
|
@ -0,0 +1,5 @@
|
|||
from . import registered_teachers
|
||||
|
||||
def wrapper(cls):
|
||||
registered_teachers.append(cls)
|
||||
return cls
|
|
@ -0,0 +1,4 @@
|
|||
from .configuration import get_config, get_config_mods, merge_config, import_config
|
||||
from .summary import *
|
||||
from .model_utils import get_network
|
||||
from .webdav import WebdavManager
|
|
@ -0,0 +1,91 @@
|
|||
from argparse import ArgumentParser
|
||||
import copy
|
||||
|
||||
def convert_nested_insert(dct, lst):
|
||||
for x in lst[:-2]:
|
||||
dct[x] = dct = dct.get(x, dict())
|
||||
dct.update({lst[-2]: lst[-1]})
|
||||
|
||||
|
||||
def convert_nested(dct):
|
||||
# empty dict to store the result
|
||||
result = dict()
|
||||
|
||||
# create an iterator of lists
|
||||
# representing nested or hierarchical flow
|
||||
lsts = ([*k.split("."), v] for k, v in dct.items())
|
||||
|
||||
# insert each list into the result
|
||||
for lst in lsts:
|
||||
convert_nested_insert(result, lst)
|
||||
return result
|
||||
|
||||
|
||||
def merge_config(source, destination):
|
||||
destination = copy.deepcopy(destination)
|
||||
for key, value in source.items():
|
||||
if isinstance(value, dict):
|
||||
if key in destination and isinstance(destination[key], dict):
|
||||
destination_value = destination[key]
|
||||
else:
|
||||
destination_value = {}
|
||||
destination[key] = merge_config(value, destination_value)
|
||||
else:
|
||||
try:
|
||||
value = eval(value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
destination[key] = value
|
||||
|
||||
return destination
|
||||
|
||||
|
||||
def get_config_mods(known_args):
|
||||
config_mods = dict(zip(map(lambda x: x[2:], known_args[1][:-1:2]), known_args[1][1::2]))
|
||||
config_mods = convert_nested(config_mods)
|
||||
return config_mods
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--config', type=str, default="experiment")
|
||||
|
||||
known_args = parser.parse_known_args()
|
||||
|
||||
args = known_args[0]
|
||||
config_mods = get_config_mods(known_args)
|
||||
|
||||
return args, config_mods
|
||||
|
||||
|
||||
def get_config():
|
||||
args, config_mods = parse_args()
|
||||
|
||||
try:
|
||||
config_module = __import__(f"configs.{args.config}", fromlist=[''])
|
||||
configs = config_module.config
|
||||
|
||||
if config_mods is not None:
|
||||
configs = merge_config(config_mods, configs)
|
||||
|
||||
return configs
|
||||
except ModuleNotFoundError as ex:
|
||||
print(f"ERROR: Could not find the experiment config file: {ex}")
|
||||
quit()
|
||||
|
||||
|
||||
def import_config(config, experiment_config):
|
||||
config["network"]["conv1"] = experiment_config["network"]["conv1"]
|
||||
config["network"]["conv2"] = experiment_config["network"]["conv2"]
|
||||
config["network"]["conv3"] = experiment_config["network"]["conv3"]
|
||||
config["network"]["linear"] = experiment_config["network"]["linear"]
|
||||
config["network"]["teacher"] = experiment_config["network"]["teacher"]
|
||||
if "critic_importance" in experiment_config["compression"]:
|
||||
config["compression"]["critic_importance"] = experiment_config["compression"]["critic_importance"]
|
||||
config["testing"]["randomness"] = experiment_config["testing"]["randomness"]
|
||||
config["quantization"]["enabled"] = experiment_config["quantization"]["enabled"]
|
||||
config["compression"]["epochs"] = experiment_config["compression"]["epochs"]
|
||||
config["train_run"] = experiment_config["run_name"]
|
||||
return config
|
||||
|
|
@ -0,0 +1,303 @@
|
|||
from abc import abstractmethod
|
||||
import json
|
||||
import torch
|
||||
import time
|
||||
import numpy
|
||||
import random
|
||||
import ray
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from torch.distributions import Normal
|
||||
|
||||
from qpd.config import Config
|
||||
from qpd.networks.wrapper.model_wrapper import ModelWrapper
|
||||
from qpd.utils.model_utils import get_output
|
||||
from ..compression.memory import Memory
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class Worker:
|
||||
idx: int = 0
|
||||
action_queue: List[int] = field(default_factory=list)
|
||||
memory: Optional[Memory] = None
|
||||
finished: bool = False
|
||||
steps: int = 0
|
||||
rewards: int = 0
|
||||
repeats: int = 0
|
||||
|
||||
def reset(self, number: int):
|
||||
self.action_queue = [0 for _ in range(int(numpy.round(random.random() * number)))]
|
||||
self.steps = 0
|
||||
self.rewards = 0
|
||||
self.repeats = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class Stats:
|
||||
repeats: int = 0
|
||||
frames: int = 0
|
||||
effective_frames: int = 0
|
||||
steps: List[int] = field(default_factory=list)
|
||||
rewards: List[int] = field(default_factory=list)
|
||||
|
||||
def update(self, worker: Worker):
|
||||
self.repeats += worker.repeats
|
||||
self.effective_frames += worker.steps
|
||||
self.steps.append(worker.steps - worker.repeats)
|
||||
self.rewards.append(worker.rewards)
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(self, config: Config, verbose: bool=False) -> None:
|
||||
self.config = config
|
||||
# self.env = create_env(config)
|
||||
self.verbose = verbose
|
||||
|
||||
ray.init(
|
||||
_system_config={
|
||||
"local_fs_capacity_threshold": 0.9999999999,
|
||||
"object_spilling_config": json.dumps(
|
||||
{
|
||||
"type": "filesystem",
|
||||
"params": {
|
||||
"directory_path": f"{config.data_directory}/spill"
|
||||
}
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
def run(self, network: torch.nn.Module, collect: bool=False, steps: int=0, episodes: int=0, teacher=None) -> Dict:
|
||||
network.eval()
|
||||
network.to(torch.device(self.config.evaluator_config.device))
|
||||
|
||||
@ray.remote(num_gpus=0.1)
|
||||
def run_ray(function, *args, **kwargs):
|
||||
return function(*args, **kwargs)
|
||||
|
||||
test_function = Evaluator.test_vec if self.config.is_vec_env else Evaluator.test
|
||||
|
||||
if self.config.evaluator_config.ray_workers:
|
||||
worker_steps = list(map(lambda x: len(x), numpy.array_split(numpy.arange(steps), self.config.evaluator_config.ray_workers)))
|
||||
worker_episodes = list(map(lambda x: len(x), numpy.array_split(numpy.arange(episodes), self.config.evaluator_config.ray_workers)))
|
||||
old_metrics = ray.get(
|
||||
[run_ray.remote(
|
||||
test_function,
|
||||
self.config,
|
||||
network,
|
||||
steps=s,
|
||||
episodes=e,
|
||||
collect=collect,
|
||||
verbose=self.verbose,
|
||||
teacher=teacher
|
||||
) for s, e in zip(worker_steps, worker_episodes)]
|
||||
)
|
||||
metrics = {item: sum(list(map(lambda x: x[item], old_metrics)), []) for item in ["Memory"]}
|
||||
metrics.update({item: numpy.mean(list(map(lambda x: x[item], old_metrics))) for item in ["Reward", "Steps", "EpisodeCount"]})
|
||||
# metrics = {item: sum(list(map(lambda x: x[item], old_metrics)), []) for item in ["Memory", "Reward", "Steps"]}
|
||||
else:
|
||||
metrics = test_function(
|
||||
self.config,
|
||||
network,
|
||||
steps=steps,
|
||||
episodes=episodes,
|
||||
collect=collect,
|
||||
verbose=self.verbose,
|
||||
teacher=teacher)
|
||||
|
||||
network.train()
|
||||
return metrics
|
||||
|
||||
@abstractmethod
|
||||
def test(config: Config,
|
||||
model_wrapper: ModelWrapper,
|
||||
steps: int=0,
|
||||
episodes: int=0,
|
||||
collect: bool=False,
|
||||
verbose=True,
|
||||
teacher: ModelWrapper=None) -> Dict:
|
||||
|
||||
env = config.environment_constructor()
|
||||
# model = None
|
||||
# model = model_wrapper.original_type(type(model_wrapper.policy), env, device=config.evaluator_config.device, **config.build_student_kwargs())
|
||||
# model.policy = model_wrapper.policy
|
||||
|
||||
if teacher:
|
||||
teacher.to(device=config.evaluator_config.device)
|
||||
|
||||
total_rewards = []
|
||||
total_steps = []
|
||||
frames = 0
|
||||
effective_frames = 0
|
||||
|
||||
with torch.no_grad():
|
||||
env_state = env.reset()
|
||||
|
||||
worker_action_queue = [0 for _ in range(int(numpy.round(random.random() * config.evaluator_config.initialize)))]
|
||||
worker_steps_since_last_reward = 0
|
||||
worker_steps = 0
|
||||
worker_finished = False
|
||||
|
||||
worker_memory = None
|
||||
if collect:
|
||||
worker_memory = Memory(config, size=int(steps) + 10000)
|
||||
|
||||
start = time.time()
|
||||
if verbose:
|
||||
print(f"Start workers")
|
||||
while not numpy.all(worker_finished):
|
||||
current_actions = []
|
||||
|
||||
state, _ = model_wrapper.observation_to_tensor(env_state)
|
||||
action, value, action_std = model_wrapper(state)
|
||||
|
||||
if len(worker_action_queue) > 0:
|
||||
current_actions.append(worker_action_queue.pop())
|
||||
else:
|
||||
current_actions = model_wrapper.get_actions()
|
||||
|
||||
if collect:
|
||||
if teacher and teacher is not model_wrapper:
|
||||
action, value, action_std = teacher(state)
|
||||
|
||||
if not worker_finished:
|
||||
if worker_memory.will_overflow:
|
||||
print(f"Growing memory to size {worker_memory.max_size + 5000}!")
|
||||
worker_memory.grow(5000)
|
||||
|
||||
worker_memory.update(
|
||||
state.unsqueeze(0),
|
||||
get_output(action, value, action_std),
|
||||
worker_steps == 0
|
||||
)
|
||||
|
||||
env_state, reward, done, infos = env.step(current_actions) #type:ignore
|
||||
|
||||
frames += 1
|
||||
|
||||
worker_steps += 1
|
||||
if reward > 0:
|
||||
worker_steps_since_last_reward = 0
|
||||
else:
|
||||
worker_steps_since_last_reward += 1
|
||||
|
||||
if done:
|
||||
if config.evaluator_config.environment_is_done_filter(infos):
|
||||
assert "episode" in infos, "No episode key in infos dict from gym environment"
|
||||
total_rewards.append(infos["episode"]["r"])
|
||||
worker_action_queue = [0 for _ in range(int(numpy.round(random.random() * config.evaluator_config.initialize)))]
|
||||
effective_frames += worker_steps
|
||||
total_steps.append(worker_steps)
|
||||
worker_steps = 0
|
||||
|
||||
if (steps and effective_frames >= steps) or (episodes and len(total_rewards)):
|
||||
worker_finished = True
|
||||
|
||||
if collect:
|
||||
worker_memory.prune()
|
||||
|
||||
if verbose:
|
||||
end = time.time()
|
||||
print("Average and effective framerate:", round(frames / (end-start), 3), round(effective_frames / (end-start), 3))
|
||||
|
||||
return {"Steps": numpy.mean(total_steps) if len(total_steps) else 0, "Reward": numpy.mean(total_rewards) if len(total_rewards) else 0, "Memory": worker_memory if collect else []}
|
||||
|
||||
@abstractmethod
|
||||
def test_vec(config: Config,
|
||||
model_wrapper: ModelWrapper,
|
||||
steps: int=0,
|
||||
episodes: int=0,
|
||||
collect: bool=False,
|
||||
verbose=True,
|
||||
teacher: ModelWrapper=None) -> Dict:
|
||||
|
||||
env = config.environment_constructor()
|
||||
|
||||
model_wrapper.to(device=config.evaluator_config.device)
|
||||
if teacher:
|
||||
teacher.eval()
|
||||
teacher.to(device=config.evaluator_config.device)
|
||||
|
||||
total_rewards = []
|
||||
total_steps = []
|
||||
frames = 0
|
||||
effective_frames = 0
|
||||
|
||||
workers = config.evaluator_config.env_workers
|
||||
|
||||
with torch.no_grad():
|
||||
state = env.reset()
|
||||
worker_action_queue = [[0 for _ in range(int(numpy.round(random.random() * config.evaluator_config.initialize)))] for _ in range(workers)]
|
||||
workers_steps_since_last_reward = [0 for _ in range(workers)]
|
||||
worker_steps = [0 for _ in range(workers)]
|
||||
worker_finished = [False for _ in range(workers)]
|
||||
|
||||
worker_memories = []
|
||||
if collect:
|
||||
worker_memories = [Memory(config, size=int(steps / workers) + 10000) for _ in range(workers)]
|
||||
|
||||
start = time.time()
|
||||
if verbose:
|
||||
print(f"Start workers")
|
||||
while not numpy.all(worker_finished):
|
||||
current_actions = []
|
||||
state = model_wrapper.observation_to_tensor(state).float()
|
||||
action, value, action_std = model_wrapper(state)
|
||||
|
||||
if max([len(q) for q in worker_action_queue]) > 0:
|
||||
for worker in range(workers):
|
||||
if len(worker_action_queue[worker]):
|
||||
current_actions.append(worker_action_queue[worker].pop())
|
||||
else:
|
||||
current_actions = model_wrapper.get_actions()
|
||||
|
||||
if collect:
|
||||
if teacher and teacher is not model_wrapper:
|
||||
action, value, action_std = teacher(state)
|
||||
|
||||
for worker in range(workers):
|
||||
if not worker_finished[worker]:
|
||||
if worker_memories[worker].will_overflow():
|
||||
print(f"Growing memory to size {worker_memories[worker].max_size + 5000}!")
|
||||
worker_memories[worker].grow(5000)
|
||||
|
||||
outputs = get_output(action[worker], value[worker] if value is not None else None, action_std[worker] if action_std is not None else None)
|
||||
|
||||
worker_memories[worker].update(state[worker].unsqueeze(0), outputs, worker_steps[worker] == 0)
|
||||
|
||||
state, reward, done, infos = env.step(current_actions) #type:ignore
|
||||
|
||||
frames += workers
|
||||
for worker in range(workers):
|
||||
worker_steps[worker] += 1
|
||||
if reward[worker] > 0:
|
||||
workers_steps_since_last_reward[worker] = 0
|
||||
else:
|
||||
workers_steps_since_last_reward[worker] += 1
|
||||
|
||||
if done[worker] and not worker_finished[worker]:
|
||||
# print(f"Done: {infos}")
|
||||
if config.evaluator_config.environment_is_done_filter(infos[worker]):
|
||||
assert "episode" in infos[worker]
|
||||
total_rewards.append(infos[worker]["episode"]["r"])
|
||||
worker_action_queue[worker] = [0 for _ in range(int(numpy.round(random.random() * config.evaluator_config.initialize)))]
|
||||
effective_frames += worker_steps[worker]
|
||||
total_steps.append(worker_steps[worker])
|
||||
worker_steps[worker] = 0
|
||||
|
||||
if (steps and effective_frames >= steps) or (episodes and len(total_rewards) + workers - 1 >= episodes):
|
||||
if verbose:
|
||||
print(f"Finished with {effective_frames} steps and {len(total_rewards)} episodes!")
|
||||
worker_finished[worker] = True
|
||||
|
||||
|
||||
if collect:
|
||||
for worker in range(workers):
|
||||
worker_memories[worker].prune()
|
||||
|
||||
if verbose:
|
||||
end = time.time()
|
||||
print("Average and effective framerate:", round(frames / (end-start), 3), round(effective_frames / (end-start), 3))
|
||||
|
||||
return {"Steps": numpy.mean(total_steps) if len(total_steps) else 0, "Reward": numpy.mean(total_rewards) if len(total_rewards) else 0, "Memory": worker_memories if collect else [], "EpisodeCount": len(total_rewards)/workers}
|
|
@ -0,0 +1,47 @@
|
|||
from distiller.quantization import DorefaQuantizer
|
||||
from .summary import summary
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import os
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, x):
|
||||
return x.reshape(x.size(0), -1)
|
||||
|
||||
|
||||
def init_module(module, weight_init, bias_init, gain=1):
|
||||
weight_init(module.weight.data, gain=gain)
|
||||
bias_init(module.bias.data)
|
||||
return module
|
||||
|
||||
|
||||
def get_output_size(input_shape, layer):
|
||||
sample = torch.zeros(size=(1, *input_shape))
|
||||
return layer(sample).view(1, -1).size(1)
|
||||
|
||||
|
||||
def get_network(config, Network, load=None, quantization=0, verbose=True):
|
||||
network = Network(config)
|
||||
|
||||
if verbose:
|
||||
summary(network)
|
||||
if quantization:
|
||||
bits = quantization
|
||||
optimizer = torch.optim.RMSprop(network.parameters(), lr=0)
|
||||
quantizer = DorefaQuantizer(network, optimizer, bits_activations=bits, bits_weights=bits, bits_bias=bits)
|
||||
quantizer.prepare_model()
|
||||
quantizer.quantize_params()
|
||||
|
||||
if load and (type(load) != str or os.path.exists(load)):
|
||||
network = network.load(load)
|
||||
|
||||
network.eval()
|
||||
return network
|
||||
|
||||
def get_output(action, value, action_std):
|
||||
if action_std is not None:
|
||||
return torch.cat([action, value, action_std], dim=0)
|
||||
if value is not None:
|
||||
return torch.cat([action, value], dim=0)
|
||||
return action
|
|
@ -0,0 +1,52 @@
|
|||
import os
|
||||
from collections import OrderedDict
|
||||
from distiller.quantization import DorefaQuantizer
|
||||
import numpy as np
|
||||
from onnxruntime.quantization import quantize_dynamic
|
||||
import torch
|
||||
|
||||
def filter_state_dict(state_dict: OrderedDict):
|
||||
new = OrderedDict()
|
||||
for key, state in state_dict.items():
|
||||
key_split = key.split(".")
|
||||
|
||||
state_type = key_split[-1]
|
||||
|
||||
if state_type == "float_weight" or state_type == "float_bias":
|
||||
continue
|
||||
|
||||
new[key] = state
|
||||
return new
|
||||
|
||||
|
||||
def clean_student(student, config):
|
||||
clean_student = type(student)(config)
|
||||
clean_student.load_state_dict(filter_state_dict(student.state_dict()))
|
||||
return clean_student
|
||||
|
||||
|
||||
def export_student_as_onnx(input, output, student_type, config, get_environment, verbose=True, onnx_input_names=["input_1"], onnx_output_names=["output_1"], opset_version=14, qat=True):
|
||||
# Loading quantized student
|
||||
student = student_type(config)
|
||||
bits = config.quantization_config.bits
|
||||
|
||||
if qat:
|
||||
q = DorefaQuantizer(student, torch.optim.RMSprop(student.parameters(), lr=config.student_config.learning_rate), bits, bits, bits)
|
||||
q.prepare_model()
|
||||
|
||||
student.load(input)
|
||||
|
||||
c_student = clean_student(student, config)
|
||||
|
||||
env = get_environment(config)
|
||||
state = env.reset().astype(np.float32)
|
||||
|
||||
path = output.split("/")[:-1]
|
||||
|
||||
orig_onnx = "/".join(path) + "/" + "orig.onnx"
|
||||
temp_onnx = "/".join(path) + "/" + "clean.onnx"
|
||||
|
||||
torch.onnx.export(student, torch.from_numpy(state), orig_onnx, verbose=verbose, input_names=onnx_input_names, output_names=onnx_output_names, opset_version=opset_version)
|
||||
torch.onnx.export(c_student, torch.from_numpy(state), temp_onnx, verbose=verbose, input_names=onnx_input_names, output_names=onnx_output_names, opset_version=opset_version)
|
||||
quantize_dynamic(orig_onnx, output)
|
||||
# os.remove(temp_onnx)
|
|
@ -0,0 +1,40 @@
|
|||
from typing import List, Any, Tuple
|
||||
import torch
|
||||
|
||||
def summary(network):
|
||||
print("======================================")
|
||||
print("= Network Summary =")
|
||||
print("======================================")
|
||||
print(print_layers([(network.__class__.__name__, extract_parameters(network.parameters()), find_layers(network))]))
|
||||
|
||||
def find_layers(network) -> List[Tuple[str, Any]]:
|
||||
layers = []
|
||||
for child in network.children():
|
||||
children = list(child.children())
|
||||
if not len(children):
|
||||
layers.append((child.__class__.__name__, extract_parameters(child.parameters())))
|
||||
else:
|
||||
layers.append((child.__class__.__name__, extract_parameters(child.parameters()), find_layers(child)))
|
||||
return layers
|
||||
|
||||
def extract_parameters(parameters) -> int:
|
||||
total = 0
|
||||
for param in parameters:
|
||||
if isinstance(param, torch.Tensor):
|
||||
total += torch.prod(torch.LongTensor(list(param.size())))
|
||||
else:
|
||||
total += extract_parameters(param)
|
||||
return int(total)
|
||||
|
||||
|
||||
def print_layers(layers) -> str:
|
||||
result = ""
|
||||
for layer in layers:
|
||||
result += "- " + layer[0]
|
||||
if len(layer) == 2:
|
||||
result += f': {layer[1]:n}\n'
|
||||
else:
|
||||
result += f': {layer[1]:n}\n'
|
||||
for line in print_layers(layer[2]).splitlines():
|
||||
result += " %s\n" % line
|
||||
return result
|
Loading…
Reference in New Issue