1 year ago
#5680
albert828
How to continue Quantization Aware Training of saved model in PyTorch?
I have a DL model that is trained in two phases:
- Pretraining using synthetic data
- Finetuning using real world data
Model is saved after phase 1. At phase 2 model is created and loaded from .pth file and training starts again with new data. I'd like to apply a QAT but I have a problem at phase 2. Losses are really huge (like beginnig of synthetic training without QAT - should be over 60x smaller). I suspect it's fault of observers re-initiation and freeze. The question is: What is the correct way to load QAT model and continue training?
Code for phase 1:
import torch
...
self.create_net()
self.net.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(self.net, inplace=True)
# Skip fuse Conv-Bn-ReLU
...
# In training loop
if train_iter == 40_000:
print("Freeze batch norm mean and variance estimates")
self.net.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
if train_iter == 50_000:
print("Freeze quantizer parameters")
self.net.apply(torch.quantization.disable_observer)
...
# After training
# Do not convert to quantized model since it'll be trained again
torch.save(self.net.state_dict(), str(filepath))
Code for phase 2:
import torch
...
self.create_net()
custom_load_state_dict(self.net, torch.load(str(filepath), map_location="cpu"))
self.net.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(self.net, inplace=True)
# Freeze observers and bn immediately after model load
self.net.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
self.net.apply(torch.quantization.disable_observer)
...
# Another file
def custom_load_state_dict(target_module: torch.nn.Module, source_state_dict: dict) -> None:
target_state_dict = target_module.state_dict()
for key, source_tensor in source_state_dict.items():
if key in target_state_dict:
if target_state_dict[key].shape == source_tensor.shape:
target_state_dict[key] = source_tensor
unmatched_keys = target_module.load_state_dict(target_state_dict)
if unmatched_keys:
print(f'Unmatched keys during model loading:\n{unmatched_keys}')
I've tried first initialize QAT and then load weights but it doesn't change anything. I've tried also manually convert model to QAT:
# Instead:
torch.quantization.prepare_qat(self.net, inplace=True)
# Do:
from torch.ao.quantization import get_default_qat_module_mappings, propagate_qconfig_, convert
mapping = get_default_qat_module_mappings()
propagate_qconfig_(model, qconfig_dict=None)
convert(model, mapping=mapping, inplace=True, remove_qconfig=False)
But after training when I try to convert to quantized model it throws an error:
# Throws error - missing observers
quantized_model = torch.quantization.convert(quantized_model, inplace=True)
When I skip freeze BN and observers after model load, it seems working fine. But is it correct? Doesn't that destroy the levels of quantization learned previously?
python
deep-learning
pytorch
quantization-aware-training
0 Answers
Your Answer