from __future__ import annotations import torch import random import numpy import gzip import json from pathlib import Path from typing import Iterator, Tuple, Dict, Optional from qpd.config import Config class Memory: def __init__(self, config: Config, size=None, meta_class = None, path: Optional[str] = None) -> None: self.meta_name = str(meta_class.__class__) if meta_class else None self.max_size = size if size else config.memory_config.size self.did_overflow = False self.current_index = 0 self.output_n = config.output_shape[0] self.device = config.memory_config.device self.check_consistency = config.memory_config.check_consistency self.critic = config.compression_config.critic_importance > 0 self.frame_stack_optim = config.memory_config.frame_stack_optimization self.framestack_size = config.observation_shape[0] if self.frame_stack_optim else 0 self.student_driven = config.evaluator_config.student_driven if path: print(f"Load memory on path: {path}") self.load(path) else: self.states = torch.zeros((self.max_size,) + (config.observation_shape[1:] if self.frame_stack_optim else config.observation_shape), device=self.device) # [(Input frame, output actions)] self.outputs = torch.zeros((self.max_size, self.output_n), device=self.device) # + 1 = critic # self.repeats = torch.zeros(self.size, device=self.device) self.start = set() def update(self, state: torch.Tensor, outputs: torch.Tensor, start: bool) -> None: assert not torch.any(outputs == 0), f"[ERROR]: Update input is not valid!" # update overflow self.did_overflow = self.will_overflow() if start: self.start.add(self.current_index) else: self.start.discard(self.current_index) state = state.to(self.device).data self.states[self.current_index] = state[:, -1] if self.frame_stack_optim else state # Story only latest frame self.outputs[self.current_index] = outputs.to(self.device).data self.current_index = (self.current_index + 1) % self.max_size if self.check_consistency: self.verifyConsistency(f"Update current_index: {self.current_index} is start: {start} outputs: {self.outputs}") def will_overflow(self, update_size=1): if (self.current_index + update_size) >= self.max_size: return True return False def real_size(self): return (self.states.element_size() * self.states.nelement()) + (self.outputs.element_size() * self.outputs.nelement()) def sample(self, batch_size: int, shuffle: bool=True) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]: state_indexes = list(range(len(self))) if shuffle: random.shuffle(state_indexes) splits = [state_indexes[i:i + batch_size] for i in range(0, len(state_indexes), batch_size)] if self.frame_stack_optim: for split in splits: current_states = [] indexes = torch.tensor(split) for i in range(self.framestack_size): current_states.append(self.states[(indexes - (self.framestack_size - 1) + i) % len(self)].unsqueeze(1)) yield (torch.cat(current_states, dim=1), self.outputs[split]) else: for split in splits: yield (self.states[split], self.outputs[split]) def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: if self.frame_stack_optim: states_index = list(self.start)[index] return (torch.cat([self.states[i % len(self)].unsqueeze(0) for i in range(states_index-3, states_index+1)]), self.outputs[states_index]) else: return (self.states[index], self.outputs[index]) def __len__(self) -> int: if self.did_overflow: return len(self.outputs) else: return self.current_index def prune(self) -> None: if not self.did_overflow: for item in ["states", "outputs"]: setattr(self, item, getattr(self, item)[0: self.current_index]) self.max_size = self.current_index if self.check_consistency: self.verifyConsistency("Prune") def grow(self, amount: int) -> int: print(f"Growing with size: {amount}") for item in ["states", "outputs"]: obj = getattr(self, item) tmp = torch.zeros((self.max_size + amount,) + obj.shape[1:], device = self.device, dtype=obj.dtype) tmp[:len(self)] = obj setattr(self, item, tmp) self.max_size += amount self.did_overflow = False if self.check_consistency: self.verifyConsistency("Grow") return len(self) def load_memory(self, other: Memory) -> None: assert other.max_size <= self.max_size, f"Memory to load is bigger than current memory! {other.max_size} > {self.max_size}" self.did_overflow = self.will_overflow(len(other)) for item in ["states", "outputs"]: getattr(self, item)[self.current_index:min(self.current_index+len(other), self.max_size)] = getattr(other, item)[:min(len(other), self.max_size - self.current_index)] if self.current_index + len(other) > self.max_size: getattr(self, item)[:len(other) - (self.max_size - self.current_index)] = getattr(other, item)[self.max_size - self.current_index:] for idx in range(len(other)): new_idx = (self.current_index + idx) % self.max_size self.start.discard(new_idx) if idx in other.start: self.start.add(new_idx) self.current_index = (self.current_index + len(other)) % self.max_size if self.check_consistency: self.verifyConsistency("Load memory") def to(self, device: torch.device): for item in ["states", "outputs"]: setattr(self, item, getattr(self, item).to(device)) self.device = device def load(self, path: str) -> None: assert self.meta_name, "Can not load this memory! No meta_class specified!" with open(f"{path}/metadata.json") as f: j = json.loads(f.read()) assert j["teacher"] == self.meta_name self.current_index = j["current_index"] self.did_overflow = j["did_overflow"] for item in ["states", "outputs"]: with gzip.GzipFile(f"{path}/{item}.npy.gz") as f: setattr(self, item, torch.from_numpy(numpy.load(file=f))) with gzip.GzipFile(f"{path}/start.npy.gz") as f: self.start = set(numpy.load(file=f)) self.max_size = self.states.shape[0] if self.check_consistency: self.verifyConsistency("Load") def save(self, path: str) -> None: assert self.meta_name, "Can not save this memory! No meta_class specified!" Path(path).mkdir(parents=True, exist_ok=True) for item in ["states", "outputs"]: with gzip.GzipFile(f"{path}/{item}.npy.gz", "w") as f: numpy.save(f, getattr(self, item).cpu().numpy()) with gzip.GzipFile(f"{path}/start.npy.gz", "w") as f: numpy.save(f, numpy.array(list(self.start))) with open(f"{path}/metadata.json", "w") as f: f.write(json.dumps({"teacher": self.meta_name, "did_overflow": self.did_overflow, "current_index": self.current_index})) def verifyConsistency(self, mes) -> None: assert not torch.any(self.outputs[:len(self)] == 0), f"[ERROR]: Memory is not consistent! ({mes})" # assert not torch.any(torch.bitwise_or(self.outputs[:len(self)] > 100, self.outputs[:len(self)] < -100)), \ # f"[ERROR]: Memory is not consistent in a weird way! max: {self.outputs.max()} min: {self.outputs.min()} ({mes})"