1 year ago

#5680

test-img

albert828

How to continue Quantization Aware Training of saved model in PyTorch?

I have a DL model that is trained in two phases:

  1. Pretraining using synthetic data
  2. Finetuning using real world data

Model is saved after phase 1. At phase 2 model is created and loaded from .pth file and training starts again with new data. I'd like to apply a QAT but I have a problem at phase 2. Losses are really huge (like beginnig of synthetic training without QAT - should be over 60x smaller). I suspect it's fault of observers re-initiation and freeze. The question is: What is the correct way to load QAT model and continue training?

Code for phase 1:

import torch

...
self.create_net()
self.net.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(self.net, inplace=True)
# Skip fuse Conv-Bn-ReLU
...

# In training loop
if train_iter == 40_000:
    print("Freeze batch norm mean and variance estimates")
    self.net.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
if train_iter == 50_000:
    print("Freeze quantizer parameters")
    self.net.apply(torch.quantization.disable_observer)
...

# After training
# Do not convert to quantized model since it'll be trained again
torch.save(self.net.state_dict(), str(filepath))

Code for phase 2:


import torch

...
self.create_net()
custom_load_state_dict(self.net, torch.load(str(filepath), map_location="cpu"))
self.net.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(self.net, inplace=True)
# Freeze observers and bn immediately after model load
self.net.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
self.net.apply(torch.quantization.disable_observer)
...

# Another file
def custom_load_state_dict(target_module: torch.nn.Module, source_state_dict: dict) -> None:
    target_state_dict = target_module.state_dict()
    for key, source_tensor in source_state_dict.items():
        if key in target_state_dict:
            if target_state_dict[key].shape == source_tensor.shape:
                target_state_dict[key] = source_tensor
    unmatched_keys = target_module.load_state_dict(target_state_dict)
    if unmatched_keys:
        print(f'Unmatched keys during model loading:\n{unmatched_keys}')

I've tried first initialize QAT and then load weights but it doesn't change anything. I've tried also manually convert model to QAT:

# Instead:
torch.quantization.prepare_qat(self.net, inplace=True)

# Do:
from torch.ao.quantization import get_default_qat_module_mappings, propagate_qconfig_, convert

mapping = get_default_qat_module_mappings()
propagate_qconfig_(model, qconfig_dict=None)
convert(model, mapping=mapping, inplace=True, remove_qconfig=False)

But after training when I try to convert to quantized model it throws an error:

# Throws error - missing observers
quantized_model = torch.quantization.convert(quantized_model, inplace=True)

When I skip freeze BN and observers after model load, it seems working fine. But is it correct? Doesn't that destroy the levels of quantization learned previously?

python

deep-learning

pytorch

quantization-aware-training

0 Answers

Your Answer

Accepted video resources