Tutorial 3: Running Quantization-Aware Training (QAT) on Bert#
In this tutorial, we’ll build on top of Tutorial 2 by taking the Bert model fine tuned for sequence classification and running Mase’s quantization pass. First, we’ll run simple Post-Training Quantization (PTQ) and see how much accuracy drops. Then, we’ll run some further training iterations of the quantized model (i.e. QAT) and see whether the accuracy of the trained quantized model approaches the accuracy of the original (full-precision) model.
checkpoint = "prajjwal1/bert-tiny"
tokenizer_checkpoint = "bert-base-uncased"
dataset_name = "imdb"
Importing the model#
If you are starting from scratch, you can create a MaseGraph for Bert by running the following cell.
from transformers import AutoModelForSequenceClassification
from chop import MaseGraph
import chop.passes as passes
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
model.config.problem_type = "single_label_classification"
mg = MaseGraph(
model,
hf_input_names=[
"input_ids",
"attention_mask",
"labels",
],
)
mg, _ = passes.init_metadata_analysis_pass(mg)
mg, _ = passes.add_common_metadata_analysis_pass(mg)
/Users/yz10513/anaconda3/envs/mase/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
`past_key_values` were not specified as input names, but model.config.use_cache = True. Setting model.config.use_cache = False.
INFO Getting dummy input for prajjwal1/bert-tiny.
/Users/yz10513/anaconda3/envs/mase/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
If you have previously ran the tutorial on LoRA Finetuning, run the following cell to import the fine tuned checkpoint.
from pathlib import Path
from chop import MaseGraph
mg = MaseGraph.from_checkpoint(f"{Path.home()}/tutorial_2_lora")
WARNING Node finfo not found in loaded metadata.
WARNING Node getattr_2 not found in loaded metadata.
Post-Training Quantization (PTQ)#
Here, we simply quantize the model and evaluate the effect in its accuracy. First, let’s evaluate the model accuracy before quantization (if you’re coming from Tutorial 2, this should be the same as the post-LoRA evaluation accuracy). As seen in Tutorial 2, we can use the get_tokenized_dataset
and get_trainer
utilities to generate a HuggingFace Trainer
instance for training and evaluation.
from chop.tools import get_tokenized_dataset, get_trainer
dataset, tokenizer = get_tokenized_dataset(
dataset=dataset_name,
checkpoint=tokenizer_checkpoint,
return_tokenizer=True,
)
trainer = get_trainer(
model=mg.model,
tokenized_dataset=dataset,
tokenizer=tokenizer,
evaluate_metric="accuracy",
)
# Evaluate accuracy
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")
INFO Tokenizing dataset imdb with AutoTokenizer for bert-base-uncased.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
[2024-12-01 13:51:21,992] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to mps (auto detect)
W1201 13:51:23.472000 8580182592 torch/distributed/elastic/multiprocessing/redirects.py:27] NOTE: Redirects are currently not supported in Windows or MacOs.
100%|██████████| 3125/3125 [05:14<00:00, 9.95it/s]
Evaluation accuracy: 0.8282
To run the quantization pass, we pass a quantization configuration dictionary as argument. This defines the quantization mode, numerical format and precision for each operator in the graph. We’ll run the quantization in “by type” mode, meaning nodes are quantized according to their mase_op
. Other modes include by name and by regex name. We’ll quantize all activations, weights and biases in the model to fixed-point with the same precision. This may be sub-optimal, but works as an example. In future tutorials, we’ll see how to run the search
flow in Mase
to find optimal quantization configurations to minimize accuracy loss.
import chop.passes as passes
quantization_config = {
"by": "type",
"default": {
"config": {
"name": None,
}
},
"linear": {
"config": {
"name": "integer",
# data
"data_in_width": 8,
"data_in_frac_width": 4,
# weight
"weight_width": 8,
"weight_frac_width": 4,
# bias
"bias_width": 8,
"bias_frac_width": 4,
}
},
}
mg, _ = passes.quantize_transform_pass(
mg,
pass_args=quantization_config,
)
Let’s evaluate the immediate effect of quantization on the model accuracy.
trainer = get_trainer(
model=mg.model,
tokenized_dataset=dataset,
tokenizer=tokenizer,
evaluate_metric="accuracy",
)
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")
100%|██████████| 3125/3125 [00:55<00:00, 56.59it/s]
Evaluation accuracy: 0.8282
We can save the current checkpoint for future reference (optional).
from pathlib import Path
mg.export(f"{Path.home()}/tutorial_3_ptq")
INFO Exporting MaseGraph to /Users/yz10513/tutorial_3_ptq.pt, /Users/yz10513/tutorial_3_ptq.mz
INFO Exporting GraphModule to /Users/yz10513/tutorial_3_ptq.pt
INFO Exporting MaseMetadata to /Users/yz10513/tutorial_3_ptq.mz
Quantization-Aware Training (QAT)#
You should have seen in the last section that quantization can lead to a significant drop in accuracy. Next, we’ll run QAT to evaluate whether this performance gap can be reduced. To run QAT in Mase, all you need to do is include the model back in your training loop after running the quantization pass.
# Evaluate accuracy
trainer.train()
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")
16%|█▌ | 501/3125 [00:36<02:57, 14.80it/s]
{'loss': 0.4017, 'grad_norm': 10.568283081054688, 'learning_rate': 4.2e-05, 'epoch': 0.16}
32%|███▏ | 1000/3125 [00:59<01:57, 18.05it/s]
{'loss': 0.3918, 'grad_norm': 7.2999749183654785, 'learning_rate': 3.4000000000000007e-05, 'epoch': 0.32}
48%|████▊ | 1501/3125 [01:21<00:53, 30.53it/s]
{'loss': 0.3987, 'grad_norm': 11.782495498657227, 'learning_rate': 2.6000000000000002e-05, 'epoch': 0.48}
64%|██████▍ | 2000/3125 [01:41<01:07, 16.55it/s]
{'loss': 0.3906, 'grad_norm': 6.373114585876465, 'learning_rate': 1.8e-05, 'epoch': 0.64}
80%|████████ | 2504/3125 [01:59<00:20, 30.69it/s]
{'loss': 0.382, 'grad_norm': 8.522379875183105, 'learning_rate': 1e-05, 'epoch': 0.8}
96%|█████████▌| 3000/3125 [02:15<00:04, 25.00it/s]
{'loss': 0.3914, 'grad_norm': 10.293235778808594, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.96}
100%|██████████| 3125/3125 [02:20<00:00, 22.26it/s]
{'train_runtime': 140.3713, 'train_samples_per_second': 178.099, 'train_steps_per_second': 22.262, 'train_loss': 0.3934635339355469, 'epoch': 1.0}
100%|██████████| 3125/3125 [00:51<00:00, 60.90it/s]
Evaluation accuracy: 0.84232
We can see the accuracy of the quantized model can match (or sometimes exceed) the full precision model, with a much lower memory requirement to store the weights. Finally, save the final checkpoint for future tutorials.
from pathlib import Path
mg.export(f"{Path.home()}/tutorial_3_qat")
INFO Exporting MaseGraph to /Users/yz10513/tutorial_3_qat.pt, /Users/yz10513/tutorial_3_qat.mz
INFO Exporting GraphModule to /Users/yz10513/tutorial_3_qat.pt
INFO Exporting MaseMetadata to /Users/yz10513/tutorial_3_qat.mz