Skip to content

API to measure MSE for target quantizers#940

Merged
meenchen merged 8 commits intomainfrom
weimingc/api_for_quantization_mse
Mar 6, 2026
Merged

API to measure MSE for target quantizers#940
meenchen merged 8 commits intomainfrom
weimingc/api_for_quantization_mse

Conversation

@meenchen
Copy link
Contributor

@meenchen meenchen commented Feb 26, 2026

What does this PR do?

Type of change: new feature ?

Overview: add an API to measure MSE for target quantizers given a forward loop

Usage

 # 1. Quantize the model as usual
model = mtq.quantize(model, quant_cfg, forward_loop)
# 2. Compute MSE for all quantizers
mse = mtq.compute_quantization_mse(model, forward_loop)
# 3. Print the top-5 noisiest quantizers
for name, err in sorted(mse.items(), key=lambda x: -x[1])[:5]:
   print(f"{name}: {err:.4e}")

Testing

Unit test and test with HF PTQ

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Added an API to measure per-quantizer mean-squared error (MSE) between original and fake-quantized tensors; supports wildcard and callable filtering, skips disabled/non-fake-quant quantizers, and runs safely under no-grad.
  • Tests

    • Added comprehensive tests for MSE validity, pattern and callable filtering, union behavior, exclusion of disabled quantizers, preservation of model state, and forward-hook cleanup.
  • Documentation

    • Updated changelog to document the new MSE measurement API.

@meenchen meenchen requested a review from a team as a code owner February 26, 2026 19:53
@meenchen meenchen requested a review from Edwardf0t1 February 26, 2026 19:53
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 26, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 26, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a public API compute_quantization_mse(model, forward_loop, wildcards="*") that selects TensorQuantizer modules, registers forward hooks to accumulate per-quantizer mean-squared error during a provided forward loop, and returns average MSE per matched quantizer.

Changes

Cohort / File(s) Summary
Core API Implementation
modelopt/torch/quantization/model_quant.py
Adds and exports compute_quantization_mse(). Selects quantizers via wildcard strings or callables, skips disabled/non-fake-quant modules with rank-0 logs, registers forward hooks to compute and accumulate per-element MSE during the provided forward_loop, and removes hooks after execution; returns a dict of average MSEs for matched quantizers.
Unit tests
tests/unit/torch/quantization/test_compute_quantization_mse.py
New test module validating non-negative MSE outputs, wildcard matching (default "*", suffix matches, list unions), callable filters, exclusion of disabled quantizers, empty-match behavior (returns empty dict), preservation of model parameters, and cleanup of forward hooks.
Changelog
CHANGELOG.rst
Documents the new compute_quantization_mse API and its behavior.

Sequence Diagram(s)

sequenceDiagram
    participant Caller as Caller
    participant Model as Model
    participant Hooks as ForwardHooks
    participant Quantizers as TensorQuantizerModules
    participant Accumulator as MSEAccumulator

    Caller->>Model: compute_quantization_mse(model, forward_loop, wildcards)
    Model->>Model: select quantizers by wildcard/callable
    Model->>Hooks: register forward hooks on matched quantizers

    loop forward_loop iterations
        Caller->>Model: run one forward iteration
        Model->>Quantizers: forward through hooked modules
        Quantizers->>Hooks: invoke hook with (orig_tensor, quantized_tensor)
        Hooks->>Accumulator: compute per-element MSE and accumulate (sum, count)
    end

    Model->>Hooks: remove all registered hooks
    Accumulator->>Caller: return dict{name -> mean MSE}
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 68.75% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely summarizes the main change: introducing a new API to measure MSE (mean squared error) for quantizers, which matches the primary functionality added across all modified files.
Security Anti-Patterns ✅ Passed Comprehensive security review found no critical anti-patterns from SECURITY.md in modified Python files or the new compute_quantization_mse() function.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch weimingc/api_for_quantization_mse

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
tests/unit/torch/quantization/test_compute_quantization_mse.py (1)

128-143: Add hook-cleanup coverage for exception path.

Current cleanup assertion only covers success. Add a test where forward_loop raises and verify hook counts are restored.

Proposed test addition
+    def test_hooks_removed_when_forward_loop_raises(self):
+        model, _ = _make_quantized_model()
+        sample = model.get_input()
+
+        hooks_before = sum(
+            len(m._forward_hooks)
+            for m in model.modules()
+            if isinstance(m, TensorQuantizer)
+        )
+
+        def failing_forward_loop(m):
+            m(sample)
+            raise RuntimeError("expected failure")
+
+        with pytest.raises(RuntimeError):
+            mtq.compute_quantization_mse(model, failing_forward_loop)
+
+        hooks_after = sum(
+            len(m._forward_hooks)
+            for m in model.modules()
+            if isinstance(m, TensorQuantizer)
+        )
+        assert hooks_after == hooks_before
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/torch/quantization/test_compute_quantization_mse.py` around lines
128 - 143, Add a new test that mirrors test_hooks_removed_after_call but
exercises the exception path: use _make_quantized_model() to get model and a
forward_loop that raises (e.g., a callable that raises an exception when
invoked), record hooks_before by summing len(m._forward_hooks) for m in
model.modules() if isinstance(m, TensorQuantizer), call
mtq.compute_quantization_mse(model, forward_loop) inside a try/except (catch the
expected exception), then recompute hooks_after the same way and assert
hooks_after == hooks_before to verify hooks are cleaned up on error; use the
same TensorQuantizer and compute_quantization_mse identifiers so the test
targets the same cleanup logic.
modelopt/torch/quantization/model_quant.py (1)

570-596: Short-circuit when no quantizers match wildcards.

If no hooks are registered, the function still executes the full forward loop and then returns {}. Early return avoids unnecessary compute and side effects.

Proposed fix
-    forward_loop(model)
+    if not hooks:
+        return {}
+
+    forward_loop(model)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/model_quant.py` around lines 570 - 596, The loop
registers hooks into TensorQuantizer modules but if none match the wildcards we
should short-circuit to avoid running forward_loop unnecessarily: after building
hooks/accumulators (the variables named hooks and accumulators in
compute_quantization_mse), check if hooks is empty, log a brief message with
print_rank_0 and return {} immediately; otherwise proceed to call
forward_loop(model) and the existing cleanup/aggregation logic.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/quantization/model_quant.py`:
- Around line 595-599: The forward pass call forward_loop(model) can raise and
currently skips cleanup; wrap the forward_loop invocation in a try/finally so
you always remove all hooks in hooks (each h.remove()) and restore the model's
original training/eval state; also set model.eval() before running forward_loop
and after the try/finally restore model.train() if it was originally in training
mode (use a saved flag like was_training) to guarantee both hook removal and
state restoration even on exceptions.

---

Nitpick comments:
In `@modelopt/torch/quantization/model_quant.py`:
- Around line 570-596: The loop registers hooks into TensorQuantizer modules but
if none match the wildcards we should short-circuit to avoid running
forward_loop unnecessarily: after building hooks/accumulators (the variables
named hooks and accumulators in compute_quantization_mse), check if hooks is
empty, log a brief message with print_rank_0 and return {} immediately;
otherwise proceed to call forward_loop(model) and the existing
cleanup/aggregation logic.

In `@tests/unit/torch/quantization/test_compute_quantization_mse.py`:
- Around line 128-143: Add a new test that mirrors test_hooks_removed_after_call
but exercises the exception path: use _make_quantized_model() to get model and a
forward_loop that raises (e.g., a callable that raises an exception when
invoked), record hooks_before by summing len(m._forward_hooks) for m in
model.modules() if isinstance(m, TensorQuantizer), call
mtq.compute_quantization_mse(model, forward_loop) inside a try/except (catch the
expected exception), then recompute hooks_after the same way and assert
hooks_after == hooks_before to verify hooks are cleaned up on error; use the
same TensorQuantizer and compute_quantization_mse identifiers so the test
targets the same cleanup logic.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between dfe705a and 4365a72.

📒 Files selected for processing (2)
  • modelopt/torch/quantization/model_quant.py
  • tests/unit/torch/quantization/test_compute_quantization_mse.py

Comment on lines +595 to +599
forward_loop(model)

for h in hooks:
h.remove()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Preserve model state and always remove hooks on failure.

At Line 595, an exception in forward_loop(model) skips cleanup, leaving stale hooks registered. Run in eval mode and use try/finally so cleanup and state restoration are guaranteed.

Proposed fix
-    forward_loop(model)
-
-    for h in hooks:
-        h.remove()
+    is_training = model.training
+    model.eval()
+    try:
+        forward_loop(model)
+    finally:
+        for h in hooks:
+            h.remove()
+        model.train(is_training)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
forward_loop(model)
for h in hooks:
h.remove()
is_training = model.training
model.eval()
try:
forward_loop(model)
finally:
for h in hooks:
h.remove()
model.train(is_training)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/model_quant.py` around lines 595 - 599, The
forward pass call forward_loop(model) can raise and currently skips cleanup;
wrap the forward_loop invocation in a try/finally so you always remove all hooks
in hooks (each h.remove()) and restore the model's original training/eval state;
also set model.eval() before running forward_loop and after the try/finally
restore model.train() if it was originally in training mode (use a saved flag
like was_training) to guarantee both hook removal and state restoration even on
exceptions.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a user-facing API to measure per-quantizer mean-squared error (MSE) by running a provided forward_loop, plus unit tests validating wildcard/filter behavior and cleanup semantics.

Changes:

  • Introduce mtq.compute_quantization_mse(model, forward_loop, wildcards=...) to compute per-quantizer MSE during fake-quant.
  • Export the new API from the quantization public surface (model_quant.__all__ / mtq namespace).
  • Add unit tests covering wildcard matching, callable filters, disabled quantizer handling, parameter immutability, and hook cleanup.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.

File Description
modelopt/torch/quantization/model_quant.py Adds the new compute_quantization_mse API and exports it publicly.
tests/unit/torch/quantization/test_compute_quantization_mse.py Adds unit tests for MSE computation, filtering, and hook cleanup.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +576 to +579
print_rank_0(
f"[compute_quantization_mse] Skipping {name}: "
f"_if_quant={module._if_quant}, _fake_quant={module._fake_quant}"
)
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compute_quantization_mse prints a message for every matched quantizer that is not in fake-quant mode. For large models this can be very noisy and introduces an unexpected stdout side-effect for a measurement API. Consider removing this print or gating it behind an explicit verbose flag / debug logging.

Suggested change
print_rank_0(
f"[compute_quantization_mse] Skipping {name}: "
f"_if_quant={module._if_quant}, _fake_quant={module._fake_quant}"
)

Copilot uses AI. Check for mistakes.
@meenchen meenchen force-pushed the weimingc/api_for_quantization_mse branch 2 times, most recently from a5b775e to 6e0fe4e Compare March 4, 2026 23:00
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
modelopt/torch/quantization/model_quant.py (1)

610-613: ⚠️ Potential issue | 🟠 Major

Guard hook cleanup and mode restoration with try/finally.

At Line 610, an exception in forward_loop(model) skips hook removal, so stale hooks can leak into later runs. Also preserve and restore the model’s training state around measurement.

Proposed fix
-    forward_loop(model)
-
-    for h in hooks:
-        h.remove()
+    was_training = model.training
+    model.eval()
+    try:
+        forward_loop(model)
+    finally:
+        for h in hooks:
+            h.remove()
+        model.train(was_training)
#!/bin/bash
# Verify cleanup/mode-restore guards exist around compute_quantization_mse forward loop.
rg -n -C4 'def compute_quantization_mse|forward_loop\(model\)|try:|finally:|h\.remove\(|model\.eval\(|model\.train\(' \
  modelopt/torch/quantization/model_quant.py \
  tests/unit/torch/quantization/test_compute_quantization_mse.py
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/model_quant.py` around lines 610 - 613,
compute_quantization_mse currently calls forward_loop(model) and then removes
hooks, but an exception in forward_loop can skip hook cleanup and model mode
restoration; wrap the forward_loop invocation and subsequent cleanup in a
try/finally block that ensures for each hook in hooks you call h.remove() and
that the model's original training state (save current model.training) is
restored with model.train() or model.eval() as appropriate; locate the hooks
list and forward_loop(model) call in compute_quantization_mse and implement the
try/finally to guarantee h.remove() and restore model.training even on errors.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@modelopt/torch/quantization/model_quant.py`:
- Around line 610-613: compute_quantization_mse currently calls
forward_loop(model) and then removes hooks, but an exception in forward_loop can
skip hook cleanup and model mode restoration; wrap the forward_loop invocation
and subsequent cleanup in a try/finally block that ensures for each hook in
hooks you call h.remove() and that the model's original training state (save
current model.training) is restored with model.train() or model.eval() as
appropriate; locate the hooks list and forward_loop(model) call in
compute_quantization_mse and implement the try/finally to guarantee h.remove()
and restore model.training even on errors.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 27311daf-f499-4f92-8b43-e7f93acfbea5

📥 Commits

Reviewing files that changed from the base of the PR and between a5b775e and 6e0fe4e.

📒 Files selected for processing (3)
  • CHANGELOG.rst
  • modelopt/torch/quantization/model_quant.py
  • tests/unit/torch/quantization/test_compute_quantization_mse.py

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
modelopt/torch/quantization/model_quant.py (1)

611-615: ⚠️ Potential issue | 🟠 Major

Guarantee hook cleanup (and mode restoration) when forward_loop fails.

At Line 611, an exception from forward_loop(model) skips cleanup, leaving stale hooks registered. Wrap execution in try/finally; also run in eval mode and restore original train/eval state to avoid side effects during measurement.

Proposed fix
-    forward_loop(model)
-
-    for h in hooks:
-        h.remove()
+    was_training = model.training
+    model.eval()
+    try:
+        forward_loop(model)
+    finally:
+        for h in hooks:
+            h.remove()
+        model.train(was_training)
#!/bin/bash
set -euo pipefail

TARGET="$(fd '^model_quant\.py$' | head -n1)"
export TARGET

python - <<'PY'
import ast, os, pathlib
p = pathlib.Path(os.environ["TARGET"])
tree = ast.parse(p.read_text())

fn = next(
    n for n in tree.body
    if isinstance(n, ast.FunctionDef) and n.name == "compute_quantization_mse"
)

has_try = any(isinstance(n, ast.Try) for n in ast.walk(fn))
has_eval = any(
    isinstance(n, ast.Call)
    and isinstance(n.func, ast.Attribute)
    and n.func.attr == "eval"
    for n in ast.walk(fn)
)
has_train_restore = any(
    isinstance(n, ast.Call)
    and isinstance(n.func, ast.Attribute)
    and n.func.attr == "train"
    for n in ast.walk(fn)
)

print({
    "has_try_finally": has_try,
    "has_eval_call": has_eval,
    "has_train_restore_call": has_train_restore,
})
PY
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/model_quant.py` around lines 611 - 615, Wrap the
forward_loop call in a try/finally so hooks are always removed and the model's
training/eval state is restored if forward_loop raises: before calling
forward_loop(model) capture the model's current training mode (e.g., mode_before
= model.training), call model.eval() to run measurement, then in a finally block
call for h in hooks: h.remove() and restore the original mode via
model.train(mode_before); ensure this change is implemented in
compute_quantization_mse (and references to forward_loop and hooks/h.remove())
so cleanup and mode restoration always run even on exceptions.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@modelopt/torch/quantization/model_quant.py`:
- Around line 611-615: Wrap the forward_loop call in a try/finally so hooks are
always removed and the model's training/eval state is restored if forward_loop
raises: before calling forward_loop(model) capture the model's current training
mode (e.g., mode_before = model.training), call model.eval() to run measurement,
then in a finally block call for h in hooks: h.remove() and restore the original
mode via model.train(mode_before); ensure this change is implemented in
compute_quantization_mse (and references to forward_loop and hooks/h.remove())
so cleanup and mode restoration always run even on exceptions.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 6e3a0bb7-1a7c-4097-9edb-2f9807ef6afa

📥 Commits

Reviewing files that changed from the base of the PR and between 6e0fe4e and a386cde.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/model_quant.py

@meenchen meenchen force-pushed the weimingc/api_for_quantization_mse branch from 1923cbd to fe8d7e4 Compare March 4, 2026 23:58
meenchen added 5 commits March 4, 2026 23:59
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
@meenchen meenchen force-pushed the weimingc/api_for_quantization_mse branch from fe8d7e4 to 10ea8be Compare March 5, 2026 00:00
meenchen added 2 commits March 5, 2026 06:15
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
@codecov
Copy link

codecov bot commented Mar 5, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 72.14%. Comparing base (be6dfad) to head (28c7190).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #940      +/-   ##
==========================================
+ Coverage   72.08%   72.14%   +0.05%     
==========================================
  Files         209      209              
  Lines       23638    23667      +29     
==========================================
+ Hits        17040    17075      +35     
+ Misses       6598     6592       -6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@meenchen meenchen self-assigned this Mar 5, 2026
Signed-off-by: Wei-Ming Chen <17592131+meenchen@users.noreply.github.com>
@meenchen meenchen enabled auto-merge (squash) March 6, 2026 21:13
@meenchen meenchen merged commit 2242304 into main Mar 6, 2026
40 checks passed
@meenchen meenchen deleted the weimingc/api_for_quantization_mse branch March 6, 2026 22:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants