from typing import Type import qpd.networks.network_interface as ni from qpd.config import Config from qpd.networks.wrapper.student.student_net import StudentNet from qpd.utils import summary from .compression.policy_distillation import PolicyDistillation from .networks.wrapper.model_wrapper import ModelWrapper import json from pathlib import Path import torch import torch.nn as nn class Compressor: def __init__(self, model, environment_constructor, config): self.model = model self.environment_constructor = environment_constructor self.config: Config = config self.teacher: ModelWrapper = ModelWrapper.construct_wrapper(self.config, self.model).to(torch.device(self.config.compression_config.device)) self.student = None self.config.init_env_model_based_params(self.teacher) def student_network(self, net: Type[StudentNet]): """Add student network to compressor with existing wrappers. Args: net (Type[nn.Module]): Student network Returns: self """ assert self.student == None, "Can not set second student!" self.config.student_config.student_network = net self.student_model = self.teacher.construct_student() self.student = ModelWrapper.construct_wrapper(self.config, self.student_model).to(torch.device(self.config.compression_config.device)) return self def student_model(self, student: Type["ni.NetworkInterface"]): """Add end-to-end student model to compressor Args: student (Type["ni.NetworkInterface"]): _description_ Returns: _type_: _description_ """ assert self.student == None, "Can not set second student!" self.student = student(self.config) return self def set_environment_done_filter(self, func): self.config.evaluator_config.environment_is_done_filter = func return self def compress(self): assert self.student is not None, "No student assigned! Use Compressor(**).student_network(**) or Use Compressor(**).student_model(**)" Path(f"{self.config.data_directory}/Data/{self.config.run_name}").mkdir(parents=True, exist_ok=True) config_dump = json.dumps(self.config.serializable(), indent=4, sort_keys=True) with open(f"{self.config.data_directory}/Data/{self.config.run_name}/config.json", "w") as f: f.write(config_dump) print("Current configuration:") print(config_dump) distillation = PolicyDistillation( self.teacher, self.student, self.config ) distillation.train() return distillation.student