Fix some bugs
This commit is contained in:
parent
862c55e03c
commit
8ca8413c5f
|
|
@ -118,7 +118,7 @@ class Memory:
|
||||||
self.max_size += amount
|
self.max_size += amount
|
||||||
self.did_overflow = False
|
self.did_overflow = False
|
||||||
|
|
||||||
if self.verifyConsistency:
|
if self.check_consistency:
|
||||||
self.verifyConsistency("Grow")
|
self.verifyConsistency("Grow")
|
||||||
return len(self)
|
return len(self)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
# import wandb
|
# import wandb
|
||||||
|
import copy
|
||||||
import random
|
import random
|
||||||
import os
|
import os
|
||||||
import gc
|
import gc
|
||||||
|
|
@ -18,7 +19,7 @@ from torch.distributions import Categorical
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
from .memory import Memory
|
from .memory import Memory
|
||||||
|
|
||||||
|
|
@ -76,6 +77,7 @@ class PolicyDistillation:
|
||||||
if (epoch and epoch % self.config.evaluator_config.student_test_frequency == 0) or self.config.evaluator_config.student_test_frequency == 1:
|
if (epoch and epoch % self.config.evaluator_config.student_test_frequency == 0) or self.config.evaluator_config.student_test_frequency == 1:
|
||||||
best, is_best = self.test(epoch, best)
|
best, is_best = self.test(epoch, best)
|
||||||
if is_best:
|
if is_best:
|
||||||
|
self.best_student = copy.deepcopy(self.student)
|
||||||
self.save_student("best")
|
self.save_student("best")
|
||||||
|
|
||||||
if (epoch and epoch % self.check_point_frequency == 0) or self.check_point_frequency == 1:
|
if (epoch and epoch % self.check_point_frequency == 0) or self.check_point_frequency == 1:
|
||||||
|
|
@ -236,7 +238,7 @@ class PolicyDistillation:
|
||||||
# if self.webdav_manager:
|
# if self.webdav_manager:
|
||||||
# self.webdav_manager.upload_model(student_save_path, epoch, not self.config["testing"]["local_checkpoints"])
|
# self.webdav_manager.upload_model(student_save_path, epoch, not self.config["testing"]["local_checkpoints"])
|
||||||
|
|
||||||
def test(self, epoch: int, best: Optional[float]=None) -> Optional[float]:
|
def test(self, epoch: int, best: Optional[float]=None) -> Tuple[Optional[float], bool]:
|
||||||
is_best = False
|
is_best = False
|
||||||
print()
|
print()
|
||||||
print("= Testing =========================================================")
|
print("= Testing =========================================================")
|
||||||
|
|
|
||||||
|
|
@ -162,7 +162,7 @@ class Evaluator:
|
||||||
action, value, action_std = teacher(state)
|
action, value, action_std = teacher(state)
|
||||||
|
|
||||||
if not worker_finished:
|
if not worker_finished:
|
||||||
if worker_memory.will_overflow:
|
if worker_memory.will_overflow():
|
||||||
print(f"Growing memory to size {worker_memory.max_size + 5000}!")
|
print(f"Growing memory to size {worker_memory.max_size + 5000}!")
|
||||||
worker_memory.grow(5000)
|
worker_memory.grow(5000)
|
||||||
|
|
||||||
|
|
@ -191,7 +191,7 @@ class Evaluator:
|
||||||
total_steps.append(worker_steps)
|
total_steps.append(worker_steps)
|
||||||
worker_steps = 0
|
worker_steps = 0
|
||||||
|
|
||||||
if (steps and effective_frames >= steps) or (episodes and len(total_rewards)):
|
if (steps and effective_frames >= steps) or (episodes and len(total_rewards) >= episodes):
|
||||||
worker_finished = True
|
worker_finished = True
|
||||||
|
|
||||||
if collect:
|
if collect:
|
||||||
|
|
@ -201,7 +201,7 @@ class Evaluator:
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print("Average and effective framerate:", round(frames / (end-start), 3), round(effective_frames / (end-start), 3))
|
print("Average and effective framerate:", round(frames / (end-start), 3), round(effective_frames / (end-start), 3))
|
||||||
|
|
||||||
return {"Steps": numpy.mean(total_steps) if len(total_steps) else 0, "Reward": numpy.mean(total_rewards) if len(total_rewards) else 0, "Memory": worker_memory if collect else []}
|
return {"Steps": numpy.mean(total_steps) if len(total_steps) else 0, "Reward": numpy.mean(total_rewards) if len(total_rewards) else 0, "Memory": worker_memory if collect else [], "EpisodeCount": len(total_rewards)}
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def test_vec(config: Config,
|
def test_vec(config: Config,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue