QPD/src/qpd/compression/memory.py

194 lines
7.8 KiB
Python

from __future__ import annotations
import torch
import random
import numpy
import gzip
import json
from pathlib import Path
from typing import Iterator, Tuple, Dict, Optional
from qpd.config import Config
class Memory:
def __init__(self, config: Config, size=None, meta_class = None, path: Optional[str] = None) -> None:
self.meta_name = str(meta_class.__class__) if meta_class else None
self.max_size = size if size else config.memory_config.size
self.did_overflow = False
self.current_index = 0
self.output_n = config.output_shape[0]
self.device = config.memory_config.device
self.check_consistency = config.memory_config.check_consistency
self.critic = config.compression_config.critic_importance > 0
self.frame_stack_optim = config.memory_config.frame_stack_optimization
self.framestack_size = config.observation_shape[0] if self.frame_stack_optim else 0
self.student_driven = config.evaluator_config.student_driven
if path:
print(f"Load memory on path: {path}")
self.load(path)
else:
self.states = torch.zeros((self.max_size,) + (config.observation_shape[1:] if self.frame_stack_optim else config.observation_shape), device=self.device) # [(Input frame, output actions)]
self.outputs = torch.zeros((self.max_size, self.output_n), device=self.device) # + 1 = critic
# self.repeats = torch.zeros(self.size, device=self.device)
self.start = set()
def update(self, state: torch.Tensor, outputs: torch.Tensor, start: bool) -> None:
assert not torch.any(outputs == 0), f"[ERROR]: Update input is not valid!"
# update overflow
self.did_overflow = self.will_overflow()
if start:
self.start.add(self.current_index)
else:
self.start.discard(self.current_index)
state = state.to(self.device).data
self.states[self.current_index] = state[:, -1] if self.frame_stack_optim else state # Story only latest frame
self.outputs[self.current_index] = outputs.to(self.device).data
self.current_index = (self.current_index + 1) % self.max_size
if self.check_consistency:
self.verifyConsistency(f"Update current_index: {self.current_index} is start: {start} outputs: {self.outputs}")
def will_overflow(self, update_size=1):
if (self.current_index + update_size) >= self.max_size:
return True
return False
def real_size(self):
return (self.states.element_size() * self.states.nelement()) + (self.outputs.element_size() * self.outputs.nelement())
def sample(self, batch_size: int, shuffle: bool=True) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
state_indexes = list(range(len(self)))
if shuffle:
random.shuffle(state_indexes)
splits = [state_indexes[i:i + batch_size] for i in range(0, len(state_indexes), batch_size)]
if self.frame_stack_optim:
for split in splits:
current_states = []
indexes = torch.tensor(split)
for i in range(self.framestack_size):
current_states.append(self.states[(indexes - (self.framestack_size - 1) + i) % len(self)].unsqueeze(1))
yield (torch.cat(current_states, dim=1), self.outputs[split])
else:
for split in splits:
yield (self.states[split], self.outputs[split])
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
if self.frame_stack_optim:
states_index = list(self.start)[index]
return (torch.cat([self.states[i % len(self)].unsqueeze(0) for i in range(states_index-3, states_index+1)]), self.outputs[states_index])
else:
return (self.states[index], self.outputs[index])
def __len__(self) -> int:
if self.did_overflow:
return len(self.outputs)
else:
return self.current_index
def prune(self) -> None:
if not self.did_overflow:
for item in ["states", "outputs"]:
setattr(self, item, getattr(self, item)[0: self.current_index])
self.max_size = self.current_index
if self.check_consistency:
self.verifyConsistency("Prune")
def grow(self, amount: int) -> int:
print(f"Growing with size: {amount}")
for item in ["states", "outputs"]:
obj = getattr(self, item)
tmp = torch.zeros((self.max_size + amount,) + obj.shape[1:], device = self.device, dtype=obj.dtype)
tmp[:len(self)] = obj
setattr(self, item, tmp)
self.max_size += amount
self.did_overflow = False
if self.check_consistency:
self.verifyConsistency("Grow")
return len(self)
def load_memory(self, other: Memory) -> None:
assert other.max_size <= self.max_size, f"Memory to load is bigger than current memory! {other.max_size} > {self.max_size}"
self.did_overflow = self.will_overflow(len(other))
for item in ["states", "outputs"]:
getattr(self, item)[self.current_index:min(self.current_index+len(other), self.max_size)] = getattr(other, item)[:min(len(other), self.max_size - self.current_index)]
if self.current_index + len(other) > self.max_size:
getattr(self, item)[:len(other) - (self.max_size - self.current_index)] = getattr(other, item)[self.max_size - self.current_index:]
for idx in range(len(other)):
new_idx = (self.current_index + idx) % self.max_size
self.start.discard(new_idx)
if idx in other.start:
self.start.add(new_idx)
self.current_index = (self.current_index + len(other)) % self.max_size
if self.check_consistency:
self.verifyConsistency("Load memory")
def to(self, device: torch.device):
for item in ["states", "outputs"]:
setattr(self, item, getattr(self, item).to(device))
self.device = device
def load(self, path: str) -> None:
assert self.meta_name, "Can not load this memory! No meta_class specified!"
with open(f"{path}/metadata.json") as f:
j = json.loads(f.read())
assert j["teacher"] == self.meta_name
self.current_index = j["current_index"]
self.did_overflow = j["did_overflow"]
for item in ["states", "outputs"]:
with gzip.GzipFile(f"{path}/{item}.npy.gz") as f:
setattr(self, item, torch.from_numpy(numpy.load(file=f)))
with gzip.GzipFile(f"{path}/start.npy.gz") as f:
self.start = set(numpy.load(file=f))
self.max_size = self.states.shape[0]
if self.check_consistency:
self.verifyConsistency("Load")
def save(self, path: str) -> None:
assert self.meta_name, "Can not save this memory! No meta_class specified!"
Path(path).mkdir(parents=True, exist_ok=True)
for item in ["states", "outputs"]:
with gzip.GzipFile(f"{path}/{item}.npy.gz", "w") as f:
numpy.save(f, getattr(self, item).cpu().numpy())
with gzip.GzipFile(f"{path}/start.npy.gz", "w") as f:
numpy.save(f, numpy.array(list(self.start)))
with open(f"{path}/metadata.json", "w") as f:
f.write(json.dumps({"teacher": self.meta_name, "did_overflow": self.did_overflow, "current_index": self.current_index}))
def verifyConsistency(self, mes) -> None:
assert not torch.any(self.outputs[:len(self)] == 0), f"[ERROR]: Memory is not consistent! ({mes})"
# assert not torch.any(torch.bitwise_or(self.outputs[:len(self)] > 100, self.outputs[:len(self)] < -100)), \
# f"[ERROR]: Memory is not consistent in a weird way! max: {self.outputs.max()} min: {self.outputs.min()} ({mes})"