Tutorial 7: MX Post-Training Quantization with Mase#

This tutorial walks through the full MX post-training quantization flow end to end against Mase’s chop APIs, using ``unsloth/Llama-3.2-1B`` as the demo model.

Contents#

  • The MX quantization config — block-scaled integer / floating-point formats expressed as a single TOML file consumed by quantize_module_transform_pass.

  • Quantization algorithm — RTN, then [gptq] for weight quantization, then [rotation_search] for activation quantization.

Live-execution timings#

Each step runs in a single-GPU demo slot. The [gptq] and [rotation_search] passes cache to disk, so reruns finish in seconds once the first execution completes.

Prerequisites#

A CUDA 12.8 toolkit (nvcc reachable) and an NVIDIA driver supporting it. The mx-ptq optional dependencies include fast-hadamard-transform (built from source, ~2-3 min CUDA compile) and lm-eval.

Using uv:

uv sync --no-build-isolation --extra mx-ptq

Using pip:

pip install --no-build-isolation --extra-index-url https://download.pytorch.org/whl/cu128 ".[mx-ptq]"

Verify:

uv run python -c "from chop.passes.module.transforms import quantize_module_transform_pass; print('OK')"

Run the full tutorial script:

uv run python docs/source/modules/documentation/tutorials/tutorial_7_mx_ptq.py

MX quantization refresher#

MX (Microscaling) formats store a tensor as low-precision elements grouped into fixed-size blocks that share a single power-of-two scale (Rouhani et al. 2023; OCP MX standard). Mase’s configurable MX quantizers (chop.nn.quantizers.mxint, chop.nn.quantizers.mxfp) follow this single-level scaling scheme.

../../../_images/mx_annotated.png

MX data formats: MXFP and MXINT, each block sharing one power-of-two scale. Figure adapted from the PLENA paper. MXFP elements carry sign + exponent + mantissa; MXINT elements carry sign + mantissa. Both share a single power-of-two scale per block, parameterised by the tunable tuple \((M, E, S, B)\) for MXFP and \((M, S, B)\) for MXINT.#

The format tuple#

We describe an MX data format as a tuple

\[\tau = (d,\, b,\, B)\]

where \(d\) is the element datatype (INT for MXINT, minifloat for MXFP), \(b\) is the element bit-width, and \(B\) is the block size. So \(\tau = (\texttt{INT},\, 4,\, 16)\) is MXINT4 with block size 16, and \(\tau = (\texttt{minifloat},\, 4,\, 16)\) is MXFP4 with the same block size. Every element in a block shares one scale \(s\) and one zero-point \(z\); we use symmetric quantization throughout, so \(z = 0\).

Quantizing a block#

Given a high-precision tensor \(\mathbf{W}\) partitioned into blocks \(w \in \mathbb{R}^{B}\), the shared scale for each block is derived from its absmax against the format’s representable maximum:

\[s = \frac{\max |w|}{\max_{\tau}}\]

Each element is then projected into the low-precision grid by scaling, round-to-nearest, and clipping to the representable range:

\[w_{\tau} = \mathrm{clip}\!\left(\mathrm{RTN}\!\left(\tfrac{w}{s}\right),\ \min_{\tau},\ \max_{\tau}\right)\]

Dequantization reconstructs an approximation of the original block using the same shared scale:

\[Q(w;\, s,\, \tau) = s \cdot w_{\tau}\]

What the TOML exposes#

Format

Element

Width knobs

MXInt

Signed integer

weight_width, data_in_width

MXFP

Mini-float (exp + frac)

weight_exponent_width / weight_frac_width, data_in_exponent_width / data_in_frac_width

Block size is set via weight_block_size / data_in_block_size. The typical “MX4” setting used throughout this tutorial is 4-bit integer elements with block size 32.

The PTQ Algorithm#

Llama-3.2-1B is quantized to W4 A4 KV4 (weights, activations, KV cache all MXInt4) three ways, with the cost measured at each step. Every config targets the same final precision; each adds one new optimization tool:

Config

Adds

Where the gain comes from

A

MXInt4 round-to-nearest

Baseline — naive W4 A4 KV4

B

  • [gptq] Hessian-aware calib

Re-derives W4 weights so quantization error is output-aware

C

  • [rotation_search] greedy

Per-matmul online Hadamard rotation on the matmul types with the largest outliers

Perplexity is measured via lm-eval-harness’s ``wikitext`` task.

Setup#

Helpers used throughout this tutorial: load_quant_config parses a TOML config into the pass_args dict that quantize_module_transform_pass consumes; evaluate_perplexity wraps the lm-eval-harness WikiText perplexity evaluation.

import tomllib
from copy import deepcopy
from pathlib import Path

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from chop.passes.module.transforms import quantize_module_transform_pass

from lm_eval import simple_evaluate
from lm_eval.models.huggingface import HFLM
from lm_eval.utils import make_table

MODEL_NAME = "unsloth/Llama-3.2-1B"
OUTPUT_DIR = Path("tutorial_7_output")
OUTPUT_DIR.mkdir(exist_ok=True)
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"


def load_quant_config(path):
    with open(path, "rb") as f:
        raw = tomllib.load(f)
    by = raw.pop("by", "regex_name")
    pass_args = {"by": by}
    gptq = raw.pop("gptq", None)
    if gptq is not None:
        pass_args["gptq"] = deepcopy(gptq)
    for key, value in raw.items():
        pass_args[key] = {"config": deepcopy(value)}
    return pass_args


def evaluate_perplexity(model, tokenizer, max_length=2048, batch_size=8):
    """Run lm-eval-harness wikitext task; return word_perplexity."""
    lm = HFLM(pretrained=model, tokenizer=tokenizer,
              max_length=max_length, batch_size=batch_size)
    results = simple_evaluate(model=lm, tasks=["wikitext"],
                              batch_size=batch_size, limit=None)
    print(make_table(results))
    metrics = results["results"]["wikitext"]
    for k, v in metrics.items():
        if k.startswith("word_perplexity"):
            return float(v)
    raise KeyError(f"word_perplexity not found: {list(metrics)}")


def fresh_model(attn_implementation="eager"):
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME, torch_dtype=torch.float16,
        attn_implementation=attn_implementation,
    )
    model.eval()
    return model


tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)

fp16 baseline#

Reference perplexity before any quantization is applied.

model_fp16 = fresh_model(attn_implementation="sdpa").to(DEVICE)
ppl_fp16 = evaluate_perplexity(model_fp16, tokenizer)
print(f"fp16 ppl = {ppl_fp16:.4f}")
del model_fp16
if torch.cuda.is_available():
    torch.cuda.empty_cache()

Note

Reference result: fp16 ppl = 12.8835 — the baseline word_perplexity for fp16.

Config A — MXInt4 RTN (W4 A4 KV4)#

Every linear projection plus the KV cache gets MXInt4 with block size 32, no calibration. Two tiers of selectors:

  • self_attn$ (with end-anchor) swaps LlamaAttentionLlamaAttentionMXInt so the in-attention KV cache becomes quantizable. The nested [qk_matmul], [av_matmul], [softmax], [rope] sub-blocks are bypassed.

  • The config can be extended to cover those modules too — see the TOML reference.

  • self_attn.(q|k|v|o)_proj and mlp.(gate|up|down)_proj swap each nn.LinearLinearMXInt.

by = "regex_name"

# Attention block — swap LlamaAttention for the MXInt variant, quantize KV cache
['model\.layers\.\d+\.self_attn$']
name = "mxint"

['model\.layers\.\d+\.self_attn$'.qk_matmul]
bypass = true

['model\.layers\.\d+\.self_attn$'.av_matmul]
bypass = true

['model\.layers\.\d+\.self_attn$'.kv_cache]
data_in_block_size = 32
data_in_width      = 4

['model\.layers\.\d+\.self_attn$'.softmax]
bypass = true

['model\.layers\.\d+\.self_attn$'.rope]
bypass = true

# Linear projections: W4 + A4
['model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj']
name = "mxint"
weight_block_size = 32
weight_width = 4
data_in_block_size = 32
data_in_width = 4

['model\.layers\.\d+\.mlp\.(gate|up|down)_proj']
name = "mxint"
weight_block_size = 32
weight_width = 4
data_in_block_size = 32
data_in_width = 4

Apply the config in Python — three steps:

  1. Load the model.

  2. Parse the TOML into a pass_args dict.

  3. Transform the model in place — every matching nn.Linear becomes a LinearMXInt and every LlamaAttention becomes a LlamaAttentionMXInt that quantizes the KV cache on every call.

config_a_toml = OUTPUT_DIR / "config_a_rtn.toml"
config_a_toml.write_text(...)   # TOML literal shown above

model_a = fresh_model(attn_implementation="eager")
pass_args_a = load_quant_config(config_a_toml)
model_a, _ = quantize_module_transform_pass(model_a, pass_args_a)
model_a.to(DEVICE)
ppl_a = evaluate_perplexity(model_a, tokenizer)
print(f"Config A ppl = {ppl_a:.4f}")

Note

Reference result: Config A ppl = 54.89 vs fp16 12.88. Quantization at 4 bits with no outlier mitigation accounts for most of the degradation; Configs B and C reduce it.

Config B — add [gptq]#

GPTQ (Frantar et al., 2022) is a one-time pre-pass: it walks the decoder layer by layer and rewrites each linear’s weight in place using Hessian-based error compensation against a calibration set (wikitext2 here). Mase’s [gptq] pass extends the original integer-only formulation to MX data formats with block-wise clipping search.

The TOML keeps Config A’s selectors unchanged (attention swap + linear selectors + KV cache config) and adds a top-level [gptq] block plus a gptq = true flag on each linear selector so the linear classes consume the calibrated weights instead of re-quantizing.

The checkpoint_dir enables auto-resume: re-running this cell only recomputes layers that aren’t already on disk. First run takes a few minutes for Llama-3.2-1B; subsequent runs finish in seconds.

by = "regex_name"

[gptq]
model_name       = "unsloth/Llama-3.2-1B"
format           = "mxint"
dataset          = "wikitext2"
nsamples         = 32
seqlen           = 512
cali_batch_size  = 8
quantile_search  = true
clip_search_y    = true
checkpoint_dir   = "tutorial_7_output/checkpoints/config_b_gptq"

    [gptq.weight_config]
    weight_block_size = 32
    weight_width      = 4

# Same attention + linear selectors as Config A (KV cache stays quantized).

['model\.layers\.\d+\.self_attn$']
name = "mxint"

['model\.layers\.\d+\.self_attn$'.qk_matmul]
bypass = true

['model\.layers\.\d+\.self_attn$'.av_matmul]
bypass = true

['model\.layers\.\d+\.self_attn$'.kv_cache]
data_in_block_size = 32
data_in_width      = 4

['model\.layers\.\d+\.self_attn$'.softmax]
bypass = true

['model\.layers\.\d+\.self_attn$'.rope]
bypass = true

['model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj']
name = "mxint"
weight_block_size = 32
weight_width = 4
data_in_block_size = 32
data_in_width = 4
gptq = true

['model\.layers\.\d+\.mlp\.(gate|up|down)_proj']
name = "mxint"
weight_block_size = 32
weight_width = 4
data_in_block_size = 32
data_in_width = 4
gptq = true
model_b = fresh_model(attn_implementation="eager")
pass_args_b = load_quant_config(config_b_toml)
model_b, _ = quantize_module_transform_pass(model_b, pass_args_b)
model_b.to(DEVICE)
ppl_b = evaluate_perplexity(model_b, tokenizer)
print(f"Config B ppl = {ppl_b:.4f}")

Note

Reference result: Config B ppl = 31.0363 — roughly halves the gap from Config A to fp16. We will target activation quantization error next.

Recap — the progressive cost#

All three configs target the same precision (W4 A4 KV4); the config is the only thing that changes. Perplexity numbers come from lm-eval-harness’s wikitext task (word_perplexity).

Config

What it adds

WikiText word_perplexity

fp16

(reference)

12.8835

A — RTN

MXInt4 round-to-nearest

54.8872

B — +GPTQ

Hessian-aware weight calibration

31.0363

C — +rotsearch

Greedy per-matmul online Hadamard

20.4134

Reference run: Llama-3.2-1B on a single GPU. Results may vary slightly across hardware.

Going further#

This tutorial covers Mase’s core MX-PTQ flow end to end. For a broader evaluation surface, see PLENA Software, which builds on Mase and adds:

  • A unified CLI driving the same TOML config across multiple eval surfaces — eval_ppl, eval_lm (lm-eval-harness), eval_evalplus (HumanEval+), eval_phase_lm (different precision for prefill vs decode), eval_phase_bfcl (function calling), eval_dllm / eval_llada (diffusion LMs), eval_osworld (agentic tasks).

  • Paper-table reproductions for Llama-2 / Llama-3 across model sizes and bit configurations.

Wrap-up#

This tutorial covered:

  1. Three MX quantization configs targeting W4 A4 KV4, each adding one new tool: RTN → [gptq][rotation_search].

  2. Applying them to Llama-3.2-1B via quantize_module_transform_pass, with perplexity recovering at each step.

  3. How the same pass_args interface handles plain RTN, weight calibration, and per-matmul rotation through a single TOML config.