API to measure MSE for target quantizers#940
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a public API Changes
Sequence Diagram(s)sequenceDiagram
participant Caller as Caller
participant Model as Model
participant Hooks as ForwardHooks
participant Quantizers as TensorQuantizerModules
participant Accumulator as MSEAccumulator
Caller->>Model: compute_quantization_mse(model, forward_loop, wildcards)
Model->>Model: select quantizers by wildcard/callable
Model->>Hooks: register forward hooks on matched quantizers
loop forward_loop iterations
Caller->>Model: run one forward iteration
Model->>Quantizers: forward through hooked modules
Quantizers->>Hooks: invoke hook with (orig_tensor, quantized_tensor)
Hooks->>Accumulator: compute per-element MSE and accumulate (sum, count)
end
Model->>Hooks: remove all registered hooks
Accumulator->>Caller: return dict{name -> mean MSE}
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tests/unit/torch/quantization/test_compute_quantization_mse.py (1)
128-143: Add hook-cleanup coverage for exception path.Current cleanup assertion only covers success. Add a test where
forward_loopraises and verify hook counts are restored.Proposed test addition
+ def test_hooks_removed_when_forward_loop_raises(self): + model, _ = _make_quantized_model() + sample = model.get_input() + + hooks_before = sum( + len(m._forward_hooks) + for m in model.modules() + if isinstance(m, TensorQuantizer) + ) + + def failing_forward_loop(m): + m(sample) + raise RuntimeError("expected failure") + + with pytest.raises(RuntimeError): + mtq.compute_quantization_mse(model, failing_forward_loop) + + hooks_after = sum( + len(m._forward_hooks) + for m in model.modules() + if isinstance(m, TensorQuantizer) + ) + assert hooks_after == hooks_before🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/torch/quantization/test_compute_quantization_mse.py` around lines 128 - 143, Add a new test that mirrors test_hooks_removed_after_call but exercises the exception path: use _make_quantized_model() to get model and a forward_loop that raises (e.g., a callable that raises an exception when invoked), record hooks_before by summing len(m._forward_hooks) for m in model.modules() if isinstance(m, TensorQuantizer), call mtq.compute_quantization_mse(model, forward_loop) inside a try/except (catch the expected exception), then recompute hooks_after the same way and assert hooks_after == hooks_before to verify hooks are cleaned up on error; use the same TensorQuantizer and compute_quantization_mse identifiers so the test targets the same cleanup logic.modelopt/torch/quantization/model_quant.py (1)
570-596: Short-circuit when no quantizers matchwildcards.If no hooks are registered, the function still executes the full forward loop and then returns
{}. Early return avoids unnecessary compute and side effects.Proposed fix
- forward_loop(model) + if not hooks: + return {} + + forward_loop(model)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/quantization/model_quant.py` around lines 570 - 596, The loop registers hooks into TensorQuantizer modules but if none match the wildcards we should short-circuit to avoid running forward_loop unnecessarily: after building hooks/accumulators (the variables named hooks and accumulators in compute_quantization_mse), check if hooks is empty, log a brief message with print_rank_0 and return {} immediately; otherwise proceed to call forward_loop(model) and the existing cleanup/aggregation logic.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/quantization/model_quant.py`:
- Around line 595-599: The forward pass call forward_loop(model) can raise and
currently skips cleanup; wrap the forward_loop invocation in a try/finally so
you always remove all hooks in hooks (each h.remove()) and restore the model's
original training/eval state; also set model.eval() before running forward_loop
and after the try/finally restore model.train() if it was originally in training
mode (use a saved flag like was_training) to guarantee both hook removal and
state restoration even on exceptions.
---
Nitpick comments:
In `@modelopt/torch/quantization/model_quant.py`:
- Around line 570-596: The loop registers hooks into TensorQuantizer modules but
if none match the wildcards we should short-circuit to avoid running
forward_loop unnecessarily: after building hooks/accumulators (the variables
named hooks and accumulators in compute_quantization_mse), check if hooks is
empty, log a brief message with print_rank_0 and return {} immediately;
otherwise proceed to call forward_loop(model) and the existing
cleanup/aggregation logic.
In `@tests/unit/torch/quantization/test_compute_quantization_mse.py`:
- Around line 128-143: Add a new test that mirrors test_hooks_removed_after_call
but exercises the exception path: use _make_quantized_model() to get model and a
forward_loop that raises (e.g., a callable that raises an exception when
invoked), record hooks_before by summing len(m._forward_hooks) for m in
model.modules() if isinstance(m, TensorQuantizer), call
mtq.compute_quantization_mse(model, forward_loop) inside a try/except (catch the
expected exception), then recompute hooks_after the same way and assert
hooks_after == hooks_before to verify hooks are cleaned up on error; use the
same TensorQuantizer and compute_quantization_mse identifiers so the test
targets the same cleanup logic.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/torch/quantization/model_quant.pytests/unit/torch/quantization/test_compute_quantization_mse.py
| forward_loop(model) | ||
|
|
||
| for h in hooks: | ||
| h.remove() | ||
|
|
There was a problem hiding this comment.
Preserve model state and always remove hooks on failure.
At Line 595, an exception in forward_loop(model) skips cleanup, leaving stale hooks registered. Run in eval mode and use try/finally so cleanup and state restoration are guaranteed.
Proposed fix
- forward_loop(model)
-
- for h in hooks:
- h.remove()
+ is_training = model.training
+ model.eval()
+ try:
+ forward_loop(model)
+ finally:
+ for h in hooks:
+ h.remove()
+ model.train(is_training)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| forward_loop(model) | |
| for h in hooks: | |
| h.remove() | |
| is_training = model.training | |
| model.eval() | |
| try: | |
| forward_loop(model) | |
| finally: | |
| for h in hooks: | |
| h.remove() | |
| model.train(is_training) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/quantization/model_quant.py` around lines 595 - 599, The
forward pass call forward_loop(model) can raise and currently skips cleanup;
wrap the forward_loop invocation in a try/finally so you always remove all hooks
in hooks (each h.remove()) and restore the model's original training/eval state;
also set model.eval() before running forward_loop and after the try/finally
restore model.train() if it was originally in training mode (use a saved flag
like was_training) to guarantee both hook removal and state restoration even on
exceptions.
There was a problem hiding this comment.
Pull request overview
Adds a user-facing API to measure per-quantizer mean-squared error (MSE) by running a provided forward_loop, plus unit tests validating wildcard/filter behavior and cleanup semantics.
Changes:
- Introduce
mtq.compute_quantization_mse(model, forward_loop, wildcards=...)to compute per-quantizer MSE during fake-quant. - Export the new API from the quantization public surface (
model_quant.__all__/mtqnamespace). - Add unit tests covering wildcard matching, callable filters, disabled quantizer handling, parameter immutability, and hook cleanup.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
modelopt/torch/quantization/model_quant.py |
Adds the new compute_quantization_mse API and exports it publicly. |
tests/unit/torch/quantization/test_compute_quantization_mse.py |
Adds unit tests for MSE computation, filtering, and hook cleanup. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| print_rank_0( | ||
| f"[compute_quantization_mse] Skipping {name}: " | ||
| f"_if_quant={module._if_quant}, _fake_quant={module._fake_quant}" | ||
| ) |
There was a problem hiding this comment.
compute_quantization_mse prints a message for every matched quantizer that is not in fake-quant mode. For large models this can be very noisy and introduces an unexpected stdout side-effect for a measurement API. Consider removing this print or gating it behind an explicit verbose flag / debug logging.
| print_rank_0( | |
| f"[compute_quantization_mse] Skipping {name}: " | |
| f"_if_quant={module._if_quant}, _fake_quant={module._fake_quant}" | |
| ) |
a5b775e to
6e0fe4e
Compare
There was a problem hiding this comment.
♻️ Duplicate comments (1)
modelopt/torch/quantization/model_quant.py (1)
610-613:⚠️ Potential issue | 🟠 MajorGuard hook cleanup and mode restoration with
try/finally.At Line 610, an exception in
forward_loop(model)skips hook removal, so stale hooks can leak into later runs. Also preserve and restore the model’s training state around measurement.Proposed fix
- forward_loop(model) - - for h in hooks: - h.remove() + was_training = model.training + model.eval() + try: + forward_loop(model) + finally: + for h in hooks: + h.remove() + model.train(was_training)#!/bin/bash # Verify cleanup/mode-restore guards exist around compute_quantization_mse forward loop. rg -n -C4 'def compute_quantization_mse|forward_loop\(model\)|try:|finally:|h\.remove\(|model\.eval\(|model\.train\(' \ modelopt/torch/quantization/model_quant.py \ tests/unit/torch/quantization/test_compute_quantization_mse.py🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/quantization/model_quant.py` around lines 610 - 613, compute_quantization_mse currently calls forward_loop(model) and then removes hooks, but an exception in forward_loop can skip hook cleanup and model mode restoration; wrap the forward_loop invocation and subsequent cleanup in a try/finally block that ensures for each hook in hooks you call h.remove() and that the model's original training state (save current model.training) is restored with model.train() or model.eval() as appropriate; locate the hooks list and forward_loop(model) call in compute_quantization_mse and implement the try/finally to guarantee h.remove() and restore model.training even on errors.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@modelopt/torch/quantization/model_quant.py`:
- Around line 610-613: compute_quantization_mse currently calls
forward_loop(model) and then removes hooks, but an exception in forward_loop can
skip hook cleanup and model mode restoration; wrap the forward_loop invocation
and subsequent cleanup in a try/finally block that ensures for each hook in
hooks you call h.remove() and that the model's original training state (save
current model.training) is restored with model.train() or model.eval() as
appropriate; locate the hooks list and forward_loop(model) call in
compute_quantization_mse and implement the try/finally to guarantee h.remove()
and restore model.training even on errors.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 27311daf-f499-4f92-8b43-e7f93acfbea5
📒 Files selected for processing (3)
CHANGELOG.rstmodelopt/torch/quantization/model_quant.pytests/unit/torch/quantization/test_compute_quantization_mse.py
There was a problem hiding this comment.
♻️ Duplicate comments (1)
modelopt/torch/quantization/model_quant.py (1)
611-615:⚠️ Potential issue | 🟠 MajorGuarantee hook cleanup (and mode restoration) when
forward_loopfails.At Line 611, an exception from
forward_loop(model)skips cleanup, leaving stale hooks registered. Wrap execution intry/finally; also run in eval mode and restore original train/eval state to avoid side effects during measurement.Proposed fix
- forward_loop(model) - - for h in hooks: - h.remove() + was_training = model.training + model.eval() + try: + forward_loop(model) + finally: + for h in hooks: + h.remove() + model.train(was_training)#!/bin/bash set -euo pipefail TARGET="$(fd '^model_quant\.py$' | head -n1)" export TARGET python - <<'PY' import ast, os, pathlib p = pathlib.Path(os.environ["TARGET"]) tree = ast.parse(p.read_text()) fn = next( n for n in tree.body if isinstance(n, ast.FunctionDef) and n.name == "compute_quantization_mse" ) has_try = any(isinstance(n, ast.Try) for n in ast.walk(fn)) has_eval = any( isinstance(n, ast.Call) and isinstance(n.func, ast.Attribute) and n.func.attr == "eval" for n in ast.walk(fn) ) has_train_restore = any( isinstance(n, ast.Call) and isinstance(n.func, ast.Attribute) and n.func.attr == "train" for n in ast.walk(fn) ) print({ "has_try_finally": has_try, "has_eval_call": has_eval, "has_train_restore_call": has_train_restore, }) PY🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/quantization/model_quant.py` around lines 611 - 615, Wrap the forward_loop call in a try/finally so hooks are always removed and the model's training/eval state is restored if forward_loop raises: before calling forward_loop(model) capture the model's current training mode (e.g., mode_before = model.training), call model.eval() to run measurement, then in a finally block call for h in hooks: h.remove() and restore the original mode via model.train(mode_before); ensure this change is implemented in compute_quantization_mse (and references to forward_loop and hooks/h.remove()) so cleanup and mode restoration always run even on exceptions.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@modelopt/torch/quantization/model_quant.py`:
- Around line 611-615: Wrap the forward_loop call in a try/finally so hooks are
always removed and the model's training/eval state is restored if forward_loop
raises: before calling forward_loop(model) capture the model's current training
mode (e.g., mode_before = model.training), call model.eval() to run measurement,
then in a finally block call for h in hooks: h.remove() and restore the original
mode via model.train(mode_before); ensure this change is implemented in
compute_quantization_mse (and references to forward_loop and hooks/h.remove())
so cleanup and mode restoration always run even on exceptions.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 6e3a0bb7-1a7c-4097-9edb-2f9807ef6afa
📒 Files selected for processing (1)
modelopt/torch/quantization/model_quant.py
1923cbd to
fe8d7e4
Compare
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
fe8d7e4 to
10ea8be
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #940 +/- ##
==========================================
+ Coverage 72.08% 72.14% +0.05%
==========================================
Files 209 209
Lines 23638 23667 +29
==========================================
+ Hits 17040 17075 +35
+ Misses 6598 6592 -6 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Wei-Ming Chen <17592131+meenchen@users.noreply.github.com>
What does this PR do?
Type of change: new feature ?
Overview: add an API to measure MSE for target quantizers given a forward loop
Usage
Testing
Unit test and test with HF PTQ
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Tests
Documentation