Skip to content

Commit 9fea797

Browse files
chochowskij-rauschkevalmorabia97
authored andcommitted
gpt-oss 20b support (#889)
## What does this PR do? Adds gpt-oss-20b support for puzzle any-model pruning. **Type of change:** <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> new feature **Overview:** adds descriptor, converter and yaml configuration files for expert removal. Introduces slight changes on conversion to account for mxfp4 quantized checkpoint of gpt-oss ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: mchochowski <mchochowski@nvidia.com> Signed-off-by: jrausch <jrausch@nvidia.com> Signed-off-by: chochowski <Marcin.Chochowski@gmail.com> Co-authored-by: J Rausch <38429553+j-rausch@users.noreply.github.com> Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
1 parent 5ce7362 commit 9fea797

File tree

19 files changed

+969
-79
lines changed

19 files changed

+969
-79
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ repos:
109109
examples/speculative_decoding/main.py|
110110
examples/speculative_decoding/medusa_utils.py|
111111
examples/speculative_decoding/server_generate.py|
112-
examples/puzzletron/evaluation/hf_deployable_anymodel\.py|
112+
examples/puzzletron/evaluation/lm_eval_anymodel.py|
113+
modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_pruned_to_mxfp4.py|
113114
modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_.*\.py|
114115
)$
115116

examples/puzzletron/GPTOSS.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
2+
## GptOss - 20b
3+
4+
With this release Puzzle algorithm supports only experts removal for `Gpt-Oss-20b`.
5+
6+
This model comes as a quantized checkpoint i.e. MoE experts matrices are quantized with _MXFP4_ format.
7+
In the prunning steps puzzle utilizes decompressed model (back to BF16) for statistics and scores computation.
8+
This means, during the conversion to puzzle format we decompress the model and store it as a BF16.
9+
Once the pruning is done i.e. experts to be removed are identified and the process is finished, user may want to get back the _MXFP4_ format of the checkpoint.
10+
To do so, there is an additional script, that takes the original and the pruned checkpoint and outputs pruned checkpoint in _MXFP4_ format.
11+
12+
```bash
13+
python -m modelopt.torch.puzzletron.anymodel.models.gpt_oss_20b.gpt_oss_pruned_to_mxfp4 --student-path /workspaces/any_model_gpt_oss_20b/mip/puzzle_solutions/stats_num_params_18014757184/solutions--checkpoints/solution_0/ --original-path /workspaces/source_model_checkpoints/openai_gpt-oss-20b/ --output-path /workspaces/any_model_gpt_oss_20b/mip/puzzle_solutions/stats_num_params_18014757184/solutions--checkpoints/mxfp4-ckpt/ --num-layers 24
14+
```

examples/puzzletron/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ The supported modifications are:
99

1010
To use the Puzzle algorithm effectively, we need to specify the target number of parameters and/or the memory. The final stage is based on Mixed-Integer Programming (MIP) algorithm to find the most optimal combination of layer modifications that satisfy the target requirements.
1111

12-
In this example, we compress the [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model reducing GPU memory usage from 113 GiB to 96 GiB (15% reduction) with less than 1% regression in the token_accuracy_top_10 metric.
12+
In this example, we compress the [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model reducing GPU memory usage from 113 GiB to 96 GiB (15% reduction) with less than 1% regression in the token_accuracy_top_10 metric. Other supported models should be compressed in a similar way. For GptOss there is one [additional step to be performed](GPTOSS.md).
1313

1414
> **Note:** Other models are also supported. See the [configs](./configs/) directory for additional model configurations (e.g., Llama-3.2-3B-Instruct on 1x H100, Qwen2.5-7B-Instruct on 1x H100, Qwen3-8B on 1x H100, Nemotron-Nano-12B-v2 on 1x H100, Mistral-Small-24B-Instruct-2501 on 4x H100). For information on adding support for new models, see the [AnyModel Guide](../../modelopt/torch/puzzletron/anymodel/README.md).
1515
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
defaults:
2+
- pruning: ffn_pruning
3+
- scoring: ../validate_solutions_defaults
4+
- realize_model: ../validate_solutions_defaults
5+
- bypass:
6+
- override hydra/hydra_logging: disabled
7+
- _self_
8+
9+
puzzle_dir: ???
10+
descriptor: gpt_oss_20b
11+
teacher_dir: ${puzzle_dir}/ckpts/teacher/
12+
replacement_library_path: ${puzzle_dir}/replacement_library.json
13+
dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2
14+
15+
skip_realize_model: false
16+
17+
build_replacement_library:
18+
add_ffn_no_ops: true
19+
add_attention_no_ops: true
20+
21+
calc_subblock_stats:
22+
batch_sizes: [64, 96, 128]
23+
prefill_seq_len: 4096
24+
generation_seq_len: 4096
25+
num_active_tokens_override: # Optional override for sequence lengths
26+
prefill_queue_size: 0
27+
allocate_prefill_query: false
28+
benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking
29+
merge_with_existing_stats: false
30+
subblock_stats_filename: "subblock_stats.json"
31+
moe_stats_filename: "moe_stats.json"
32+
runtime_stats:
33+
backend: trt_torch
34+
35+
scoring:
36+
descriptor: ${descriptor}
37+
solutions_to_validate:
38+
skip_existing_solutions: true
39+
40+
replacement_library_path: ${replacement_library_path}
41+
solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json}
42+
teacher_dir: ${to_path:${teacher_dir}}
43+
output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation
44+
45+
eval_samples: 128
46+
micro_batch_size: 1
47+
seed: 42
48+
shuffle_seed: 444
49+
dataset_path: ${dataset_path}
50+
51+
mip:
52+
single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}}
53+
subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}}
54+
output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions}
55+
gathered_metrics_path:
56+
puzzle_profile:
57+
58+
# puzzle_profile:
59+
objective: metrics.cosine_embedding_loss_hidden_states
60+
bigger_is_better: false
61+
62+
subblock_stats_args:
63+
- batch_size: 96
64+
weights_dtype: torch.bfloat16
65+
activations_dtype: torch.bfloat16
66+
kv_cache_dtype: torch.bfloat16
67+
68+
report_additional_costs:
69+
- stats.memory_mib
70+
- stats.num_params
71+
- stats.num_kv_heads
72+
- stats.has_attention
73+
- stats.has_ffn
74+
- stats.kv_cache_memory_mib
75+
- stats.attention_memory_mib
76+
- stats.ffn_memory_mib
77+
- stats.ffn_num_params
78+
- stats.attention_num_params
79+
80+
human_constraints:
81+
target_memory: 45_000
82+
num_params: 3_000_000_000
83+
84+
mip_constraints:
85+
metric_overrides:
86+
max_seconds_per_solution: 60
87+
88+
realize_model:
89+
descriptor: ${descriptor}
90+
teacher_dir: ${to_path:${teacher_dir}}
91+
tokenizer_name: ${to_path:${teacher_dir}}
92+
replacement_library_path: ${replacement_library_path}
93+
save_models: true
94+
solutions_path: # Filled dynamically
95+
96+
# Validate params
97+
skip_validation: false # To enable validation of the model solution set `skip_validation` as False
98+
eval_samples: 128
99+
micro_batch_size: 1
100+
seed: 42
101+
shuffle_seed: 444
102+
dataset_path: ${dataset_path}
103+
104+
nccl_timeout_minutes: ${timedelta_minutes:10}
105+
106+
# This section redirects Hydra outputs
107+
hydra:
108+
run:
109+
dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
110+
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
defaults:
2+
- gptoss-20b
3+
- _self_
4+
5+
# Input Hugging Face model to compress
6+
input_hf_model_path: /workspace/hf_models/openai/gpt-oss-20b
7+
8+
# Dataset path for pruning and NAS scoring
9+
dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2
10+
11+
# Working directory for compression outputs
12+
puzzle_dir: /workspace/puzzle_dir
13+
14+
# MIP memory constraint (in MiB)
15+
mip:
16+
human_constraints:
17+
target_memory: 16_000 # 45 GiB
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
defaults:
2+
- pruning_defaults
3+
4+
eval_samples: 2500 #10
5+
activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/expert_removal/${pruning.experiment_id}
6+
7+
pruning_mixin:
8+
_target_: modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin.ExpertRemovalPruningMixIn
9+
layer_descriptor:
10+
_target_: modelopt.torch.puzzletron.anymodel.models.gpt_oss_20b.gpt_oss_20b_model_descriptor.GptOss20bExpertRemovalLayerDescriptor
11+
target_name: "mlp.router"
12+
13+
hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.RankedChoiceVotingHook}
14+
activation_hooks_kwargs: # Additional kwargs to pass to the hook init
15+
16+
num_experts_to_keep_list: [24, 16, 8] # num_experts in teacher is 128
17+
mlp_init_mode: "ExpertRemoval"
18+
mlp_init_config_yaml:
19+
expert_scores_key: "expert_ranks"
20+
layer_prefix_template: "model.layers.{layer_idx}.mlp.router"
21+
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
defaults:
2+
- /validate_model_defaults
3+
4+
model_name_or_path: ${teacher_dir}
5+
experiment_id: ${pruning.eval_samples}samples_diverse_mini
6+
activations_log_dir: ???
7+
activation_hooks_kwargs: ???
8+
9+
descriptor: ${descriptor}
10+
11+
# Data:
12+
eval_samples: 10_000
13+
micro_batch_size: 1
14+
dataset_path: ${dataset_path}
15+
val_dataset_name: train
16+
17+
# Prune ckpts
18+
pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id}
19+
20+
## FFN pruning
21+
ffn_list:
22+
mlp_init_mode: "Truncate" # PruneByActivationsLog
23+
24+
## KV-heads pruning
25+
n_heads_in_group_list:
26+
gqa_init_mode: "AverageKV"
27+
28+
## Hidden dimension pruning
29+
hidden_size_list:
30+
hidden_size_init_mode: "PruneByChannelRanking"
31+
linear_init_mode: "FromTeacher"
32+
33+
mlp_init_config_yaml:
34+
activations_log_dir: ${pruning.activations_log_dir}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
model_dtype: torch.bfloat16 # dtype to cast the model for validate_model
2+
autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model
3+
block_size: 8192
4+
bos_rate: 0.5
5+
data_column: messages
6+
val_dataset_name: valid
7+
shuffle_seed: 81436
8+
seed: 42
9+
fim_rate: 0
10+
fim_spm_rate: 0
11+
source_datasets_to_discard:
12+
varlen: false
13+
write_results: false
14+
calc_losses_on_cpu: false
15+
activations_log_dir:
16+
model_name_or_path:
17+
load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn}
18+
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
defaults:
2+
- /validate_model_defaults
3+
- _self_
4+
5+
solutions_to_validate:
6+
skip_validation: false
7+
save_models: false
8+
bigger_is_better: false
9+
sort_solutions_by:
10+
calculate_full_score_ablations: false
11+

examples/puzzletron/evaluation/hf_deployable_anymodel.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,35 @@
1-
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
1+
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/tree/aa457edc3d64d81530159cd3a182932320c78f8c
2+
3+
# MIT License
4+
#
5+
# Copyright (c) 2020 EleutherAI
6+
#
7+
# Permission is hereby granted, free of charge, to any person obtaining a copy
8+
# of this software and associated documentation files (the "Software"), to deal
9+
# in the Software without restriction, including without limitation the rights
10+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11+
# copies of the Software, and to permit persons to whom the Software is
12+
# furnished to do so, subject to the following conditions:
13+
#
14+
# The above copyright notice and this permission notice shall be included in all
15+
# copies or substantial portions of the Software.
16+
#
17+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23+
# SOFTWARE.
24+
25+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
26+
# SPDX-License-Identifier: Apache-2.0
227
#
328
# Licensed under the Apache License, Version 2.0 (the "License");
429
# you may not use this file except in compliance with the License.
530
# You may obtain a copy of the License at
631
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
32+
# http://www.apache.org/licenses/LICENSE-2.0
833
#
934
# Unless required by applicable law or agreed to in writing, software
1035
# distributed under the License is distributed on an "AS IS" BASIS,
@@ -13,6 +38,7 @@
1338
# limitations under the License.
1439

1540

41+
import json
1642
import logging
1743
from typing import Any
1844

@@ -28,6 +54,11 @@
2854
from peft import PeftModel
2955
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
3056

57+
from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor_factory import (
58+
resolve_descriptor_from_pretrained,
59+
)
60+
from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher
61+
3162
try:
3263
from pytriton.decorators import batch
3364
from pytriton.model_config import Tensor
@@ -139,18 +170,12 @@ def _load(
139170
# Wraps model loading with deci_x_patcher for heterogeneous layer configs.
140171
# See: modelopt/torch/puzzletron/anymodel/puzzformer/utils.py
141172
# =========================================================================
142-
import os
143-
import sys
144173

145-
modelopt_workdir = os.environ.get("MODELOPT_WORKDIR") or os.environ.get(
146-
"PUZZLE_WORKDIR"
174+
descriptor = resolve_descriptor_from_pretrained(
175+
self.hf_model_id_path, trust_remote_code=hf_kwargs.get("trust_remote_code", False)
147176
)
148-
if modelopt_workdir and modelopt_workdir not in sys.path:
149-
sys.path.insert(0, modelopt_workdir)
150-
from modelopt.torch.puzzletron.anymodel.models.llama import LlamaModelDescriptor
151-
from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher
152177

153-
with deci_x_patcher(model_descriptor=LlamaModelDescriptor):
178+
with deci_x_patcher(model_descriptor=descriptor):
154179
self.model = AutoModelForCausalLM.from_pretrained(
155180
self.hf_model_id_path,
156181
torch_dtype=torch_dtype,
@@ -587,8 +612,6 @@ def ray_infer_fn(self, inputs: dict[Any, Any]):
587612
- log_probs: Optional list of log probabilities if compute_logprob is True
588613
- top_logprobs: Optional list of top log probabilities if n_top_logprobs > 0
589614
"""
590-
import json
591-
592615
try:
593616
prompts = inputs.pop("prompts")
594617
temperature = inputs.pop("temperature", 1.0)

0 commit comments

Comments
 (0)