194 lines
7.8 KiB
Python
194 lines
7.8 KiB
Python
from __future__ import annotations
|
|
import torch
|
|
import random
|
|
import numpy
|
|
import gzip
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Iterator, Tuple, Dict, Optional
|
|
|
|
from qpd.config import Config
|
|
|
|
|
|
class Memory:
|
|
def __init__(self, config: Config, size=None, meta_class = None, path: Optional[str] = None) -> None:
|
|
self.meta_name = str(meta_class.__class__) if meta_class else None
|
|
|
|
self.max_size = size if size else config.memory_config.size
|
|
self.did_overflow = False
|
|
self.current_index = 0
|
|
|
|
self.output_n = config.output_shape[0]
|
|
self.device = config.memory_config.device
|
|
self.check_consistency = config.memory_config.check_consistency
|
|
self.critic = config.compression_config.critic_importance > 0
|
|
self.frame_stack_optim = config.memory_config.frame_stack_optimization
|
|
self.framestack_size = config.observation_shape[0] if self.frame_stack_optim else 0
|
|
self.student_driven = config.evaluator_config.student_driven
|
|
|
|
if path:
|
|
print(f"Load memory on path: {path}")
|
|
self.load(path)
|
|
else:
|
|
self.states = torch.zeros((self.max_size,) + (config.observation_shape[1:] if self.frame_stack_optim else config.observation_shape), device=self.device) # [(Input frame, output actions)]
|
|
self.outputs = torch.zeros((self.max_size, self.output_n), device=self.device) # + 1 = critic
|
|
# self.repeats = torch.zeros(self.size, device=self.device)
|
|
self.start = set()
|
|
|
|
|
|
def update(self, state: torch.Tensor, outputs: torch.Tensor, start: bool) -> None:
|
|
assert not torch.any(outputs == 0), f"[ERROR]: Update input is not valid!"
|
|
|
|
# update overflow
|
|
self.did_overflow = self.will_overflow()
|
|
|
|
if start:
|
|
self.start.add(self.current_index)
|
|
else:
|
|
self.start.discard(self.current_index)
|
|
|
|
state = state.to(self.device).data
|
|
self.states[self.current_index] = state[:, -1] if self.frame_stack_optim else state # Story only latest frame
|
|
|
|
self.outputs[self.current_index] = outputs.to(self.device).data
|
|
self.current_index = (self.current_index + 1) % self.max_size
|
|
|
|
if self.check_consistency:
|
|
self.verifyConsistency(f"Update current_index: {self.current_index} is start: {start} outputs: {self.outputs}")
|
|
|
|
def will_overflow(self, update_size=1):
|
|
if (self.current_index + update_size) >= self.max_size:
|
|
return True
|
|
return False
|
|
|
|
def real_size(self):
|
|
return (self.states.element_size() * self.states.nelement()) + (self.outputs.element_size() * self.outputs.nelement())
|
|
|
|
def sample(self, batch_size: int, shuffle: bool=True) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
|
|
state_indexes = list(range(len(self)))
|
|
if shuffle:
|
|
random.shuffle(state_indexes)
|
|
splits = [state_indexes[i:i + batch_size] for i in range(0, len(state_indexes), batch_size)]
|
|
|
|
if self.frame_stack_optim:
|
|
for split in splits:
|
|
current_states = []
|
|
indexes = torch.tensor(split)
|
|
for i in range(self.framestack_size):
|
|
current_states.append(self.states[(indexes - (self.framestack_size - 1) + i) % len(self)].unsqueeze(1))
|
|
yield (torch.cat(current_states, dim=1), self.outputs[split])
|
|
else:
|
|
for split in splits:
|
|
yield (self.states[split], self.outputs[split])
|
|
|
|
|
|
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
if self.frame_stack_optim:
|
|
states_index = list(self.start)[index]
|
|
return (torch.cat([self.states[i % len(self)].unsqueeze(0) for i in range(states_index-3, states_index+1)]), self.outputs[states_index])
|
|
else:
|
|
return (self.states[index], self.outputs[index])
|
|
|
|
|
|
def __len__(self) -> int:
|
|
if self.did_overflow:
|
|
return len(self.outputs)
|
|
else:
|
|
return self.current_index
|
|
|
|
|
|
def prune(self) -> None:
|
|
if not self.did_overflow:
|
|
for item in ["states", "outputs"]:
|
|
setattr(self, item, getattr(self, item)[0: self.current_index])
|
|
self.max_size = self.current_index
|
|
|
|
if self.check_consistency:
|
|
self.verifyConsistency("Prune")
|
|
|
|
|
|
def grow(self, amount: int) -> int:
|
|
print(f"Growing with size: {amount}")
|
|
for item in ["states", "outputs"]:
|
|
obj = getattr(self, item)
|
|
tmp = torch.zeros((self.max_size + amount,) + obj.shape[1:], device = self.device, dtype=obj.dtype)
|
|
tmp[:len(self)] = obj
|
|
setattr(self, item, tmp)
|
|
|
|
self.max_size += amount
|
|
self.did_overflow = False
|
|
|
|
if self.check_consistency:
|
|
self.verifyConsistency("Grow")
|
|
return len(self)
|
|
|
|
|
|
def load_memory(self, other: Memory) -> None:
|
|
assert other.max_size <= self.max_size, f"Memory to load is bigger than current memory! {other.max_size} > {self.max_size}"
|
|
|
|
self.did_overflow = self.will_overflow(len(other))
|
|
|
|
for item in ["states", "outputs"]:
|
|
getattr(self, item)[self.current_index:min(self.current_index+len(other), self.max_size)] = getattr(other, item)[:min(len(other), self.max_size - self.current_index)]
|
|
|
|
if self.current_index + len(other) > self.max_size:
|
|
getattr(self, item)[:len(other) - (self.max_size - self.current_index)] = getattr(other, item)[self.max_size - self.current_index:]
|
|
|
|
for idx in range(len(other)):
|
|
new_idx = (self.current_index + idx) % self.max_size
|
|
self.start.discard(new_idx)
|
|
if idx in other.start:
|
|
self.start.add(new_idx)
|
|
|
|
self.current_index = (self.current_index + len(other)) % self.max_size
|
|
|
|
if self.check_consistency:
|
|
self.verifyConsistency("Load memory")
|
|
|
|
|
|
def to(self, device: torch.device):
|
|
for item in ["states", "outputs"]:
|
|
setattr(self, item, getattr(self, item).to(device))
|
|
|
|
self.device = device
|
|
|
|
|
|
def load(self, path: str) -> None:
|
|
assert self.meta_name, "Can not load this memory! No meta_class specified!"
|
|
with open(f"{path}/metadata.json") as f:
|
|
j = json.loads(f.read())
|
|
assert j["teacher"] == self.meta_name
|
|
self.current_index = j["current_index"]
|
|
self.did_overflow = j["did_overflow"]
|
|
|
|
for item in ["states", "outputs"]:
|
|
with gzip.GzipFile(f"{path}/{item}.npy.gz") as f:
|
|
setattr(self, item, torch.from_numpy(numpy.load(file=f)))
|
|
|
|
with gzip.GzipFile(f"{path}/start.npy.gz") as f:
|
|
self.start = set(numpy.load(file=f))
|
|
|
|
self.max_size = self.states.shape[0]
|
|
|
|
if self.check_consistency:
|
|
self.verifyConsistency("Load")
|
|
|
|
|
|
def save(self, path: str) -> None:
|
|
assert self.meta_name, "Can not save this memory! No meta_class specified!"
|
|
Path(path).mkdir(parents=True, exist_ok=True)
|
|
for item in ["states", "outputs"]:
|
|
with gzip.GzipFile(f"{path}/{item}.npy.gz", "w") as f:
|
|
numpy.save(f, getattr(self, item).cpu().numpy())
|
|
|
|
with gzip.GzipFile(f"{path}/start.npy.gz", "w") as f:
|
|
numpy.save(f, numpy.array(list(self.start)))
|
|
with open(f"{path}/metadata.json", "w") as f:
|
|
f.write(json.dumps({"teacher": self.meta_name, "did_overflow": self.did_overflow, "current_index": self.current_index}))
|
|
|
|
def verifyConsistency(self, mes) -> None:
|
|
assert not torch.any(self.outputs[:len(self)] == 0), f"[ERROR]: Memory is not consistent! ({mes})"
|
|
# assert not torch.any(torch.bitwise_or(self.outputs[:len(self)] > 100, self.outputs[:len(self)] < -100)), \
|
|
# f"[ERROR]: Memory is not consistent in a weird way! max: {self.outputs.max()} min: {self.outputs.min()} ({mes})"
|
|
|