78 lines
2.7 KiB
Python
78 lines
2.7 KiB
Python
from typing import Type
|
|
import qpd.networks.network_interface as ni
|
|
|
|
from qpd.config import Config
|
|
from qpd.networks.wrapper.student.student_net import StudentNet
|
|
from qpd.utils import summary
|
|
from .compression.policy_distillation import PolicyDistillation
|
|
from .networks.wrapper.model_wrapper import ModelWrapper
|
|
|
|
import json
|
|
from pathlib import Path
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
class Compressor:
|
|
def __init__(self, model, environment_constructor, config):
|
|
self.model = model
|
|
self.environment_constructor = environment_constructor
|
|
self.config: Config = config
|
|
|
|
self.teacher: ModelWrapper = ModelWrapper.construct_wrapper(self.config, self.model).to(torch.device(self.config.compression_config.device))
|
|
self.student = None
|
|
self.config.init_env_model_based_params(self.teacher)
|
|
|
|
|
|
def student_network(self, net: Type[StudentNet]):
|
|
"""Add student network to compressor with existing wrappers.
|
|
|
|
Args:
|
|
net (Type[nn.Module]): Student network
|
|
|
|
Returns:
|
|
self
|
|
"""
|
|
assert self.student == None, "Can not set second student!"
|
|
self.config.student_config.student_network = net
|
|
self.student_model = self.teacher.construct_student()
|
|
self.student = ModelWrapper.construct_wrapper(self.config, self.student_model).to(torch.device(self.config.compression_config.device))
|
|
return self
|
|
|
|
def student_model(self, student: Type["ni.NetworkInterface"]):
|
|
"""Add end-to-end student model to compressor
|
|
|
|
Args:
|
|
student (Type["ni.NetworkInterface"]): _description_
|
|
|
|
Returns:
|
|
_type_: _description_
|
|
"""
|
|
assert self.student == None, "Can not set second student!"
|
|
self.student = student(self.config)
|
|
return self
|
|
|
|
def set_environment_done_filter(self, func):
|
|
self.config.evaluator_config.environment_is_done_filter = func
|
|
return self
|
|
|
|
def compress(self):
|
|
assert self.student is not None, "No student assigned! Use Compressor(**).student_network(**) or Use Compressor(**).student_model(**)"
|
|
|
|
Path(f"{self.config.data_directory}/Data/{self.config.run_name}").mkdir(parents=True, exist_ok=True)
|
|
config_dump = json.dumps(self.config.serializable(), indent=4, sort_keys=True)
|
|
with open(f"{self.config.data_directory}/Data/{self.config.run_name}/config.json", "w") as f:
|
|
f.write(config_dump)
|
|
|
|
print("Current configuration:")
|
|
print(config_dump)
|
|
|
|
distillation = PolicyDistillation(
|
|
self.teacher,
|
|
self.student,
|
|
self.config
|
|
)
|
|
|
|
distillation.train()
|
|
|
|
return distillation.student
|