Tutorial 3: Running Quantization-Aware Training (QAT) on Bert#

In this tutorial, we build on top of Tutorial 2 by taking a BERT sequence-classification model and running MASE quantization passes. We first run Post-Training Quantization (PTQ), then continue training with Quantization-Aware Training (QAT) to recover quantized-model accuracy.

Run this tutorial#

From the repository root:

uv run python docs/source/modules/documentation/tutorials/tutorial_3_qat.py

Expected terminal output (excerpt)#

============================================================
Tutorial 3: Post-Training Quantization + QAT on BERT
============================================================

[1/5] Loading model...
WARNING  Node finfo not found in loaded metadata.
WARNING  Node getattr_2 not found in loaded metadata.
      Loaded from tutorial_2_lora  ✓

[2/5] Loading dataset and evaluating baseline accuracy...
INFO     Tokenizing dataset imdb with AutoTokenizer for bert-base-uncased.
      Dataset loaded: 25000 train / 25000 test
      [Baseline] Accuracy: 0.8350

[3/5] Applying integer quantization (PTQ)...
      quantize_transform_pass  ✓
      [PTQ] Accuracy: 0.7738
INFO     Exporting MaseGraph to ~/tutorial_3_ptq.pt, ~/tutorial_3_ptq.mz
INFO     Exporting GraphModule to ~/tutorial_3_ptq.pt
INFO     Saving full model format
INFO     Exporting MaseMetadata to ~/tutorial_3_ptq.mz
      PTQ checkpoint saved to ~/tutorial_3_ptq

[4/5] Running QAT (1 epoch)...
{'loss': 0.4101, 'grad_norm': 10.378958702087402, 'learning_rate': 4.2016e-05, 'epoch': 0.16}
{'loss': 0.401, 'grad_norm': 5.116357326507568, 'learning_rate': 3.4016e-05, 'epoch': 0.32}
{'loss': 0.3959, 'grad_norm': 15.316791534423828, 'learning_rate': 2.6016000000000003e-05, 'epoch': 0.48}
...
{'train_runtime': 141.1579, 'train_samples_per_second': 177.107, 'train_steps_per_second': 22.138, 'train_loss': 0.39675407958984377, 'epoch': 1.0}
      [QAT] Accuracy: 0.8399

[5/5] Exporting QAT checkpoint...
INFO     Exporting MaseGraph to ~/tutorial_3_qat.pt, ~/tutorial_3_qat.mz
INFO     Exporting GraphModule to ~/tutorial_3_qat.pt
INFO     Saving full model format
INFO     Exporting MaseMetadata to ~/tutorial_3_qat.mz
      QAT checkpoint saved to ~/tutorial_3_qat

============================================================
Tutorial 3 complete!
============================================================

Importing the model#

If you are starting from scratch, create a fresh MaseGraph for BERT:

Step 1: Import model and build MaseGraph#

print("\n[1/5] Loading model...", flush=True)
from transformers import AutoModelForSequenceClassification
import chop.passes as passes
from chop import MaseGraph

# Option A: load from Tutorial 2 LoRA checkpoint
mg = MaseGraph.from_checkpoint(f"{Path.home()}/tutorial_2_lora")
print("      Loaded from tutorial_2_lora  ✓", flush=True)

# Option B: start from scratch (use if Tutorial 2 checkpoint is not available)
# model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
# model.config.problem_type = "single_label_classification"
# mg = MaseGraph(
#     model,
#     hf_input_names=["input_ids", "attention_mask", "labels"],
# )
# mg, _ = passes.init_metadata_analysis_pass(mg)
# mg, _ = passes.add_common_metadata_analysis_pass(mg)
# print("      Fresh MaseGraph built  ✓", flush=True)

Example output:

[1/5] Loading model...
WARNING  Node finfo not found in loaded metadata.
WARNING  Node getattr_2 not found in loaded metadata.
      Loaded from tutorial_2_lora  ✓

If Tutorial 2 has not been run yet, you can build a fresh MaseGraph instead (comment out Option A and uncomment Option B in the script):

model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
model.config.problem_type = "single_label_classification"
mg = MaseGraph(model, hf_input_names=["input_ids", "attention_mask", "labels"])
mg, _ = passes.init_metadata_analysis_pass(mg)
mg, _ = passes.add_common_metadata_analysis_pass(mg)

Post-Training Quantization (PTQ)#

Before quantization, evaluate baseline accuracy with the tokenized IMDb dataset and HuggingFace trainer.

Step 2: Baseline evaluation#

print("\n[2/5] Loading dataset and evaluating baseline accuracy...", flush=True)
from chop.tools import get_tokenized_dataset, get_trainer

dataset, tokenizer = get_tokenized_dataset(
    dataset=dataset_name,
    checkpoint=tokenizer_checkpoint,
    return_tokenizer=True,
)
print(f"      Dataset loaded: {len(dataset['train'])} train / {len(dataset['test'])} test", flush=True)

trainer = get_trainer(
    model=mg.model,
    tokenized_dataset=dataset,
    tokenizer=tokenizer,
    evaluate_metric="accuracy",
)
eval_results = trainer.evaluate()
print(f"      [Baseline] Accuracy: {eval_results['eval_accuracy']:.4f}", flush=True)

Example output:

[2/5] Loading dataset and evaluating baseline accuracy...
INFO     Tokenizing dataset imdb with AutoTokenizer for bert-base-uncased.
      Dataset loaded: 25000 train / 25000 test
      [Baseline] Accuracy: 0.8350

Next, run quantization with a “by type” config, where quantization is assigned by mase_op. In this tutorial, all linear activations/weights/biases are quantized with the same integer precision.

Step 3: Apply PTQ and evaluate#

print("\n[3/5] Applying integer quantization (PTQ)...", flush=True)

quantization_config = {
    "by": "type",
    "default": {"config": {"name": None}},
    "linear": {
        "config": {
            "name": "integer",
            "data_in_width": 8,
            "data_in_frac_width": 4,
            "weight_width": 8,
            "weight_frac_width": 4,
            "bias_width": 8,
            "bias_frac_width": 4,
        }
    },
}

mg, _ = passes.quantize_transform_pass(mg, pass_args=quantization_config)
print("      quantize_transform_pass  ✓", flush=True)

trainer = get_trainer(
    model=mg.model,
    tokenized_dataset=dataset,
    tokenizer=tokenizer,
    evaluate_metric="accuracy",
)
eval_results = trainer.evaluate()
print(f"      [PTQ] Accuracy: {eval_results['eval_accuracy']:.4f}", flush=True)

mg.export(f"{Path.home()}/tutorial_3_ptq")
print(f"      PTQ checkpoint saved to {Path.home()}/tutorial_3_ptq", flush=True)

Example output:

[3/5] Applying integer quantization (PTQ)...
      quantize_transform_pass  ✓
      [PTQ] Accuracy: 0.7738
INFO     Exporting MaseGraph to ~/tutorial_3_ptq.pt, ~/tutorial_3_ptq.mz
INFO     Exporting GraphModule to ~/tutorial_3_ptq.pt
INFO     Saving full model format
INFO     Exporting MaseMetadata to ~/tutorial_3_ptq.mz
      PTQ checkpoint saved to ~/tutorial_3_ptq

Quantization-Aware Training (QAT)#

PTQ alone can reduce accuracy. To reduce this performance gap, include the quantized model back in the training loop and fine-tune with QAT.

Step 4: Run QAT#

print("\n[4/5] Running QAT (1 epoch)...", flush=True)
trainer.train()
eval_results = trainer.evaluate()
print(f"      [QAT] Accuracy: {eval_results['eval_accuracy']:.4f}", flush=True)

Example output:

{'loss': 0.4101, 'grad_norm': 10.378958702087402, 'learning_rate': 4.2016e-05, 'epoch': 0.16}
{'loss': 0.401, 'grad_norm': 5.116357326507568, 'learning_rate': 3.4016e-05, 'epoch': 0.32}
{'loss': 0.3959, 'grad_norm': 15.316791534423828, 'learning_rate': 2.6016000000000003e-05, 'epoch': 0.48}
{'loss': 0.3906, 'grad_norm': 9.798357009887695, 'learning_rate': 1.8015999999999998e-05, 'epoch': 0.64}
{'loss': 0.3874, 'grad_norm': 6.183642864227295, 'learning_rate': 1.0016e-05, 'epoch': 0.8}
{'loss': 0.3918, 'grad_norm': 10.731794357299805, 'learning_rate': 2.0160000000000003e-06, 'epoch': 0.96}
{'train_runtime': 141.1579, 'train_samples_per_second': 177.107, 'train_steps_per_second': 22.138, 'train_loss': 0.39675407958984377, 'epoch': 1.0}
      [QAT] Accuracy: 0.8399

Step 5: Export final QAT checkpoint#

print("\n[5/5] Exporting QAT checkpoint...", flush=True)
mg.export(f"{Path.home()}/tutorial_3_qat")
print(f"      QAT checkpoint saved to {Path.home()}/tutorial_3_qat", flush=True)

print("\n" + "=" * 60, flush=True)
print("Tutorial 3 complete!", flush=True)
print("=" * 60, flush=True)

Example output:

INFO     Exporting MaseGraph to ~/tutorial_3_qat.pt, ~/tutorial_3_qat.mz
INFO     Exporting GraphModule to ~/tutorial_3_qat.pt
INFO     Saving full model format
INFO     Exporting MaseMetadata to ~/tutorial_3_qat.mz
      QAT checkpoint saved to ~/tutorial_3_qat

Conclusion#

Tutorial 3 demonstrates a standard PTQ→QAT workflow:

  • PTQ gives a quick quantized baseline and can reduce model accuracy.

  • QAT can recover (and in some runs exceed) full-precision accuracy.

  • Exported checkpoints are saved for follow-up tutorials:

    • tutorial_3_ptq.pt and tutorial_3_ptq.mz

    • tutorial_3_qat.pt and tutorial_3_qat.mz