Source code for chop.passes.graph.transforms.tensorrt.quantize.fine_tune
import logging
import torch
import logging
import os
from .utils import prepare_save_path, check_for_value_in_dict
[docs]
def tensorrt_fine_tune_transform_pass(graph, pass_args=None):
"""
Fine-tunes a quantized model using Quantization Aware Training (QAT) to improve its accuracy post-quantization.
This pass employs a fine-tuning process that adjusts the quantized model's weights in a way that acknowledges the quantization effects, thereby aiming to recover or even surpass the original model's accuracy. The training process uses a reduced number of epochs and a significantly lower learning rate compared to the initial training phase, following a cosine annealing learning rate schedule.
:param graph: The model graph to be fine-tuned. This graph should already be quantized.
:type graph: MaseGraph
:param pass_args: A dictionary containing arguments for fine-tuning, such as the number of epochs (`epochs`), the initial learning rate (`initial_learning_rate`), and the final learning rate (`final_learning_rate`). These parameters allow customization of the training regime based on the specific needs of the model and dataset.
:type pass_args: dict, optional
:return: A tuple containing the fine-tuned graph and an empty dictionary. The empty dictionary is a placeholder for potential extensions.
:rtype: tuple(MaseGraph, dict)
The default training regime involves:
- Using 10% of the original training epochs.
- Starting with 1% of the original training learning rate.
- Employing a cosine annealing schedule to reduce the learning rate to 0.01% of the initial training learning rate by the end of fine-tuning.
The resulting fine-tuned model checkpoints are saved in the following directory structure, facilitating easy access and version control:
- mase_output
- tensorrt
- quantization
- model_task_dataset_date
- cache
- ckpts
- fine_tuning
- json
- onnx
- trt
Example of usage:
graph = MaseGraph(...)
fine_tuned_graph, _ = tensorrt_fine_tune_transform_pass(graph, {'epochs': 5, 'initial_learning_rate': 0.001, 'final_learning_rate': 0.00001})
This example demonstrates initiating the fine-tuning process with custom epochs, and initial and final learning rates, adapting the training regime to the specific requirements of the quantized model.
"""
trainer = FineTuning(graph, pass_args)
ckpt = trainer.train()
# Link the model with the graph for further operations or evaluations
graph.model = torch.fx.GraphModule(graph.model, graph.fx_graph)
return graph, {"ckpt_save_path": ckpt}
class FineTuning:
def __init__(self, graph, config):
self.logger = logging.getLogger(__name__)
self.config = config
self.graph = graph
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.graph.model.to(self.device)
def train(self):
"""
For QAT it is typical to employ 10% of the original training epochs,
starting at 1% of the initial training learning rate, and a cosine
annealing learning rate schedule that follows the decreasing half of
a cosine period, down to 1% of the initial fine tuning learning rate
(0.01% of the initial training learning rate). However this default
can be overidden by setting the `epochs`, `initial_learning_rate` and
`final_learning_rate` in `passes.tensorrt.fine_tune`.
"""
if not self.config.get("fine_tune", {}).get("fine_tune", True):
self.logger.warning(
"Fine tuning is disabled in the config. Skipping QAT fine tuning."
)
return None
if not check_for_value_in_dict(self.config, "int8"):
self.logger.warning(
"int8 precision not found in config. Skipping QAT fine tuning."
)
return None
from chop.actions import train
from chop.models import get_model_info
# load the settings and default to chop default parameters
model_info = get_model_info(self.config["data_module"].model_name)
weight_decay = (
self.config["weight_decay"] if "weight_decay" in self.config else 0
)
optimizer = self.config["optimizer"] if "optimizer" in self.config else "adam"
# Check if user would like to override the initial learning rate otherwise default to 1% of original LR
try:
initial_fine_tune_lr = (self.config["initial_learning_rate"]) * 0.01
except KeyError:
initial_fine_tune_lr = (self.config.get("learning_rate", 1e-5)) * 0.01
# Check if user would like to override the final learning rate otherwise default to
# 1% of initial learning rate or 0.01% of original learning rate
try:
eta_min = self.config["final_learning_rate"]
except KeyError:
eta_min = initial_fine_tune_lr * 0.01 # Decreases to
# Check if user would like to override the number of epochs otherwise default to 10% of original epochs
try:
epochs = self.config["fine_tune"]["epochs"]
except KeyError:
epochs = int(self.config.get("max_epochs", 20) * 0.1)
t_max = int(len(self.config["data_module"].train_dataloader()) * epochs)
ckpt_save_path = prepare_save_path(
self.config, method="ckpts/fine_tuning", suffix="ckpt"
)
scheduler_args = {"t_max": t_max, "eta_min": eta_min}
plt_trainer_args = {
"max_epochs": epochs,
"accelerator": self.config["accelerator"],
}
self.logger.info(f"Starting Fine Tuning for {epochs} epochs...")
train(
self.graph.model,
model_info,
self.config["data_module"],
self.config["data_module"].dataset_info,
"Quantization Fine Tuning",
optimizer,
initial_fine_tune_lr,
weight_decay,
scheduler_args,
plt_trainer_args,
False,
ckpt_save_path,
None,
None,
"",
)
self.logger.info("Fine Tuning Complete")
return ckpt_save_path / "best.ckpt"