QPD/src/qpd/compressor.py

78 lines
2.7 KiB
Python

from typing import Type
import qpd.networks.network_interface as ni
from qpd.config import Config
from qpd.networks.wrapper.student.student_net import StudentNet
from qpd.utils import summary
from .compression.policy_distillation import PolicyDistillation
from .networks.wrapper.model_wrapper import ModelWrapper
import json
from pathlib import Path
import torch
import torch.nn as nn
class Compressor:
def __init__(self, model, environment_constructor, config):
self.model = model
self.environment_constructor = environment_constructor
self.config: Config = config
self.teacher: ModelWrapper = ModelWrapper.construct_wrapper(self.config, self.model).to(torch.device(self.config.compression_config.device))
self.student = None
self.config.init_env_model_based_params(self.teacher)
def student_network(self, net: Type[StudentNet]):
"""Add student network to compressor with existing wrappers.
Args:
net (Type[nn.Module]): Student network
Returns:
self
"""
assert self.student == None, "Can not set second student!"
self.config.student_config.student_network = net
self.student_model = self.teacher.construct_student()
self.student = ModelWrapper.construct_wrapper(self.config, self.student_model).to(torch.device(self.config.compression_config.device))
return self
def student_model(self, student: Type["ni.NetworkInterface"]):
"""Add end-to-end student model to compressor
Args:
student (Type["ni.NetworkInterface"]): _description_
Returns:
_type_: _description_
"""
assert self.student == None, "Can not set second student!"
self.student = student(self.config)
return self
def set_environment_done_filter(self, func):
self.config.evaluator_config.environment_is_done_filter = func
return self
def compress(self):
assert self.student is not None, "No student assigned! Use Compressor(**).student_network(**) or Use Compressor(**).student_model(**)"
Path(f"{self.config.data_directory}/Data/{self.config.run_name}").mkdir(parents=True, exist_ok=True)
config_dump = json.dumps(self.config.serializable(), indent=4, sort_keys=True)
with open(f"{self.config.data_directory}/Data/{self.config.run_name}/config.json", "w") as f:
f.write(config_dump)
print("Current configuration:")
print(config_dump)
distillation = PolicyDistillation(
self.teacher,
self.student,
self.config
)
distillation.train()
return distillation.student