Skip to content

Commit 2242304

Browse files
authored
API to measure MSE for target quantizers (#940)
## What does this PR do? **Type of change:** new feature ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** add an API to measure MSE for target quantizers given a forward loop ## Usage <!-- You can potentially add a usage example below. --> ```python # 1. Quantize the model as usual model = mtq.quantize(model, quant_cfg, forward_loop) # 2. Compute MSE for all quantizers mse = mtq.compute_quantization_mse(model, forward_loop) # 3. Print the top-5 noisiest quantizers for name, err in sorted(mse.items(), key=lambda x: -x[1])[:5]: print(f"{name}: {err:.4e}") ``` ## Testing <!-- Mention how have you tested your change if applicable. --> Unit test and test with HF PTQ ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added an API to measure per-quantizer mean-squared error (MSE) between original and fake-quantized tensors; supports wildcard and callable filtering, skips disabled/non-fake-quant quantizers, and runs safely under no-grad. * **Tests** * Added comprehensive tests for MSE validity, pattern and callable filtering, union behavior, exclusion of disabled quantizers, preservation of model state, and forward-hook cleanup. * **Documentation** * Updated changelog to document the new MSE measurement API. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> Signed-off-by: Wei-Ming Chen <17592131+meenchen@users.noreply.github.com>
1 parent be6dfad commit 2242304

File tree

3 files changed

+219
-0
lines changed

3 files changed

+219
-0
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ NVIDIA Model Optimizer Changelog
1919
- Enable PTQ workflow for Qwen3.5 MoE models.
2020
- Add ``nvfp4_omlp_only`` quantization format for NVFP4 quantization. This is similar to ``nvfp4_mlp_only`` but also quantizes the output projection layer in attention.
2121
- ``pass_through_bwd`` in the quantization config is now default to True. Please set it to False if you want to use STE with zeroed outlier gradients for potentially better QAT accuracy.
22+
- Add :meth:`compute_quantization_mse <modelopt.torch.quantization.model_quant.compute_quantization_mse>` API to measure per-quantizer mean-squared quantization error, with flexible wildcard and callable filtering.
2223

2324
**Misc**
2425

modelopt/torch/quantization/model_quant.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
__all__ = [
4242
"auto_quantize",
4343
"calibrate",
44+
"compute_quantization_mse",
4445
"disable_quantizer",
4546
"enable_quantizer",
4647
"fold_weight",
@@ -535,3 +536,79 @@ def fold_weight(model: nn.Module, keep_attrs: bool = False):
535536
for name, module in model.named_modules():
536537
if isinstance(module, QuantModule):
537538
module.fold_weight(keep_attrs)
539+
540+
541+
@torch.no_grad()
542+
def compute_quantization_mse(
543+
model: nn.Module,
544+
forward_loop: ForwardLoop,
545+
wildcards: str | Callable | list[str | Callable] = "*",
546+
) -> dict[str, float]:
547+
"""Compute the mean-squared quantization error for selected quantizers.
548+
549+
Runs ``forward_loop`` through the model while recording, for every matching
550+
:class:`TensorQuantizer`, the MSE between the original float tensor and
551+
its fake-quantized (Q→DQ) counterpart. Values are averaged over all
552+
calibration batches.
553+
554+
Args:
555+
model: A quantized model (output of :func:`quantize`).
556+
forward_loop: Callable that takes ``model`` and runs data through it.
557+
wildcards: One or more fnmatch glob patterns (or callable filters)
558+
matched against :class:`TensorQuantizer` module names in
559+
``model.named_modules()``. Follows the same convention as
560+
``quant_cfg`` wildcard keys. Defaults to ``"*"`` (all quantizers).
561+
562+
Returns:
563+
A dict mapping each matched quantizer's fully-qualified name to its
564+
mean MSE (float). Quantizers that are disabled or not in fake-quant
565+
mode are skipped and absent from the output.
566+
567+
Example::
568+
569+
mse = mtq.compute_quantization_mse(
570+
model,
571+
forward_loop,
572+
wildcards=["*k_bmm_quantizer", "*v_bmm_quantizer"],
573+
)
574+
for name, err in sorted(mse.items()):
575+
print(f"{name}: {err:.4e}")
576+
"""
577+
if not isinstance(wildcards, list):
578+
wildcards = [wildcards]
579+
580+
def _matches(name: str) -> bool:
581+
return any(fnmatch.fnmatch(name, w) if isinstance(w, str) else w(name) for w in wildcards)
582+
583+
accumulators: dict[str, dict] = {} # name -> {"sum": float, "count": int}
584+
hooks = []
585+
586+
for name, module in model.named_modules():
587+
if not isinstance(module, TensorQuantizer):
588+
continue
589+
if not _matches(name):
590+
continue
591+
if not (module._if_quant and module._fake_quant) or module._disabled:
592+
continue
593+
accumulators[name] = {"sum": 0.0, "count": 0}
594+
595+
def _make_hook(acc):
596+
def hook(mod, inp, out):
597+
original = inp[0].detach().float()
598+
quantized = out.detach().float()
599+
acc["sum"] += torch.mean((original - quantized) ** 2).item()
600+
acc["count"] += 1
601+
602+
return hook
603+
604+
hooks.append(module.register_forward_hook(_make_hook(accumulators[name])))
605+
606+
try:
607+
forward_loop(model)
608+
finally:
609+
for h in hooks:
610+
h.remove()
611+
612+
return {
613+
name: acc["sum"] / acc["count"] for name, acc in accumulators.items() if acc["count"] > 0
614+
}
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Unit tests for mtq.compute_quantization_mse()."""
17+
18+
import torch
19+
from _test_utils.torch.quantization.models import SimpleLinear
20+
21+
import modelopt.torch.quantization as mtq
22+
from modelopt.torch.quantization.nn import TensorQuantizer
23+
24+
INT8_CFG = {
25+
"quant_cfg": {
26+
"*weight_quantizer": {"num_bits": 8, "axis": 0},
27+
"*input_quantizer": {"num_bits": 8, "axis": None},
28+
},
29+
"algorithm": "max",
30+
}
31+
32+
33+
def _make_quantized_model():
34+
model = SimpleLinear()
35+
calib_data = [model.get_input() for _ in range(4)]
36+
37+
def forward_loop(m):
38+
for batch in calib_data:
39+
m(batch)
40+
41+
mtq.quantize(model, INT8_CFG, forward_loop)
42+
return model, forward_loop
43+
44+
45+
class TestComputeQuantizationMse:
46+
def test_returns_nonnegative_values(self):
47+
"""MSE values must be >= 0 for all quantizers."""
48+
model, forward_loop = _make_quantized_model()
49+
mse = mtq.compute_quantization_mse(model, forward_loop)
50+
assert len(mse) > 0
51+
assert all(v >= 0.0 for v in mse.values())
52+
53+
def test_wildcard_star_covers_all_enabled_fake_quant(self):
54+
"""Default wildcard '*' should return an entry for every enabled fake-quant quantizer."""
55+
model, forward_loop = _make_quantized_model()
56+
mse = mtq.compute_quantization_mse(model, forward_loop, wildcards="*")
57+
58+
expected_names = {
59+
name
60+
for name, module in model.named_modules()
61+
if isinstance(module, TensorQuantizer)
62+
and module._if_quant
63+
and module._fake_quant
64+
and not module._disabled
65+
}
66+
assert set(mse.keys()) == expected_names
67+
68+
def test_wildcard_filters_by_suffix(self):
69+
"""A suffix pattern should restrict results to matching quantizer names."""
70+
model, forward_loop = _make_quantized_model()
71+
mse = mtq.compute_quantization_mse(model, forward_loop, wildcards="*weight_quantizer")
72+
assert len(mse) > 0
73+
assert all("weight_quantizer" in k for k in mse)
74+
# No input quantizers should appear
75+
assert not any("input_quantizer" in k for k in mse)
76+
77+
def test_list_of_wildcards(self):
78+
"""A list of patterns should return the union of matched quantizers."""
79+
model, forward_loop = _make_quantized_model()
80+
mse_weight = mtq.compute_quantization_mse(
81+
model, forward_loop, wildcards="*weight_quantizer"
82+
)
83+
mse_input = mtq.compute_quantization_mse(model, forward_loop, wildcards="*input_quantizer")
84+
mse_both = mtq.compute_quantization_mse(
85+
model, forward_loop, wildcards=["*weight_quantizer", "*input_quantizer"]
86+
)
87+
assert set(mse_both.keys()) == set(mse_weight.keys()) | set(mse_input.keys())
88+
89+
def test_callable_filter(self):
90+
"""A callable wildcard should select quantizers by arbitrary predicate."""
91+
model, forward_loop = _make_quantized_model()
92+
# Pick only quantizers belonging to the first linear layer (net.0)
93+
mse = mtq.compute_quantization_mse(model, forward_loop, wildcards=lambda n: "net.0" in n)
94+
assert len(mse) > 0
95+
assert all("net.0" in k for k in mse)
96+
97+
def test_disabled_quantizer_absent_from_result(self):
98+
"""A quantizer disabled after calibration must not appear in the output."""
99+
model, forward_loop = _make_quantized_model()
100+
101+
# Disable one quantizer and record its name
102+
disabled_name = None
103+
for name, module in model.named_modules():
104+
if isinstance(module, TensorQuantizer) and module._if_quant and module._fake_quant:
105+
module.disable()
106+
disabled_name = name
107+
break
108+
109+
assert disabled_name is not None, "No enabled quantizer found to disable"
110+
111+
mse = mtq.compute_quantization_mse(model, forward_loop)
112+
assert disabled_name not in mse
113+
114+
def test_no_matching_wildcard_returns_empty_dict(self):
115+
"""A pattern that matches nothing should return an empty dict."""
116+
model, forward_loop = _make_quantized_model()
117+
mse = mtq.compute_quantization_mse(
118+
model, forward_loop, wildcards="*nonexistent_quantizer_xyz*"
119+
)
120+
assert mse == {}
121+
122+
def test_does_not_modify_model_parameters(self):
123+
"""Running MSE measurement must leave model weights unchanged."""
124+
model, forward_loop = _make_quantized_model()
125+
params_before = {k: v.clone() for k, v in model.named_parameters()}
126+
mtq.compute_quantization_mse(model, forward_loop)
127+
for k, v in model.named_parameters():
128+
assert torch.equal(v, params_before[k]), f"Parameter {k} was modified"
129+
130+
def test_hooks_removed_after_call(self):
131+
"""All forward hooks registered during the call must be cleaned up."""
132+
model, forward_loop = _make_quantized_model()
133+
134+
hooks_before = sum(
135+
len(m._forward_hooks) for m in model.modules() if isinstance(m, TensorQuantizer)
136+
)
137+
mtq.compute_quantization_mse(model, forward_loop)
138+
hooks_after = sum(
139+
len(m._forward_hooks) for m in model.modules() if isinstance(m, TensorQuantizer)
140+
)
141+
assert hooks_after == hooks_before

0 commit comments

Comments
 (0)