Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 211 additions & 10 deletions examples/megatron_bridge/prune_minitron.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@
from megatron.bridge import AutoBridge
from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider
from megatron.bridge.models.nemotronh.nemotron_h_provider import NemotronHModelProvider
from transformers import AutoConfig, AutoModelForCausalLM
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoProcessor

try:
from transformers import AutoModelForImageTextToText
except ImportError:
AutoModelForImageTextToText = None

import modelopt.torch.opt as mto
import modelopt.torch.prune as mtp
Expand All @@ -50,14 +55,90 @@
from modelopt.torch.utils.plugins.mbridge import (
get_hf_mbridge_calibration_loop,
load_mbridge_model_from_hf,
resolve_prunable_backbone,
)
from modelopt.torch.utils.plugins.megatron_mmlu import megatron_mmlu
from modelopt.torch.utils.plugins.megatron_mmmu import megatron_mmmu
from modelopt.torch.utils.vlm_dataset_utils import get_supported_vlm_datasets

SUPPORTED_ARCH_SUFFIXES = (
"ForCausalLM",
"ForConditionalGeneration",
"NemotronH_Nano_VL_V2",
)


def _create_dummy_hf_model(hf_cfg, *, is_vlm_wrapper: bool, trust_remote_code: bool):
"""Create a dummy HF model from config for bridge-assisted weight export."""
auto_classes = [AutoModelForCausalLM]
if is_vlm_wrapper and AutoModelForImageTextToText is not None:
auto_classes.insert(0, AutoModelForImageTextToText)
auto_classes.append(AutoModel)
for auto_cls in auto_classes:
try:
return auto_cls.from_config(hf_cfg, trust_remote_code=trust_remote_code)
except (ValueError, KeyError):
continue
return None


def _ensure_bridge_supported_architectures(
hf_cfg,
*,
bridge: AutoBridge,
is_vlm_wrapper: bool,
) -> None:
"""Ensure exported HF config uses architecture suffixes accepted by AutoBridge."""
cfg_arches = list(getattr(hf_cfg, "architectures", []) or [])
if not cfg_arches or any(arch.endswith(SUPPORTED_ARCH_SUFFIXES) for arch in cfg_arches):
return

source_arches = list(
getattr(getattr(getattr(bridge, "hf_pretrained", None), "config", None), "architectures", [])
or []
)
if any(arch.endswith(SUPPORTED_ARCH_SUFFIXES) for arch in source_arches):
hf_cfg.architectures = source_arches
return

target_suffix = "ForConditionalGeneration" if is_vlm_wrapper else "ForCausalLM"
hf_cfg.architectures = [
f"{arch[:-len('Model')]}{target_suffix}" if arch.endswith("Model") else arch
for arch in cfg_arches
]


def _set_cfg_attr(cfg, name: str, value) -> None:
"""Set config field for both object-like and dict-like nested configs."""
if cfg is None:
return
if isinstance(cfg, dict):
cfg[name] = value
return
setattr(cfg, name, value)


def _sync_language_cfg_from_mcore(hf_cfg, mcore_cfg) -> None:
"""Keep nested language configs aligned with pruned Megatron config."""
nested_cfg_names = ("text_config", "language_config", "llm_config")
for nested_name in nested_cfg_names:
nested_cfg = getattr(hf_cfg, nested_name, None)
if nested_cfg is None:
continue
_set_cfg_attr(nested_cfg, "hidden_size", mcore_cfg.hidden_size)
_set_cfg_attr(nested_cfg, "intermediate_size", mcore_cfg.ffn_hidden_size)
_set_cfg_attr(nested_cfg, "num_attention_heads", mcore_cfg.num_attention_heads)
_set_cfg_attr(nested_cfg, "head_dim", mcore_cfg.kv_channels)
_set_cfg_attr(nested_cfg, "num_key_value_heads", mcore_cfg.num_query_groups)
_set_cfg_attr(nested_cfg, "num_hidden_layers", mcore_cfg.num_layers)


def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--hf_model_name_or_path", type=str, required=True)
parser.add_argument("--trust_remote_code", action="store_true")
supported_text_datasets = get_supported_datasets()
supported_vlm_datasets = get_supported_vlm_datasets()

target_group = parser.add_mutually_exclusive_group(required=True)
target_group.add_argument(
Expand Down Expand Up @@ -89,8 +170,19 @@ def get_args() -> argparse.Namespace:
"--calib_dataset_name",
type=str,
default="nemotron-post-training-dataset-v2",
choices=get_supported_datasets(),
help="Dataset name for calibration",
help=(
"Dataset name for calibration. "
f"Text calibration datasets: {supported_text_datasets}. "
f"Image-text calibration datasets (--calib_with_images): {supported_vlm_datasets}."
),
)
parser.add_argument(
"--calib_with_images",
action="store_true",
help=(
"Use image+text calibration pipeline (for VLM pruning). "
"When enabled, calibration runs multimodal forward passes using the full VLM wrapper."
),
)
parser.add_argument(
"--calib_num_samples", type=int, default=1024, help="Number of samples for calibration"
Expand Down Expand Up @@ -137,11 +229,23 @@ def get_args() -> argparse.Namespace:
parser.add_argument(
"--prune_score_func",
type=str,
choices=["mmlu_5pct"],
default="mmlu_5pct",
choices=["mmlu_5pct", "mmmu_5pct"],
default=None,
help=(
"Score function to use for NAS-based pruning (--prune_target_params). Currently supported: "
"mmlu_5pct (MMLU on 5% sampled data per subject for faster eval). "
"mmlu_5pct (MMLU on 5% sampled data per subject for faster eval), "
"mmmu_5pct (MMMU on 5% sampled data per subject, recommended for VLM models). "
"If not specified, defaults to mmlu_5pct for text-only calibration and "
"mmmu_5pct for image-text calibration (--calib_with_images)."
),
)
parser.add_argument(
"--mmmu_few_shots",
type=int,
default=0,
help=(
"Number of dev examples to prepend as few-shot context in MMMU scoring "
"(used only when --prune_score_func mmmu_5pct)."
),
)
parser.add_argument(
Expand Down Expand Up @@ -206,6 +310,9 @@ def get_args() -> argparse.Namespace:
raise ValueError("--prune_export_config must parse to a dictionary.")
args.prune_export_config = prune_export_config

if args.prune_score_func is None:
args.prune_score_func = "mmmu_5pct" if args.calib_with_images else "mmlu_5pct"

print_rank_0("\n==================== Arguments ====================")
for k, v in args.__dict__.items():
print_rank_0(f"{k:<35} {v}")
Expand Down Expand Up @@ -239,6 +346,51 @@ def main(args: argparse.Namespace):
},
init_model_parallel=True,
)
# Keep full wrapper for VLM-aware score functions (MMMU image+text).
full_unwrapped_model = unwrapped_model
# For wrapper models that expose a GPT-style language_model, prune only the language model.
# Non-language components (e.g., vision towers) are preserved.
unwrapped_model, is_vlm_wrapper, wrapper_name = resolve_prunable_backbone(unwrapped_model)
if is_vlm_wrapper and wrapper_name is not None:
print_rank_0(
f"Detected VLM wrapper ({wrapper_name}), "
f"extracting language_model ({type(unwrapped_model).__name__}) for pruning"
)
supported_text_datasets = set(get_supported_datasets())
supported_vlm_datasets = set(get_supported_vlm_datasets())
if args.calib_with_images:
if not is_vlm_wrapper:
raise ValueError(
"--calib_with_images is only supported for VLM wrapper models "
"(models exposing `.language_model`)."
)
if args.calib_dataset_name not in supported_vlm_datasets:
raise ValueError(
f"Unsupported VLM calibration dataset: {args.calib_dataset_name}. "
f"Supported VLM datasets: {sorted(supported_vlm_datasets)}"
)
elif args.calib_dataset_name not in supported_text_datasets:
raise ValueError(
f"Unsupported text calibration dataset: {args.calib_dataset_name}. "
f"Supported text datasets: {sorted(supported_text_datasets)}. "
"For multimodal datasets, pass --calib_with_images."
)
if args.prune_score_func == "mmmu_5pct" and not is_vlm_wrapper:
raise ValueError(
"--prune_score_func mmmu_5pct requires a VLM wrapper model "
"(model exposing `.language_model`)."
)
if is_vlm_wrapper and args.prune_score_func == "mmmu_5pct":
# For VLM wrappers, multimodal projector outputs are tied to the original hidden_size.
# If hidden_size is pruned, vision embeddings and language hidden states mismatch.
if "hidden_size" not in args.hparams_to_skip:
warn_rank_0(
"VLM MMMU(image+text) scoring requires fixed hidden_size to keep "
"vision-language embedding dimensions aligned. "
"Adding 'hidden_size' to --hparams_to_skip."
)
args.hparams_to_skip.append("hidden_size")

print_rank_0(f"\nPruning {unwrapped_model=}")
print_rank_0(
f"Original model params: {num2hrb(mtp.mcore_minitron.get_mcore_param_count(unwrapped_model))}"
Expand All @@ -254,6 +406,7 @@ def main(args: argparse.Namespace):
num_samples=args.calib_num_samples,
micro_batch_size=args.calib_mbs,
global_batch_size=args.calib_gbs,
calib_with_images=args.calib_with_images,
)

pruning_config = {
Expand All @@ -279,8 +432,38 @@ def main(args: argparse.Namespace):

def score_func_mmlu(m):
return megatron_mmlu(m, tokenizer, percentage=0.05)
if args.prune_score_func == "mmlu_5pct":
if is_vlm_wrapper:
warn_rank_0(
"Using mmlu_5pct for a VLM model. "
"For multimodal capability-aware scoring, prefer --prune_score_func mmmu_5pct."
)
pruning_config["score_func"] = score_func_mmlu
elif args.prune_score_func == "mmmu_5pct":
mmmu_processor = AutoProcessor.from_pretrained(
args.hf_model_name_or_path, trust_remote_code=args.trust_remote_code
)

pruning_config["score_func"] = score_func_mmlu
def score_func_mmmu(m):
# MMMU image+text scoring requires the full VLM wrapper.
eval_model = m
if hasattr(full_unwrapped_model, "language_model"):
full_unwrapped_model.language_model = m
eval_model = full_unwrapped_model
return megatron_mmmu(
eval_model,
processor=mmmu_processor,
few_shots=args.mmmu_few_shots,
percentage=0.05,
use_images=True,
)

pruning_config["score_func"] = score_func_mmmu
else:
raise ValueError(
f"Unsupported --prune_score_func: {args.prune_score_func}. "
"Supported values are: ['mmlu_5pct', 'mmmu_5pct']."
)
pruning_config["max_width_pruning"] = args.max_width_pruning
pruning_config["max_depth_pruning"] = args.max_depth_pruning
pruning_config["hparams_to_skip"] = args.hparams_to_skip
Expand Down Expand Up @@ -349,6 +532,7 @@ def score_func_mmlu(m):
hf_cfg.num_attention_heads = mcore_cfg.num_attention_heads
hf_cfg.head_dim = mcore_cfg.kv_channels
hf_cfg.num_key_value_heads = mcore_cfg.num_query_groups
_sync_language_cfg_from_mcore(hf_cfg, mcore_cfg)
if hasattr(hf_cfg, "mamba_num_heads"):
hf_cfg.mamba_num_heads = mcore_cfg.mamba_num_heads
if hasattr(hf_cfg, "mamba_head_dim"):
Expand Down Expand Up @@ -377,9 +561,26 @@ def score_func_mmlu(m):
hf_cfg.num_hidden_layers = mcore_cfg.num_layers

# Save dummy pruned HF model to get the correct bridge for saving pruned weights
AutoModelForCausalLM.from_config(
hf_cfg, trust_remote_code=args.trust_remote_code
).save_pretrained(args.output_hf_path, trust_remote_code=args.trust_remote_code)
dummy_model = _create_dummy_hf_model(
hf_cfg,
is_vlm_wrapper=is_vlm_wrapper,
trust_remote_code=args.trust_remote_code,
)
assert dummy_model is not None, f"Failed to create dummy model from config: {hf_cfg}"
dummy_model.save_pretrained(args.output_hf_path, trust_remote_code=args.trust_remote_code)

# AutoBridge validates config.architectures and only accepts a subset of suffixes.
# Some multimodal configs can end up with '*Model' after dummy export.
before_arches = list(getattr(hf_cfg, "architectures", []) or [])
_ensure_bridge_supported_architectures(
hf_cfg,
bridge=bridge,
is_vlm_wrapper=is_vlm_wrapper,
)
after_arches = list(getattr(hf_cfg, "architectures", []) or [])
if after_arches != before_arches:
hf_cfg.save_pretrained(args.output_hf_path)

pruned_bridge = AutoBridge.from_hf_pretrained(
args.output_hf_path, trust_remote_code=args.trust_remote_code
)
Expand Down
Loading
Loading