Tutorial 3: Running Quantization-Aware Training (QAT) on Bert

Tutorial 3: Running Quantization-Aware Training (QAT) on Bert#

In this tutorial, we’ll build on top of Tutorial 2 by taking the Bert model fine tuned for sequence classification and running Mase’s quantization pass. First, we’ll run simple Post-Training Quantization (PTQ) and see how much accuracy drops. Then, we’ll run some further training iterations of the quantized model (i.e. QAT) and see whether the accuracy of the trained quantized model approaches the accuracy of the original (full-precision) model.

checkpoint = "prajjwal1/bert-tiny"
tokenizer_checkpoint = "bert-base-uncased"
dataset_name = "imdb"

Importing the model#

If you are starting from scratch, you can create a MaseGraph for Bert by running the following cell.

from transformers import AutoModelForSequenceClassification

from chop import MaseGraph
import chop.passes as passes

model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
model.config.problem_type = "single_label_classification"

mg = MaseGraph(
    model,
    hf_input_names=[
        "input_ids",
        "attention_mask",
        "labels",
    ],
)

mg, _ = passes.init_metadata_analysis_pass(mg)
mg, _ = passes.add_common_metadata_analysis_pass(mg)
/Users/yz10513/anaconda3/envs/mase/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
`past_key_values` were not specified as input names, but model.config.use_cache = True. Setting model.config.use_cache = False.
INFO     Getting dummy input for prajjwal1/bert-tiny.
/Users/yz10513/anaconda3/envs/mase/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(

If you have previously ran the tutorial on LoRA Finetuning, run the following cell to import the fine tuned checkpoint.

from pathlib import Path
from chop import MaseGraph

mg = MaseGraph.from_checkpoint(f"{Path.home()}/tutorial_2_lora")
WARNING  Node finfo not found in loaded metadata.
WARNING  Node getattr_2 not found in loaded metadata.

Post-Training Quantization (PTQ)#

Here, we simply quantize the model and evaluate the effect in its accuracy. First, let’s evaluate the model accuracy before quantization (if you’re coming from Tutorial 2, this should be the same as the post-LoRA evaluation accuracy). As seen in Tutorial 2, we can use the get_tokenized_dataset and get_trainer utilities to generate a HuggingFace Trainer instance for training and evaluation.

from chop.tools import get_tokenized_dataset, get_trainer

dataset, tokenizer = get_tokenized_dataset(
    dataset=dataset_name,
    checkpoint=tokenizer_checkpoint,
    return_tokenizer=True,
)

trainer = get_trainer(
    model=mg.model,
    tokenized_dataset=dataset,
    tokenizer=tokenizer,
    evaluate_metric="accuracy",
)

# Evaluate accuracy
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")
INFO     Tokenizing dataset imdb with AutoTokenizer for bert-base-uncased.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
[2024-12-01 13:51:21,992] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to mps (auto detect)
W1201 13:51:23.472000 8580182592 torch/distributed/elastic/multiprocessing/redirects.py:27] NOTE: Redirects are currently not supported in Windows or MacOs.
100%|██████████| 3125/3125 [05:14<00:00,  9.95it/s]
Evaluation accuracy: 0.8282

To run the quantization pass, we pass a quantization configuration dictionary as argument. This defines the quantization mode, numerical format and precision for each operator in the graph. We’ll run the quantization in “by type” mode, meaning nodes are quantized according to their mase_op. Other modes include by name and by regex name. We’ll quantize all activations, weights and biases in the model to fixed-point with the same precision. This may be sub-optimal, but works as an example. In future tutorials, we’ll see how to run the search flow in Mase to find optimal quantization configurations to minimize accuracy loss.

import chop.passes as passes

quantization_config = {
    "by": "type",
    "default": {
        "config": {
            "name": None,
        }
    },
    "linear": {
        "config": {
            "name": "integer",
            # data
            "data_in_width": 8,
            "data_in_frac_width": 4,
            # weight
            "weight_width": 8,
            "weight_frac_width": 4,
            # bias
            "bias_width": 8,
            "bias_frac_width": 4,
        }
    },
}

mg, _ = passes.quantize_transform_pass(
    mg,
    pass_args=quantization_config,
)

Let’s evaluate the immediate effect of quantization on the model accuracy.

trainer = get_trainer(
    model=mg.model,
    tokenized_dataset=dataset,
    tokenizer=tokenizer,
    evaluate_metric="accuracy",
)
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")
100%|██████████| 3125/3125 [00:55<00:00, 56.59it/s]
Evaluation accuracy: 0.8282

We can save the current checkpoint for future reference (optional).

from pathlib import Path

mg.export(f"{Path.home()}/tutorial_3_ptq")
INFO     Exporting MaseGraph to /Users/yz10513/tutorial_3_ptq.pt, /Users/yz10513/tutorial_3_ptq.mz
INFO     Exporting GraphModule to /Users/yz10513/tutorial_3_ptq.pt
INFO     Exporting MaseMetadata to /Users/yz10513/tutorial_3_ptq.mz

Quantization-Aware Training (QAT)#

You should have seen in the last section that quantization can lead to a significant drop in accuracy. Next, we’ll run QAT to evaluate whether this performance gap can be reduced. To run QAT in Mase, all you need to do is include the model back in your training loop after running the quantization pass.

# Evaluate accuracy
trainer.train()
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")
 16%|█▌        | 501/3125 [00:36<02:57, 14.80it/s]
{'loss': 0.4017, 'grad_norm': 10.568283081054688, 'learning_rate': 4.2e-05, 'epoch': 0.16}
 32%|███▏      | 1000/3125 [00:59<01:57, 18.05it/s]
{'loss': 0.3918, 'grad_norm': 7.2999749183654785, 'learning_rate': 3.4000000000000007e-05, 'epoch': 0.32}
 48%|████▊     | 1501/3125 [01:21<00:53, 30.53it/s]
{'loss': 0.3987, 'grad_norm': 11.782495498657227, 'learning_rate': 2.6000000000000002e-05, 'epoch': 0.48}
 64%|██████▍   | 2000/3125 [01:41<01:07, 16.55it/s]
{'loss': 0.3906, 'grad_norm': 6.373114585876465, 'learning_rate': 1.8e-05, 'epoch': 0.64}
 80%|████████  | 2504/3125 [01:59<00:20, 30.69it/s]
{'loss': 0.382, 'grad_norm': 8.522379875183105, 'learning_rate': 1e-05, 'epoch': 0.8}
 96%|█████████▌| 3000/3125 [02:15<00:04, 25.00it/s]
{'loss': 0.3914, 'grad_norm': 10.293235778808594, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.96}
100%|██████████| 3125/3125 [02:20<00:00, 22.26it/s]
{'train_runtime': 140.3713, 'train_samples_per_second': 178.099, 'train_steps_per_second': 22.262, 'train_loss': 0.3934635339355469, 'epoch': 1.0}
100%|██████████| 3125/3125 [00:51<00:00, 60.90it/s]
Evaluation accuracy: 0.84232

We can see the accuracy of the quantized model can match (or sometimes exceed) the full precision model, with a much lower memory requirement to store the weights. Finally, save the final checkpoint for future tutorials.

from pathlib import Path

mg.export(f"{Path.home()}/tutorial_3_qat")
INFO     Exporting MaseGraph to /Users/yz10513/tutorial_3_qat.pt, /Users/yz10513/tutorial_3_qat.mz
INFO     Exporting GraphModule to /Users/yz10513/tutorial_3_qat.pt
INFO     Exporting MaseMetadata to /Users/yz10513/tutorial_3_qat.mz