Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,10 +631,14 @@ def _prefix_wildcard_summarize_exclude_modules(unquantized_layers, quantized_lay
"""

def all_matching_prefix_wildcards(name):
# include all possible prefix wildcards, and the exact name itself
# Include the exact name and prefix wildcards at segment boundaries only (at each '.'),
# At each boundary add both "prefix*" and "prefix.*"
wildcards = {name}
for i in range(len(name) + 1):
wildcards.add(name[:i] + "*")
for i in range(len(name)):
if name[i] == ".":
wildcards.add(name[:i] + "*")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why we need adding it?

wildcards.add(name[: i + 1] + "*")
wildcards.add(name + "*")
return wildcards

def next_formatted_matching_prefix_wildcards(name: str) -> Generator[list[str], None, None]:
Expand Down Expand Up @@ -782,11 +786,10 @@ def process_layer_quant_config(layer_config_dict):
per_layer_config["quant_algo"] = "MIXED_PRECISION"
elif len(quantization_formats) == 1 and quantization_config is not None:
per_layer_config.update(quantization_config)
per_layer_config["exclude_modules"] = sorted(
_prefix_wildcard_summarize_exclude_modules(
exclude_modules, per_layer_config["quantized_layers"].keys()
)
summarized = _prefix_wildcard_summarize_exclude_modules(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a no-op change?

exclude_modules, per_layer_config["quantized_layers"].keys()
)
per_layer_config["exclude_modules"] = sorted(summarized)
per_layer_config.pop("quantized_layers")

return per_layer_config
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@
},
"algorithm": "max",
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove

MAMBA_MOE_NVFP4_CONSERVATIVE_CFG = {
"quant_cfg": {
"*weight_quantizer": {
Expand Down
62 changes: 62 additions & 0 deletions modelopt/torch/quantization/plugins/huggingface.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my understanding, we dont need explicit registration for attention (it is caught by the attention AST patching). Is that correct?

Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,47 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return output


class _QuantNemotronHMOE(QuantModule):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this - does this work -

def register_sparse_moe_on_the_fly(model):

"""Quantized MoE module for NemotronH (Nano/Super) with expert amax sync.

Synchronizes activation quantizer amax across local experts in a layer via
layer_sync_moe_local_experts_amax(), which is called by the calibration pipeline
(model_calib.max_calibrate()) so that all experts share the same amax before
distributed sync. Weight quantizers are left unchanged.
"""

def _setup(self):
pass

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just moving what we have in super_ptq.py to here, can discuss more if we want to align this with other MoE Quantization behavior in ModelOPT

return super().forward(hidden_states)

def layer_sync_moe_local_experts_amax(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@realAsma is this still necessary?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is helpful for advanced algorithms. This method only sync the input quantizer amax - the correctly syncd input quantizer amax is required for MSE/GPTQ algorithms.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Fridah-nv could you please check the implementation of the supported MoE layers

class _QuantSparseMoe(QuantModule):
and see whether we can implement this function to the _QuantSparseMoe base class?

"""Sync activation quantizer amax across local experts in this MoE layer.

Only activation quantizers are synchronized; weight quantizers are unchanged.
"""
if not hasattr(self, "experts"):
return
amax_dict = {}
for expert in self.experts:
for name, mod in expert.named_modules():
if "weight_quantizer" in name:
continue
if isinstance(mod, TensorQuantizer) and mod.amax is not None:
stored = amax_dict.get(name)
amax_tensor = mod.amax.detach().clone()
amax_dict[name] = (
amax_tensor if stored is None else torch.maximum(stored, amax_tensor)
)
for expert in self.experts:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you iterate through the key/values in the amax_dict instead of through the entire expert modules instead?

for name, mod in expert.named_modules():
if "weight_quantizer" in name:
continue
if isinstance(mod, TensorQuantizer) and mod.amax is not None and name in amax_dict:
mod.amax = amax_dict[name].detach().clone().to(mod.amax.device)


class _QuantLlama4TextExperts(QuantModule):
def _setup(self):
self.gate_up_proj_input_quantizer = TensorQuantizer()
Expand Down Expand Up @@ -982,6 +1023,26 @@ def register_dbrx_moe_on_the_fly(model):
QuantModuleRegistry.register({moe_type: moe_type.__name__})(_QuantDbrxExpertGLU)


def register_nemotron_h_moe_on_the_fly(model):
"""Register NemotronH MoE modules (Nano/Super) as _QuantNemotronHMOE.

NemotronH MoE is used in NVIDIA Nemotron-3-Nano and Super architectures and may be
loaded via trust_remote_code with a class named NemotronHMOE.
"""
visited_types = set()
for name, module in model.named_modules():
mod_type = type(module)
if mod_type in visited_types or QuantModuleRegistry.get(mod_type) is not None:
continue
if mod_type.__name__ == "NemotronHMOE":
visited_types.add(mod_type)
print(
f"\033[1mDetected NemotronH MOE module '{name}', "
f"registering with _QuantNemotronHMOE.\033[0m"
)
QuantModuleRegistry.register({mod_type: f"hf.{mod_type.__name__}"})(_QuantNemotronHMOE)


def register_falcon_linears_on_the_fly(model):
"""Register Falcon linear modules as a QUANT_MODULE.

Expand Down Expand Up @@ -1112,6 +1173,7 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model):
[
register_falcon_linears_on_the_fly,
register_dbrx_moe_on_the_fly,
register_nemotron_h_moe_on_the_fly,
register_sparse_moe_on_the_fly,
register_hf_attentions_on_the_fly,
convert_hf_parallel_linears_on_the_fly,
Expand Down