Fix some bugs

This commit is contained in:
Thomas Avé 2026-07-04 20:13:47 +07:00
parent 862c55e03c
commit 8ca8413c5f
Signed by: thomasave
SSH Key Fingerprint: SHA256:bvIbWy6TO9+PdMTPzWy6dqkRlVQ3eSky+vQcc9aRIiE
3 changed files with 8 additions and 6 deletions

View File

@ -118,7 +118,7 @@ class Memory:
self.max_size += amount self.max_size += amount
self.did_overflow = False self.did_overflow = False
if self.verifyConsistency: if self.check_consistency:
self.verifyConsistency("Grow") self.verifyConsistency("Grow")
return len(self) return len(self)

View File

@ -1,4 +1,5 @@
# import wandb # import wandb
import copy
import random import random
import os import os
import gc import gc
@ -18,7 +19,7 @@ from torch.distributions import Categorical
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from typing import Dict, Optional, Union from typing import Dict, Optional, Tuple, Union
from .memory import Memory from .memory import Memory
@ -76,6 +77,7 @@ class PolicyDistillation:
if (epoch and epoch % self.config.evaluator_config.student_test_frequency == 0) or self.config.evaluator_config.student_test_frequency == 1: if (epoch and epoch % self.config.evaluator_config.student_test_frequency == 0) or self.config.evaluator_config.student_test_frequency == 1:
best, is_best = self.test(epoch, best) best, is_best = self.test(epoch, best)
if is_best: if is_best:
self.best_student = copy.deepcopy(self.student)
self.save_student("best") self.save_student("best")
if (epoch and epoch % self.check_point_frequency == 0) or self.check_point_frequency == 1: if (epoch and epoch % self.check_point_frequency == 0) or self.check_point_frequency == 1:
@ -236,7 +238,7 @@ class PolicyDistillation:
# if self.webdav_manager: # if self.webdav_manager:
# self.webdav_manager.upload_model(student_save_path, epoch, not self.config["testing"]["local_checkpoints"]) # self.webdav_manager.upload_model(student_save_path, epoch, not self.config["testing"]["local_checkpoints"])
def test(self, epoch: int, best: Optional[float]=None) -> Optional[float]: def test(self, epoch: int, best: Optional[float]=None) -> Tuple[Optional[float], bool]:
is_best = False is_best = False
print() print()
print("= Testing =========================================================") print("= Testing =========================================================")

View File

@ -162,7 +162,7 @@ class Evaluator:
action, value, action_std = teacher(state) action, value, action_std = teacher(state)
if not worker_finished: if not worker_finished:
if worker_memory.will_overflow: if worker_memory.will_overflow():
print(f"Growing memory to size {worker_memory.max_size + 5000}!") print(f"Growing memory to size {worker_memory.max_size + 5000}!")
worker_memory.grow(5000) worker_memory.grow(5000)
@ -191,7 +191,7 @@ class Evaluator:
total_steps.append(worker_steps) total_steps.append(worker_steps)
worker_steps = 0 worker_steps = 0
if (steps and effective_frames >= steps) or (episodes and len(total_rewards)): if (steps and effective_frames >= steps) or (episodes and len(total_rewards) >= episodes):
worker_finished = True worker_finished = True
if collect: if collect:
@ -201,7 +201,7 @@ class Evaluator:
end = time.time() end = time.time()
print("Average and effective framerate:", round(frames / (end-start), 3), round(effective_frames / (end-start), 3)) print("Average and effective framerate:", round(frames / (end-start), 3), round(effective_frames / (end-start), 3))
return {"Steps": numpy.mean(total_steps) if len(total_steps) else 0, "Reward": numpy.mean(total_rewards) if len(total_rewards) else 0, "Memory": worker_memory if collect else []} return {"Steps": numpy.mean(total_steps) if len(total_steps) else 0, "Reward": numpy.mean(total_rewards) if len(total_rewards) else 0, "Memory": worker_memory if collect else [], "EpisodeCount": len(total_rewards)}
@abstractmethod @abstractmethod
def test_vec(config: Config, def test_vec(config: Config,