From 8ca8413c5f5f1dcd4c207a605acc31608124fbcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Av=C3=A9?= Date: Sat, 4 Jul 2026 20:13:47 +0700 Subject: [PATCH] Fix some bugs --- src/qpd/compression/memory.py | 2 +- src/qpd/compression/policy_distillation.py | 6 ++++-- src/qpd/utils/evaluator.py | 6 +++--- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/qpd/compression/memory.py b/src/qpd/compression/memory.py index 4d9d814..51bcc8a 100644 --- a/src/qpd/compression/memory.py +++ b/src/qpd/compression/memory.py @@ -118,7 +118,7 @@ class Memory: self.max_size += amount self.did_overflow = False - if self.verifyConsistency: + if self.check_consistency: self.verifyConsistency("Grow") return len(self) diff --git a/src/qpd/compression/policy_distillation.py b/src/qpd/compression/policy_distillation.py index fd17fc6..f37bfea 100644 --- a/src/qpd/compression/policy_distillation.py +++ b/src/qpd/compression/policy_distillation.py @@ -1,4 +1,5 @@ # import wandb +import copy import random import os import gc @@ -18,7 +19,7 @@ from torch.distributions import Categorical import torch import torch.nn as nn import torch.nn.functional as F -from typing import Dict, Optional, Union +from typing import Dict, Optional, Tuple, Union from .memory import Memory @@ -76,6 +77,7 @@ class PolicyDistillation: if (epoch and epoch % self.config.evaluator_config.student_test_frequency == 0) or self.config.evaluator_config.student_test_frequency == 1: best, is_best = self.test(epoch, best) if is_best: + self.best_student = copy.deepcopy(self.student) self.save_student("best") if (epoch and epoch % self.check_point_frequency == 0) or self.check_point_frequency == 1: @@ -236,7 +238,7 @@ class PolicyDistillation: # if self.webdav_manager: # self.webdav_manager.upload_model(student_save_path, epoch, not self.config["testing"]["local_checkpoints"]) - def test(self, epoch: int, best: Optional[float]=None) -> Optional[float]: + def test(self, epoch: int, best: Optional[float]=None) -> Tuple[Optional[float], bool]: is_best = False print() print("= Testing =========================================================") diff --git a/src/qpd/utils/evaluator.py b/src/qpd/utils/evaluator.py index 65801a6..96f2ad9 100644 --- a/src/qpd/utils/evaluator.py +++ b/src/qpd/utils/evaluator.py @@ -162,7 +162,7 @@ class Evaluator: action, value, action_std = teacher(state) if not worker_finished: - if worker_memory.will_overflow: + if worker_memory.will_overflow(): print(f"Growing memory to size {worker_memory.max_size + 5000}!") worker_memory.grow(5000) @@ -191,7 +191,7 @@ class Evaluator: total_steps.append(worker_steps) worker_steps = 0 - if (steps and effective_frames >= steps) or (episodes and len(total_rewards)): + if (steps and effective_frames >= steps) or (episodes and len(total_rewards) >= episodes): worker_finished = True if collect: @@ -201,7 +201,7 @@ class Evaluator: end = time.time() print("Average and effective framerate:", round(frames / (end-start), 3), round(effective_frames / (end-start), 3)) - return {"Steps": numpy.mean(total_steps) if len(total_steps) else 0, "Reward": numpy.mean(total_rewards) if len(total_rewards) else 0, "Memory": worker_memory if collect else []} + return {"Steps": numpy.mean(total_steps) if len(total_steps) else 0, "Reward": numpy.mean(total_rewards) if len(total_rewards) else 0, "Memory": worker_memory if collect else [], "EpisodeCount": len(total_rewards)} @abstractmethod def test_vec(config: Config,