Fix some bugs
This commit is contained in:
parent
862c55e03c
commit
8ca8413c5f
|
|
@ -118,7 +118,7 @@ class Memory:
|
|||
self.max_size += amount
|
||||
self.did_overflow = False
|
||||
|
||||
if self.verifyConsistency:
|
||||
if self.check_consistency:
|
||||
self.verifyConsistency("Grow")
|
||||
return len(self)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
# import wandb
|
||||
import copy
|
||||
import random
|
||||
import os
|
||||
import gc
|
||||
|
|
@ -18,7 +19,7 @@ from torch.distributions import Categorical
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
from .memory import Memory
|
||||
|
||||
|
|
@ -76,6 +77,7 @@ class PolicyDistillation:
|
|||
if (epoch and epoch % self.config.evaluator_config.student_test_frequency == 0) or self.config.evaluator_config.student_test_frequency == 1:
|
||||
best, is_best = self.test(epoch, best)
|
||||
if is_best:
|
||||
self.best_student = copy.deepcopy(self.student)
|
||||
self.save_student("best")
|
||||
|
||||
if (epoch and epoch % self.check_point_frequency == 0) or self.check_point_frequency == 1:
|
||||
|
|
@ -236,7 +238,7 @@ class PolicyDistillation:
|
|||
# if self.webdav_manager:
|
||||
# self.webdav_manager.upload_model(student_save_path, epoch, not self.config["testing"]["local_checkpoints"])
|
||||
|
||||
def test(self, epoch: int, best: Optional[float]=None) -> Optional[float]:
|
||||
def test(self, epoch: int, best: Optional[float]=None) -> Tuple[Optional[float], bool]:
|
||||
is_best = False
|
||||
print()
|
||||
print("= Testing =========================================================")
|
||||
|
|
|
|||
|
|
@ -162,7 +162,7 @@ class Evaluator:
|
|||
action, value, action_std = teacher(state)
|
||||
|
||||
if not worker_finished:
|
||||
if worker_memory.will_overflow:
|
||||
if worker_memory.will_overflow():
|
||||
print(f"Growing memory to size {worker_memory.max_size + 5000}!")
|
||||
worker_memory.grow(5000)
|
||||
|
||||
|
|
@ -191,7 +191,7 @@ class Evaluator:
|
|||
total_steps.append(worker_steps)
|
||||
worker_steps = 0
|
||||
|
||||
if (steps and effective_frames >= steps) or (episodes and len(total_rewards)):
|
||||
if (steps and effective_frames >= steps) or (episodes and len(total_rewards) >= episodes):
|
||||
worker_finished = True
|
||||
|
||||
if collect:
|
||||
|
|
@ -201,7 +201,7 @@ class Evaluator:
|
|||
end = time.time()
|
||||
print("Average and effective framerate:", round(frames / (end-start), 3), round(effective_frames / (end-start), 3))
|
||||
|
||||
return {"Steps": numpy.mean(total_steps) if len(total_steps) else 0, "Reward": numpy.mean(total_rewards) if len(total_rewards) else 0, "Memory": worker_memory if collect else []}
|
||||
return {"Steps": numpy.mean(total_steps) if len(total_steps) else 0, "Reward": numpy.mean(total_rewards) if len(total_rewards) else 0, "Memory": worker_memory if collect else [], "EpisodeCount": len(total_rewards)}
|
||||
|
||||
@abstractmethod
|
||||
def test_vec(config: Config,
|
||||
|
|
|
|||
Loading…
Reference in New Issue