Skip to content

Commit 1923cbd

Browse files
committed
address comments
1 parent a386cde commit 1923cbd

File tree

1 file changed

+6
-12
lines changed

1 file changed

+6
-12
lines changed

modelopt/torch/quantization/model_quant.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from modelopt.torch.opt.utils import forward_with_reshard
3131
from modelopt.torch.quantization.config import QuantizeConfig
3232
from modelopt.torch.quantization.conversion import set_quantizer_by_cfg
33-
from modelopt.torch.utils import atomic_print, print_rank_0
33+
from modelopt.torch.utils import atomic_print
3434

3535
from .algorithms import AutoQuantizeGradientSearcher, AutoQuantizeKLDivSearcher, QuantRecipe
3636
from .config import QuantizeAlgoCfgType
@@ -588,13 +588,6 @@ def _matches(name: str) -> bool:
588588
continue
589589
if not _matches(name):
590590
continue
591-
if not (module._if_quant and module._fake_quant):
592-
print_rank_0(
593-
f"[compute_quantization_mse] Skipping {name}: "
594-
f"_if_quant={module._if_quant}, _fake_quant={module._fake_quant}"
595-
)
596-
continue
597-
598591
accumulators[name] = {"sum": 0.0, "count": 0}
599592

600593
def _make_hook(acc):
@@ -608,10 +601,11 @@ def hook(mod, inp, out):
608601

609602
hooks.append(module.register_forward_hook(_make_hook(accumulators[name])))
610603

611-
forward_loop(model)
612-
613-
for h in hooks:
614-
h.remove()
604+
try:
605+
forward_loop(model)
606+
finally:
607+
for h in hooks:
608+
h.remove()
615609

616610
return {
617611
name: acc["sum"] / acc["count"] for name, acc in accumulators.items() if acc["count"] > 0

0 commit comments

Comments
 (0)