diff --git a/docs/source/guides/5_speculative_decoding.rst b/docs/source/guides/5_speculative_decoding.rst index 001808e35..40f8655e6 100644 --- a/docs/source/guides/5_speculative_decoding.rst +++ b/docs/source/guides/5_speculative_decoding.rst @@ -2,127 +2,59 @@ Speculative Decoding ==================== -Introduction -============ - ModelOpt's Speculative Decoding module (:mod:`modelopt.torch.speculative `) -enables your model to generate multiple tokens in each generate step. This can be useful for reducing the -latency of your model and speeds up inference. - -Below are the speculative decoding algorithms supported by ModelOpt: -- Medusa -- EAGLE - - -Follow the steps described below to obtain a model with Medusa or EAGLE speculative decoding using ModelOpt's -Speculative Decoding module :mod:`modelopt.torch.speculative`: - -#. **Convert your model via** :meth:`mtsp.convert `: - Add Medusa heads or EAGLE module to your model. -#. **Fine-tune Medusa heads or EAGLE module**: Fine-tune the Medusa heads or EAGLE module. - The base model is recommended to be frozen. -#. **Checkpoint and re-load**: Save the model via :meth:`mto.save ` and - restore via :meth:`mto.restore ` - -.. _speculative_conversion: - -Convert -======= - -You can convert your model to a speculative decoding model using :meth:`mtsp.convert() -`. - -Example usage: - -.. code-block:: python +enables your model to generate multiple tokens in each generate step, reducing inference latency. - import torch - from transformers import AutoModelForCausalLM, AutoTokenizer - import modelopt.torch.speculative as mtsp +ModelOpt implements the **EAGLE3** algorithm, which attaches a lightweight autoregressive draft +module to a frozen base model. The draft module operates at the feature level—predicting future +hidden states rather than tokens directly—to achieve high acceptance rates at low compute cost. - # User-defined model - model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") - tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") - tokenizer.pad_token_id = tokenizer.eos_token_id - if mode == "medusa": - # Configure and convert to medusa - config = { - "medusa_num_heads": 2, - "medusa_num_layers": 1, - } - elif mode == "eagle": - config = { - "eagle_num_layers": 1 - } - mtsp.convert(model, [(mode, config)]) +.. toctree:: + :maxdepth: 1 + :caption: Module Guide -Fine-tune speculative decoding model and store/restore the model ----------------------------------------------------------------- + ./_speculative_module_guide.rst -After converting to a speculative decoding model, you need to fine-tune the decoding module: +.. toctree:: + :maxdepth: 1 + :caption: EAGLE -.. code-block:: python + ./_eagle_workflow.rst + ./_eagle_config_reference.rst + ./_eagle_best_practices.rst - import os - from transformers import Trainer - import modelopt.torch.opt as mto - - mto.enable_huggingface_checkpointing() - - trainer = Trainer(model=model, processing_class=tokenizer, args=training_args, **data_module) - trainer._move_model_to_device(model, trainer.args.device) - - trainer.train(resume_from_checkpoint=checkpoint) - trainer.save_state() - trainer.save_model("") - - -To restore the saved speculative decoding model: - -.. code-block:: python - - model = AutoModelForCausalLM.from_pretrained("") .. _speculative-concepts: Speculative Decoding Concepts -============================= - -Below, we will provide an overview of ModelOpt's speculative decoding feature as well as its basic -concepts and terminology. +============================== Speculative decoding -------------------- + The standard way of generating text from a language model is with autoregressive decoding: one token is generated each step and appended to the input context for the next token generation. This means to generate *K* tokens it will take *K* serial runs of the model. Inference from large autoregressive models like Transformers can be slow and expensive. Therefore, various *speculative decoding* algorithms have been proposed to accelerate text generation, especially in latency critical applications. -Typically, a short draft of length *K* is generated using a faster, auto-regressive model, called draft -model. This can be attained with either a parallel model or by calling the draft model *K* times. -Then, a larger and more powerful model, called target model, is used to score the draft. Last, a sampling -scheme is used to decide which draft to accept by the target model, recovering the distribution of the -target model in the process. +Typically, a short draft of length *K* is generated using a faster model, called the *draft model*. +Then, a larger and more powerful model, called the *target model*, verifies the draft in a single +forward pass. A sampling scheme decides which draft tokens to accept, recovering the output +distribution of the target model in the process. -Medusa algorithm +EAGLE3 algorithm ---------------- -There are many ways to achieve speculative decoding. A popular approach is Medusa where instead of -using an additional draft model, it introduces a few additional decoding heads to predict multiple -future tokens simultaneously. During generation, these heads each produce multiple likely words for -the corresponding position. These options are then combined and processed using a tree-based attention -mechanism. Finally, a typical acceptance scheme is employed to pick the longest plausible prefix from -the candidates for further decoding. Since the draft model is the target model itself, this guarantees -the output distribution is the same as that of the target model. - -EAGLE algorithm ---------------- - -Unlike Medusa that predicts future tokens based on the base model hidden states, EAGLE predicts -future hidden states through a lightweight autoregressive decoder, which is then used to -predict the future tokens. Since autoregression at the feature (hidden states) level is simpler -than at the token level, EAGLE can predict future tokens more accurately than Medusa. Therefore, it -achieves higher speedup. +EAGLE3 attaches a lightweight autoregressive decoder (the draft module) to a frozen base model. +Unlike token-level autoregression, the draft module operates at the *feature level*: it predicts +future hidden states, which are then projected to token logits. Autoregression over hidden states +is an easier task than over tokens, so the draft module achieves high prediction accuracy with low +compute cost. + +Compared to earlier EAGLE versions, EAGLE3 uses auxiliary hidden states from **multiple intermediate +layers** of the base model as additional input to the draft decoder, not just the final layer output. +This richer signal enables the draft module to more accurately predict the base model's next-layer +representations, resulting in higher token acceptance rates and greater inference speedup. diff --git a/docs/source/guides/_eagle_best_practices.rst b/docs/source/guides/_eagle_best_practices.rst new file mode 100644 index 000000000..cce343e37 --- /dev/null +++ b/docs/source/guides/_eagle_best_practices.rst @@ -0,0 +1,107 @@ +.. _eagle-best-practices: + +Best Practices +==================== + +This page collects practical recommendations for achieving the best results when training EAGLE +speculative decoding models. + + +.. _eagle-best-practices-data-synthesis: + +Data Synthesis +-------------- + +Training on conversations **generated by the base model** rather than human-authored datasets +significantly improves token acceptance rates. The draft module learns to predict the target +model's actual output distribution, not just surface-level text patterns. + +To prepare synthetic training data, launch an inference server with the base model: + +.. code-block:: bash + + pip install vllm + vllm serve meta-llama/Llama-3.2-1B-Instruct \ + --api-key token-abc123 \ + --port 8000 \ + --tensor-parallel-size 1 + +.. note:: + + For quantized models, add ``--quantization=modelopt`` to the ``vllm serve`` command. + +Then generate conversations using prompts from your training dataset: + +.. code-block:: bash + + python scripts/server_generate.py \ + --data_path input_conversations/daring-anteater.jsonl \ + --output_path synthetic/train.jsonl + +Use ``--system_prompt `` to inject a system prompt into every conversation. + +For large-scale generation across multiple GPUs, see the +`SLURM data preparation guide `_. + + +.. _eagle-best-practices-configure: + +Configuring the Draft Model +---------------------------- + +ModelOpt ships with sensible default architectures for EAGLE‑1 and EAGLE‑3. See +:ref:`eagle-config-reference` for the full list of configurable fields. + +When launching training via ``launch_train.sh``, pass a JSON override file with +``--eagle_config ``. Only the fields you want to change need to be specified; omitted +fields fall back to the built-in defaults. For example, to use a 2-layer draft decoder with a +larger MLP: + +.. code-block:: json + + { + "num_hidden_layers": 2, + "intermediate_size": 8192 + } + +See :ref:`eagle-config-reference` for more details. + + +.. _eagle-best-practices-vocab: + +Draft Vocabulary Compression +----------------------------- + +By default the draft model shares the full vocabulary of the base model. For large vocabularies +(e.g., 128 256 tokens in Llama‑3) you can compress the draft vocabulary to a smaller working set, +reducing embedding table size and speeding up both training and inference. + +**Step 1 — Calibrate a vocabulary mapping** + +Find the most frequently used tokens in your training set and save a ``d2t.pt`` mapping file: + +.. code-block:: bash + + python scripts/calibrate_draft_vocab.py \ + --model meta-llama/Llama-3.2-1B-Instruct \ + --data input_conversations/daring-anteater.jsonl \ + --draft_vocab_size 32000 \ + --save_dir draft_vocab_cache + +The ``d2t.pt`` file maps each compressed draft token index to its offset in the target vocabulary. +During inference the target token is recovered as: + +.. code-block:: text + + target_token = draft_token_index + d2t[draft_token_index] + +**Step 2 — Enable compressed vocabulary in training** + +Add the following to your ``eagle_config.json``: + +.. code-block:: json + + {"draft_vocab_size": 32000} + +Then pass ``--draft_vocab_cache `` when running ``./launch_train.sh``. The draft +model will use the compressed vocabulary table during both training and export. diff --git a/docs/source/guides/_eagle_config_reference.rst b/docs/source/guides/_eagle_config_reference.rst new file mode 100644 index 000000000..bc58e6b9a --- /dev/null +++ b/docs/source/guides/_eagle_config_reference.rst @@ -0,0 +1,102 @@ +.. _eagle-config-reference: + +Configuration Reference +=============================== + +EAGLE3 is configured through a dict passed to :meth:`mtsp.convert() +`. The top-level keys correspond to +fields of :class:`EagleConfig `, with +``eagle_architecture_config`` containing a nested dict of draft module architecture settings. + +.. code-block:: python + + config = { + # --- EagleConfig top-level fields --- + "eagle_decoder_type": "llama", + "eagle_freeze_base_model": True, + "eagle_self_logit_distillation": True, + "eagle_offline": False, + "eagle_loss_decay_factor": 0.9, + + # --- Draft module architecture --- + "eagle_architecture_config": { + "num_hidden_layers": 1, + "intermediate_size": 8192, + ... + }, + } + mtsp.convert(model, [("eagle", config)]) + + +EagleConfig fields +------------------ + +``eagle_decoder_type`` (*str*, default: ``"llama"``) + Draft decoder architecture. Use ``"llama"`` for most models; ``"kimik2"`` for Kimi-K2 models. + +``eagle_freeze_base_model`` (*bool*, default: ``True``) + Keep the base model weights frozen during training. Disabling this allows joint fine-tuning + but significantly increases memory usage. + +``eagle_self_logit_distillation`` (*bool*, default: ``True``) + Apply logit-level distillation loss in addition to hidden-state regression. Improves token + acceptance rates without extra inference cost. + +``eagle_offline`` (*bool*, default: ``False``) + Use pre-computed hidden states from disk instead of running the base model forward pass at + each training step. Required for large models (70B+) that cannot be co-located with the + draft module in GPU memory. See :ref:`Offline Training `. + +``eagle_loss_decay_factor`` (*float*, default: ``0.9``) + Exponential decay applied to losses at successive draft steps, weighting earlier steps more + heavily during training. + +``eagle_architecture_config`` (*dict*, default: ``{}``) + Overrides for the draft module architecture. See `eagle_architecture_config fields`_ below. + ``hidden_size``, ``vocab_size``, and ``max_position_embeddings`` are inferred from the base + model and should not be set here. + + +eagle_architecture_config fields +--------------------------------- + +These keys override the default draft module architecture. Only set the fields you need to +change; unspecified fields fall back to the defaults listed below (for ``eagle_decoder_type="llama"``). + +``num_hidden_layers`` (*int*, default: ``1``) + Number of transformer layers in the draft decoder. Increasing this improves acceptance rates + at the cost of higher draft latency. + +``intermediate_size`` (*int*, default: inferred from base model) + Feed-forward intermediate dimension of the draft decoder MLP. + +``num_attention_heads`` (*int*, default: ``32``) + Number of attention heads in the draft decoder. + +``num_key_value_heads`` (*int*, default: ``8``) + Number of key/value heads (grouped-query attention). Set equal to ``num_attention_heads`` + to disable GQA. + +``hidden_act`` (*str*, default: ``"silu"``) + Activation function used in the MLP layers. + +``use_aux_hidden_state`` (*bool*, default: ``False``) + Feed auxiliary hidden states from intermediate base model layers into the draft decoder—the + key EAGLE3 feature. Set to ``True`` for EAGLE3; ``False`` gives EAGLE1 behaviour. + +``eagle_aux_hidden_state_layer_ids`` (*list[int]*, default: ``[]``) + Indices of base model layers whose hidden states are used as auxiliary inputs. Populated + automatically when ``use_aux_hidden_state=True``; override only for custom layer selection. + +``use_last_layernorm`` (*bool*, default: ``False``) + Apply a layer-norm after the last draft decoder layer. Required when + ``use_aux_hidden_state=True`` (i.e., EAGLE3 mode). + +``parallel_draft_step`` (*int*, default: ``1``) + Number of tokens drafted in parallel per step. Values greater than 1 enable parallel + speculative decoding and can further reduce latency on suitable hardware. + +.. note:: + + The complete set of architecture fields and their defaults can be found in + :mod:`modelopt.torch.speculative.eagle.default_config`. diff --git a/docs/source/guides/_eagle_workflow.rst b/docs/source/guides/_eagle_workflow.rst new file mode 100644 index 000000000..b4784ac32 --- /dev/null +++ b/docs/source/guides/_eagle_workflow.rst @@ -0,0 +1,204 @@ +Workflow +======== + +Follow the steps below to obtain a model with EAGLE3 speculative decoding: + +#. **Convert your model via** :meth:`mtsp.convert `: + Attach the EAGLE3 draft module to your base model. +#. **Fine-tune the EAGLE3 module**: Fine-tune the draft module using online or offline training. + The base model is frozen throughout. +#. **Checkpoint and re-load**: Save the model via :meth:`mto.save ` and + restore via :meth:`mto.restore ` +#. **Export**: Export the checkpoint to a deployment-compatible format using + :func:`export_speculative_decoding `. +#. **Deploy**: Serve the exported model with TRT-LLM, vLLM, or SGLang. + +.. _speculative_conversion: + +Convert +------- + +You can convert your model to an EAGLE3 speculative decoding model using :meth:`mtsp.convert() +`. + +Example usage: + +.. code-block:: python + + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer + import modelopt.torch.speculative as mtsp + + # Load base model + model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.2-1B-Instruct", + torch_dtype="auto", + device_map="cpu", # load on CPU first to avoid OOM; moved to GPU by the trainer + ) + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") + tokenizer.pad_token_id = tokenizer.eos_token_id + + # Configure and convert to EAGLE3 + config = { + "eagle_decoder_type": "llama", # decoder architecture; use "kimik2" for Kimi-K2 models + "eagle_architecture_config": { + "num_hidden_layers": 1, # depth of the draft decoder (default: 1) + # "intermediate_size": 8192, # MLP hidden size (default: inferred from base model) + }, + } + mtsp.convert(model, [("eagle", config)]) + +.. note:: + + ``hidden_size``, ``vocab_size``, and ``max_position_embeddings`` are automatically inferred + from the base model and do not need to be set manually. See :ref:`eagle-config-reference` + for all available options. + + +Fine-tune EAGLE3 module +----------------------- + +After conversion, fine-tune the draft module. The base model weights are frozen by default +(controlled by ``eagle_freeze_base_model`` in ``EagleConfig``). ModelOpt supports two training +modes depending on hardware constraints: + +Online Training +^^^^^^^^^^^^^^^ + +In online training the base model and draft module are co-located in GPU memory. The base model +performs forward passes at each training step to produce hidden states for the draft module. +This is recommended for smaller base models (e.g., 1B–8B parameters). + +.. code-block:: python + + from transformers import Trainer + import modelopt.torch.opt as mto + + mto.enable_huggingface_checkpointing() + + trainer = Trainer(model=model, processing_class=tokenizer, args=training_args, **data_module) + trainer._move_model_to_device(model, trainer.args.device) + + trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_state() + trainer.save_model("") + +.. note:: + + `FSDP2 `_ is used by default for distributed training. + For long-context training, context parallelism can be enabled by setting ``cp_size > 1`` in + ``ParallelismConfig``. + +.. tip:: + + Training on conversations generated by the base model (rather than human-written data) improves + acceptance rates, since the draft module learns to mimic the target distribution more closely. + See the ``scripts/server_generate.py`` script in the example directory for data synthesis. + +Offline Training +^^^^^^^^^^^^^^^^ + +For large base models (e.g., 70B+ parameters) that cannot be co-located with the draft module in +GPU memory, you can pre-compute and dump hidden states to disk, then train only the draft module. +This decouples base model inference from draft model training and significantly reduces GPU memory +requirements. + +.. note:: + + Offline training requires several to tens of terabytes of disk storage depending on dataset + size and base model hidden dimension. + +**Step 1 — Dump hidden states to disk** + +Two backends are supported: + +.. code-block:: bash + + # Recommended: TRT-LLM backend (higher throughput) + python collect_hidden_states/compute_hidden_states_trtllm.py \ + --model $BASE_MODEL \ + --input-file input_conversations/daring-anteater.jsonl \ + --output-dir $HIDDEN_STATES_DIR + + # Alternative: HuggingFace backend + python collect_hidden_states/compute_hidden_states_hf.py \ + --model $BASE_MODEL \ + --input-file input_conversations/daring-anteater.jsonl \ + --output-dir $HIDDEN_STATES_DIR + +Each output ``.pt`` file contains the tokenized input and the corresponding hidden states from +all required layers of the base model. + +**Step 2 — Convert and train using pre-computed hidden states** + +Set ``eagle_offline=True`` in the config to enable offline mode, then point the data loader to +the pre-computed hidden state files: + +.. code-block:: python + + config = { + "eagle_decoder_type": "llama", + "eagle_offline": True, # use pre-computed hidden states; skip base model forward pass + "eagle_architecture_config": { + "num_hidden_layers": 1, + }, + } + mtsp.convert(model, [("eagle", config)]) + + # Then train as usual, with offline_data_path pointing to $HIDDEN_STATES_DIR + + +Checkpoint and re-load +----------------------- + +To restore the saved EAGLE3 model: + +.. code-block:: python + + model = AutoModelForCausalLM.from_pretrained("") + + +Validation +---------- + +After online training, you can evaluate the acceptance rate (AR) on MT-Bench to gauge draft quality +before exporting: + +.. code-block:: bash + + python scripts/ar_validate.py --model_path $ONLINE_CKPT + +.. note:: + + In-framework AR evaluation is supported only for online training checkpoints. For offline + training checkpoints, export the model first and evaluate with a serving framework. + + +Export +------ + +After training, export the ModelOpt checkpoint to a deployment-compatible format that can be +served by TRT-LLM, vLLM, or SGLang: + +.. code-block:: python + + from modelopt.torch.export import export_speculative_decoding + + export_speculative_decoding(model, export_dir="") + +Alternatively, use the provided export script from the example directory: + +.. code-block:: bash + + python scripts/export_hf_checkpoint.py \ + --model_path $OUTPUT_DIR \ + --export_path $EXPORT_PATH + +The exported checkpoint separates the base model and draft module weights into independent +directories that serving frameworks can load directly. + + +Deployment +---------- + +The exported checkpoint can be deployed on TRT-LLM, vLLM, or SGLang. diff --git a/docs/source/guides/_speculative_module_guide.rst b/docs/source/guides/_speculative_module_guide.rst new file mode 100644 index 000000000..57265aefc --- /dev/null +++ b/docs/source/guides/_speculative_module_guide.rst @@ -0,0 +1,149 @@ +Module Guide +============ + +This page describes the public API of :mod:`modelopt.torch.speculative` (imported as ``mtsp``) +and explains how the conversion pipeline works under the hood. + + +Public API +---------- + +``import modelopt.torch.speculative as mtsp`` exposes the following: + +``mtsp.convert(model, mode)`` + Main entry point. Converts a base model into a speculative decoding model. + See `mtsp.convert`_ below. + +``mtsp.MedusaConfig``, ``mtsp.EagleConfig`` + Configuration dataclasses for each algorithm. + See the respective algorithm's configuration reference for field details. + +``mtsp.EAGLE1_DEFAULT_CFG``, ``mtsp.EAGLE3_DEFAULT_CFG``, ``mtsp.EAGLE_MTP_DEFAULT_CFG`` + Built-in preset dicts for common EAGLE variants, ready to pass directly to ``mtsp.convert()``. + + +mtsp.convert +------------ + +.. code-block:: python + + mtsp.convert(model: nn.Module, mode: str | list | dict) -> nn.Module + +Converts ``model`` in-place into a speculative decoding model and returns it. + +**Parameters** + +``model`` + A ``torch.nn.Module`` (typically loaded from HuggingFace) to be converted. + +``mode`` + Specifies the algorithm and its configuration. Accepted forms: + + .. code-block:: python + + # 1. (algorithm, config_dict) tuple inside a list — most common + mtsp.convert(model, [("eagle", {"eagle_decoder_type": "llama", ...})]) + + # 2. Preset dict directly + from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG + mtsp.convert(model, [("eagle", EAGLE3_DEFAULT_CFG["config"])]) + + # 3. Algorithm name only — uses all defaults + mtsp.convert(model, "medusa") + +**What convert() does** + +Internally, ``convert()`` delegates to the ``SpeculativeDecodingModeRegistry``, which routes +each algorithm name to its registered conversion function: + +.. code-block:: text + + convert(model, [("eagle", config)]) + │ + ├─ looks up "eagle" in SpeculativeDecodingModeRegistry + │ → EagleModeDescriptor.convert → convert_to_eagle_model() + │ + └─ convert_to_eagle_model(): + 1. Resolves model class in EagleDMRegistry + 2. Merges user config with built-in architecture defaults + 3. Wraps model as EagleModel (a DynamicModule subclass) + 4. Calls eagle_model.modify() to store all config as attributes + 5. Attaches the draft module; freezes base model if configured + +The result is the same Python object as the input—``convert()`` modifies the model in-place +and also returns it. After conversion the model's ``forward()`` is replaced with a +training-compatible forward that computes speculative decoding losses. + +.. note:: + + ``convert()`` is designed to be called once, immediately after loading the base model and + before moving it to GPU. The trainer then moves the converted model to the target device. + + +Supported algorithms +-------------------- + +Two algorithms are currently registered: + +``"eagle"`` + EAGLE-family speculative decoding. The draft module is a lightweight autoregressive + decoder operating at the *feature level* (predicts hidden states, not tokens directly). + Configured via :class:`EagleConfig `. + Variants—EAGLE1, EAGLE3, EAGLE-MTP—are selected through ``eagle_architecture_config`` + fields; see the :ref:`EAGLE config reference `. + +``"medusa"`` + Medusa speculative decoding. Adds *K* independent prediction heads on top of the base + model, each predicting a future token position in parallel. + Configured via :class:`MedusaConfig ` + with two fields: ``medusa_num_heads`` (default: ``2``) and ``medusa_num_layers`` + (default: ``1``). + + +Model state after conversion +----------------------------- + +After ``convert()`` returns, the model object gains the following attributes and behaviours: + +- The original model weights are preserved and accessible as before. +- A draft module (``eagle_module`` or medusa heads) is attached to the model. +- ``model.forward()`` is replaced with a training forward that returns speculative + decoding losses in addition to the normal LM loss. +- If ``eagle_freeze_base_model=True`` (the default), base model parameters have + ``requires_grad=False``. Only draft module parameters are updated during training. +- The model remains compatible with ``transformers.Trainer`` and FSDP2. + + +Save and restore +----------------- + +ModelOpt tracks the conversion so that checkpoints can be restored to the same +speculative decoding state: + +.. code-block:: python + + import modelopt.torch.opt as mto + + # Enable HuggingFace-compatible checkpointing before training + mto.enable_huggingface_checkpointing() + + # After training: save + trainer.save_model("") # saves base + draft module together + + # Restore in a new session + model = AutoModelForCausalLM.from_pretrained("") + # model is already an EagleModel / MedusaModel — no explicit mto.restore() needed + +Alternatively, if you use ``mto.save()`` / ``mto.restore()`` directly: + +.. code-block:: python + + mto.save(model, "") + model = AutoModelForCausalLM.from_pretrained("") + mto.restore(model, "") + +.. note:: + + ``mto.restore()`` re-runs the same ``convert_to_eagle_model()`` pipeline (with the saved + config) and then loads the saved weights. No manual ``mtsp.convert()`` call is needed + after restoration. diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index d7aadf994..c35feffe3 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -51,8 +51,10 @@ import modelopt.torch.sparsity as mts from modelopt.torch.export import ( export_hf_checkpoint, + export_speculative_decoding, export_tensorrt_llm_checkpoint, get_model_type, + has_spec_opt, save_expert_token_count_table, ) from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model @@ -566,6 +568,13 @@ def export_quantized( export_path = args.export_path + # Early exit for speculative decoding checkpoints + # No tokenizer saving needed for spec ckpts + if has_spec_opt(full_model): + export_speculative_decoding(full_model, export_dir=export_path) + print(f"Quantized speculative decoding checkpoint exported to: {export_path}") + return + # Check if the model is a multimodal/VLM model is_vlm = is_multimodal_model(full_model) diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index 7e9c855cb..6d37dbf94 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -19,7 +19,7 @@ This example focuses on training with Hugging Face. To train with Megatron‑LM, | Online Training | Train draft model alongside base model in GPU memory | \[[Link](#training-draft-model-with-online-base-model)\] | | Offline Training | Train draft model using pre-computed hidden states | \[[Link](#training-draft-model-with-offline-base-model)\] | | After Training | Evaluation, export and deployment | \[[Link](#model-validation)\] | -| Advanced Usage | Data synthesis, vocab compression, and configuration | \[[Link](#advanced-usage)\] | +| Custom Datasets | Other dataset options and custom data format | \[[Link](#custom-datasets)\] | | Support Matrix | Supported models for speculative decoding training | \[[Link](#support-matrix)\] | | Speculation Module Checkpoints | View pre-trained speculation modules ready to deploy! | \[[Link](#speculation-module-checkpoints)\] | | Resources | Extra links to relevant resources | \[[Link](#resources)\] | @@ -51,7 +51,9 @@ We use [Daring-Anteater](https://huggingface.co/datasets/nvidia/Daring-Anteater) python prepare_input_conversations/add_daring_anteater.py ``` -See [other-datasets](#other-datasets) section for other dataset options and instruction for user-provided data. +See the [Custom Datasets](#custom-datasets) section for other dataset options and instructions for user-provided data. + +For higher acceptance rates, consider training on **model-generated conversations** — see [Best Practices: Data Synthesis](https://nvidia.github.io/Model-Optimizer/guides/5_speculative_decoding.html#data-synthesis) for step-by-step instructions. ## Getting Started: Simplified Workflow @@ -81,6 +83,8 @@ For small base models that fit in GPU memory, we can collocate them with draft m FSDP2 is used by default. To enable context parallelism for long-context training, specify `--cp_size n`. The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT. +To customize the draft model architecture (number of layers, MLP size, etc.) or enable draft vocabulary compression, see [Best Practices: Configuring the Draft Model](https://nvidia.github.io/Model-Optimizer/guides/5_speculative_decoding.html#configuring-the-draft-model) and [Draft Vocabulary Compression](https://nvidia.github.io/Model-Optimizer/guides/5_speculative_decoding.html#draft-vocabulary-compression). + ## Training Draft Model with Offline Base Model For large models, you can export intermediate hidden states to disk and train only the draft model. This significantly reduces GPU memory requirements, but requires several to tens of terabytes of disk storage depending on dataset size. @@ -195,9 +199,7 @@ One can also use [examples/specdec_bench](../specdec_bench) to validate the trai See more details on deployment of quantized model to TRTLLM [here](../llm_ptq/README.md). -## Advanced Usage - -### Other Datasets +## Custom Datasets In addition to `daring-anteater`, we provide scripts for adding several other commonly used datasets in `prepare_input_conversations`: @@ -219,106 +221,6 @@ To use your own datasets, please preprocess your data into a `.jsonl` file with } ``` -### Data Synthesis - -To achieve higher acceptance rates during speculative decoding, it is beneficial to use conversations generated by the base model as training data. This ensures that the draft model's output distribution closely aligns with that of the base model. - -To prepare such data, we launch an inference server with the base model: - -```bash -pip install vllm -vllm serve meta-llama/Llama-3.2-1B-Instruct --api-key token-abc123 --port 8000 --tensor-parallel-size 1 -``` - -Note: Add `--quantization=modelopt` flag for quantized models. - -Then, we generate conversations with the base model using prompts from Daring-Anteater: - -```bash -python scripts/server_generate.py --data_path input_conversations/daring-anteater.jsonl --output_path synthetic/train.jsonl -``` - -To add a system prompt, use the `--system_prompt ` argument. - -For large scale data generation, please see [SLURM prepare data](SLURM_prepare_data.md) for SLURM support. - -### Configuring Draft Model - -For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. You can override default settings by providing an additional JSON dict. E.g. To use 2-layer eagle with 8192 intermediate size for MLP, set `eagle_config.json` to: - -```json -{ - "num_hidden_layers": 2, - "intermediate_size":8192 -} -``` - -### Draft Vocabulary Compression - -We can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set: - -```bash -python scripts/calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data input_conversations/daring-anteater.jsonl --draft_vocab_size 32000 --save_dir draft_vocab_cache -``` - -This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`. - -Then, simply set `{"draft_vocab_size":32000}` in `eagle_config.json` and include `--draft_vocab_cache ` when running `./launch_train.sh`. The draft model will use this provided vocab table during training and export. - -### Interact with `modelopt.torch.speculative` - -`main.py` provides an example for converting a HF base model for speculative decoding and training it. It consists of a few simple steps: -First, load the base model and tokenizer from Hugging Face: - -```python -model = transformers.AutoModelForCausalLM.from_pretrained( - "" -) -``` - -Then, load default eagle config and make necessary overwrites: - -```python -# Load default config -config = { - "eagle1": EAGLE1_DEFAULT_CFG, - "eagle3": EAGLE3_DEFAULT_CFG, -}[training_args.mode]["config"] - -# overwrite config with custom config -config["eagle_architecture_config"].update({"": ""}) - -# Mandatory: hidden size, vocab size and max position embeddings must match base model -config["eagle_architecture_config"].update( - { - "hidden_size": model.config.hidden_size, - "vocab_size": model.config.vocab_size, - "max_position_embeddings": model.config.max_position_embeddings, - } -) -``` - -Then, we convert model to a speculative decoding model: - -```python -mtsp.convert(model, [("eagle", config)]) -``` - -This will modify the model in-place with eagle training forward, making it compatible with HF trainer: - -```python -# Create a trainer -trainer = transformers.Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) -trainer._move_model_to_device(model, trainer.args.device) - -# Enable HF checkpointing so that the saved model will contain the speculative decoding module -mto.enable_huggingface_checkpointing() - -trainer.train(resume_from_checkpoint=checkpoint) -trainer.save_state() -trainer.save_model("") -``` - ## Support Matrix | Model | Medusa | EAGLE1/2 | EAGLE3 | diff --git a/examples/speculative_decoding/scripts/export_hf_checkpoint.py b/examples/speculative_decoding/scripts/export_hf_checkpoint.py index fc3421583..23a7560f7 100644 --- a/examples/speculative_decoding/scripts/export_hf_checkpoint.py +++ b/examples/speculative_decoding/scripts/export_hf_checkpoint.py @@ -20,7 +20,7 @@ import torch import modelopt.torch.opt as mto -from modelopt.torch.export import export_hf_checkpoint +from modelopt.torch.export import export_speculative_decoding from modelopt.torch.speculative.utils import load_vlm_or_llm_with_kwargs @@ -41,7 +41,7 @@ def parse_args(): _, model = load_vlm_or_llm_with_kwargs(args.model_path, torch_dtype="auto") model.eval() with torch.inference_mode(): - export_hf_checkpoint( + export_speculative_decoding( model, export_dir=args.export_path, ) diff --git a/modelopt/torch/export/plugins/hf_spec_configs.py b/modelopt/torch/export/plugins/hf_spec_configs.py new file mode 100644 index 000000000..b78dfadd4 --- /dev/null +++ b/modelopt/torch/export/plugins/hf_spec_configs.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Template config for speculative decoding exporting.""" + +llama_eagle_template_config = { + "architectures": ["LlamaForCausalLMEagle3"], + "bos_token_id": None, + "eos_token_id": None, + "hidden_act": None, + "hidden_size": None, + "initializer_range": None, + "intermediate_size": None, + "max_position_embeddings": None, + "model_type": "llama", + "num_attention_heads": None, + "num_key_value_heads": None, + "num_hidden_layers": None, + "pad_token_id": None, + "rms_norm_eps": None, + "tie_word_embeddings": False, + "torch_dtype": None, + "transformers_version": None, + "use_cache": None, + "vocab_size": None, + "draft_vocab_size": None, + "rope_scaling": None, + "attention_bias": None, + "attention_dropout": None, + "head_dim": None, + "mlp_bias": None, + "pretraining_tp": None, + "rope_theta": None, + "eagle_config": { + "eagle_aux_hidden_state_layer_ids": None, + "use_aux_hidden_state": None, + "use_input_layernorm_in_first_layer": None, + "use_last_layernorm": None, + "use_mtp_layernorm": None, + "next_layer_regular": True, + "parallel_draft_step": None, + "parallel_draft_heads_num_layers": None, + }, +} + +kimik2_eagle_template_config = { + "architectures": ["Eagle3DeepseekV2ForCausalLM"], + "attention_bias": None, + "attention_dropout": None, + "aux_loss_alpha": None, + "bos_token_id": None, + "chunk_size_feed_forward": None, + "diversity_penalty": None, + "do_sample": None, + "early_stopping": None, + "encoder_no_repeat_ngram_size": None, + "eos_token_id": None, + "ep_size": None, + "first_k_dense_replace": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "hidden_act": None, + "hidden_size": None, + "id2label": None, + "initializer_range": None, + "intermediate_size": None, + "is_decoder": None, + "is_encoder_decoder": None, + "kv_lora_rank": None, + "label2id": None, + "length_penalty": None, + "max_length": None, + "max_position_embeddings": None, + "min_length": None, + "model_type": "kimi_k2", + "moe_intermediate_size": None, + "moe_layer_freq": None, + "n_group": None, + "n_routed_experts": None, + "n_shared_experts": None, + "no_repeat_ngram_size": None, + "norm_topk_prob": None, + "num_attention_heads": None, + "num_beam_groups": None, + "num_beams": None, + "num_experts_per_tok": None, + "num_hidden_layers": None, + "num_key_value_heads": None, + "num_nextn_predict_layers": None, + "num_return_sequences": None, + "output_attentions": None, + "output_hidden_states": None, + "output_scores": None, + "pad_token_id": None, + "pretraining_tp": None, + "pruned_heads": None, + "q_lora_rank": None, + "qk_nope_head_dim": None, + "qk_rope_head_dim": None, + "remove_invalid_values": None, + "repetition_penalty": None, + "return_dict": None, + "return_dict_in_generate": None, + "rms_norm_eps": None, + "rope_scaling": None, + "rope_theta": None, + "routed_scaling_factor": None, + "scoring_func": None, + "sep_token_id": None, + "seq_aux": None, + "temperature": None, + "tf_legacy_loss": None, + "tie_encoder_decoder": None, + "tie_word_embeddings": None, + "top_k": None, + "top_p": None, + "topk_group": None, + "topk_method": None, + "torch_dtype": None, + "torchscript": None, + "transformers_version": None, + "typical_p": None, + "use_bfloat16": None, + "use_cache": None, + "v_head_dim": None, + "vocab_size": None, + "eagle_config": { + "eagle_aux_hidden_state_layer_ids": None, + "use_aux_hidden_state": None, + "use_input_layernorm_in_first_layer": None, + "use_last_layernorm": None, + "use_mtp_layernorm": None, + "next_layer_regular": True, + "parallel_draft_step": None, + "parallel_draft_heads_num_layers": None, + }, +} diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index cdb009003..b25a630ab 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -15,123 +15,226 @@ """Modify state_dict and config for exporting speculative decoding in official format.""" +import json import re -from copy import copy +from abc import abstractmethod +from copy import deepcopy +from pathlib import Path import torch import torch.nn as nn +from safetensors.torch import save_file + +from .hf_spec_configs import kimik2_eagle_template_config, llama_eagle_template_config + +ALL_SPEC_MODES = ["eagle"] LLAMA_EAGLE_SINGLE_LAYER = { "required": { - "midlayer.self_attn.q_proj.weight", - "midlayer.self_attn.k_proj.weight", - "midlayer.self_attn.v_proj.weight", - "midlayer.self_attn.o_proj.weight", - "midlayer.mlp.gate_proj.weight", - "midlayer.mlp.up_proj.weight", - "midlayer.mlp.down_proj.weight", - "midlayer.hidden_norm.weight", - "midlayer.input_layernorm.weight", - "midlayer.post_attention_layernorm.weight", - "norm.weight", - "fc.weight", + "layers.0.self_attn.q_proj", + "layers.0.self_attn.k_proj", + "layers.0.self_attn.v_proj", + "layers.0.self_attn.o_proj", + "layers.0.mlp.gate_proj", + "layers.0.mlp.up_proj", + "layers.0.mlp.down_proj", + "layers.0.hidden_norm", + "layers.0.input_layernorm", + "layers.0.post_attention_layernorm", + "norm", + "fc", }, - "optional": {"d2t", "lm_head.weight"}, + "optional": {"d2t", "lm_head"}, } KIMIK2_EAGLE_SINGLE_LAYER = { "required": { - "midlayer.self_attn.kv_a_layernorm.weight", - "midlayer.self_attn.q_a_layernorm.weight", - "midlayer.self_attn.q_a_proj.weight", - "midlayer.self_attn.q_b_proj.weight", - "midlayer.self_attn.kv_a_proj_with_mqa.weight", - "midlayer.self_attn.kv_b_proj.weight", - "midlayer.self_attn.o_proj.weight", - "midlayer.mlp.gate_proj.weight", - "midlayer.mlp.up_proj.weight", - "midlayer.mlp.down_proj.weight", - "midlayer.hidden_norm.weight", - "midlayer.input_layernorm.weight", - "midlayer.post_attention_layernorm.weight", - "norm.weight", - "fc.weight", + "layers.0.self_attn.kv_a_layernorm", + "layers.0.self_attn.q_a_layernorm", + "layers.0.self_attn.q_a_proj", + "layers.0.self_attn.q_b_proj", + "layers.0.self_attn.kv_a_proj_with_mqa", + "layers.0.self_attn.kv_b_proj", + "layers.0.self_attn.o_proj", + "layers.0.mlp.gate_proj", + "layers.0.mlp.up_proj", + "layers.0.mlp.down_proj", + "layers.0.hidden_norm", + "layers.0.input_layernorm", + "layers.0.post_attention_layernorm", + "norm", + "fc", }, "optional": { "d2t", - "lm_head.weight", + "lm_head", }, } -def _check_valid_sd(state_dict: dict, eagle_decoder_type: str, num_hidden_layers: int): - """Check the export state dict is valid, otherwise raise Exception.""" - expected_keys_single_layer = { - "llama": LLAMA_EAGLE_SINGLE_LAYER, - "kimik2": KIMIK2_EAGLE_SINGLE_LAYER, - }[eagle_decoder_type] - # Check that export sd has required keys - if num_hidden_layers == 1: - for key in expected_keys_single_layer["required"]: - assert key in state_dict, f"Missing required key: {key}" - else: +def has_spec_opt(model: nn.Module): + """Check if the model has speculative decoding optimization.""" + opt_modes = getattr(model, "_modelopt_state", []) + return any(mode[0] in ALL_SPEC_MODES for mode in opt_modes) + + +def has_quant_opt(model: nn.Module): + """Check if the model has quantization optimization.""" + opt_modes = getattr(model, "_modelopt_state", []) + return any(mode[0] == "quantize" for mode in opt_modes) + + +class SpeculativeDecodingExporter: + """Export an modelopt speculative decoding checkpoint to deployment format.""" + + def __init__(self, model: nn.Module): + """Initialize the SpeculativeDecodingExporter.""" + self.model = model + + @abstractmethod + def export(self, export_dir: Path | str, dtype: torch.dtype | None = None): + """Export the model to the deployment format.""" + raise NotImplementedError("Subclasses must implement this method.") + + +class EagleExporter(SpeculativeDecodingExporter): + """Draft model exporter for Eagle.""" + + def __init__(self, model: nn.Module): + """Initialize the EagleExporter.""" + super().__init__(model) + self.eagle_decoder_type = model.eagle_config.eagle_decoder_type + self.num_hidden_layers = model.eagle_config.num_hidden_layers + + def _check_valid_sd(self, export_sd: dict): + """Check the export state dict is valid, otherwise raise Exception.""" + expected_keys_single_layer = { + "llama": LLAMA_EAGLE_SINGLE_LAYER, + "kimik2": KIMIK2_EAGLE_SINGLE_LAYER, + }[self.eagle_decoder_type] + # Check that export sd has required keys for key in expected_keys_single_layer["required"]: - assert key.replace("midlayer", "midlayer.0") in state_dict, ( - f"Missing required key: {key}" - ) - for i in range(1, num_hidden_layers): + assert f"{key}.weight" in export_sd, f"Missing required key: {key}.weight" + for i in range(1, self.num_hidden_layers): for key in expected_keys_single_layer["required"] - { - "midlayer.hidden_norm.weight", - "midlayer.input_layernorm.weight", - "norm.weight", - "fc.weight", + "layers.0.hidden_norm", + "layers.0.input_layernorm", + "norm", + "fc", }: - assert key.replace("midlayer", f"midlayer.{i}") in state_dict, ( - f"Missing required key: {key}" + assert f"{key}.weight".replace("layers.0", f"layers.{i}") in export_sd, ( + f"Missing required key: {key}.weight" ) - # Check that export sd has no unexpected keys - allowed_keys_single_layer = ( - expected_keys_single_layer["required"] | expected_keys_single_layer["optional"] - ) - if num_hidden_layers == 1: - for key in state_dict: - assert key in allowed_keys_single_layer, f"Unexpected key: {key}" - else: - for key in state_dict: - assert re.sub(r"midlayers\.\d+\.", "", key) in { - k.replace("midlayer.", "") for k in allowed_keys_single_layer - }, f"Unexpected key: {key}" - - -def spec_opt_only(model: nn.Module): - """Check if the model have only speculative decoding optimization.""" - opt_modes = getattr(model, "_modelopt_state", None) - return ( - isinstance(opt_modes, (list, tuple)) and len(opt_modes) == 1 and opt_modes[0][0] == "eagle" - ) - - -def export_spec_ckpt_state_dict(model: nn.Module): - """Only return the state dict of the draft model in official format and ignore the base model.""" - # check the model has only speculative decoding - assert spec_opt_only(model), "Not purely eagle model." - - # Rename layers to midlayer - if model.eagle_config.num_hidden_layers == 1: - model.eagle_module.midlayer = model.eagle_module._modules.pop("layers")[0] - else: - model.eagle_module.midlayer = model.eagle_module._modules.pop("layers") - export_sd = copy(model.eagle_module.state_dict()) - - # Use base model's lm head if draft model doesn't have one - if "lm_head.weight" not in export_sd: - export_sd["lm_head.weight"] = model.state_dict()["lm_head.weight"] - - # Rename parallel draft weights - if model.eagle_config.parallel_draft_step > 1: - for i in range(model.eagle_config.parallel_draft_step - 1): - for j in range(model.eagle_config.parallel_draft_heads_num_layers): + # Check that export sd has no unexpected keys + # Note that quantized eagle are allowed to have scales + allowed_keys_single_layer = ( + expected_keys_single_layer["required"] | expected_keys_single_layer["optional"] + ) + for key in export_sd: + assert ( + re.sub(r"layers\.\d+\.", "layers.0.", key.rsplit(".", 1)[0]) + in allowed_keys_single_layer + ), f"Unexpected key: {key}" + + def _extract_state_dict(self, full_state_dict: dict): + """Extract and return eagle state dict in deployment format.""" + export_sd = {} + for key in full_state_dict: + if "eagle_module" in key or "lm_head" in key: + export_key = key.replace("eagle_module.", "") + export_sd[export_key] = full_state_dict[key].clone() + # Use base model's lm head if draft model doesn't have one + if "lm_head.weight" not in export_sd: + export_sd["lm_head.weight"] = full_state_dict["lm_head.weight"] + + self._check_valid_sd(export_sd) + + return export_sd + + def _export_config(self): + """Export config.json in deployment format.""" + template_config: dict = { + "llama": llama_eagle_template_config, + "kimik2": kimik2_eagle_template_config, + }[self.model.eagle_config.eagle_decoder_type] + template_config = deepcopy(template_config) + + def _get_config_from_draft_or_base(key: str, model: nn.Module): + if getattr(model._draft_model_config, key, None) is not None: + return getattr(model._draft_model_config, key) + elif getattr(model.config, key, None) is not None: + return getattr(model.config, key) + else: + return None + + for key in template_config: + value = template_config[key] + if isinstance(value, dict): + # for eagle config, we find it in model.eagle_config + for sub_key in value: + if value[sub_key] is None: + value[sub_key] = _get_config_from_draft_or_base(sub_key, self.model) + elif value is None: + # First, we try to load fron eagle config. + new_value = _get_config_from_draft_or_base(key, self.model) + # If the value is a torch.dtype, we convert to string for serialization. + if isinstance(new_value, torch.dtype): + new_value = str(new_value).replace("torch.", "") + template_config[key] = new_value + + return template_config + + def export(self, export_dir: Path | str, dtype: torch.dtype | None = None): + """Export the model to the deployment format.""" + # Make export dir + export_dir = Path(export_dir) + export_dir.mkdir(parents=True, exist_ok=True) + + # Export quantized modules + if has_quant_opt(self.model): + from ..unified_export_hf import _export_transformers_checkpoint + + full_sd, hf_quant_config = _export_transformers_checkpoint(self.model, dtype) + else: + full_sd, hf_quant_config = self.model.state_dict(), None + + # Export state dit + drafter_sd = self._extract_state_dict(full_sd) + save_file(drafter_sd, f"{export_dir}/model.safetensors") + + # Export config + drafter_config = self._export_config() + if hf_quant_config is not None: + drafter_config["quantization_config"] = hf_quant_config + with open(f"{export_dir}/config.json", "w") as file: + json.dump(drafter_config, file, indent=4) + + # Export hf_quant_config for backward compatibility + if hf_quant_config is not None: + with open(f"{export_dir}/hf_quant_config.json", "w") as file: + json.dump(hf_quant_config, file, indent=4) + + +class EagleMedusaExporter(EagleExporter): + """Draft model exporter for EagleMedusa.""" + + def __init__(self, model: nn.Module): + """Initialize the EagleMedusaExporter.""" + super().__init__(model) + self.parallel_draft_step = model.eagle_config.parallel_draft_step + self.parallel_draft_heads_num_layers = model.eagle_config.parallel_draft_heads_num_layers + # NOTE: tmp: bypassing format check for parallel draft + self._check_valid_sd = lambda *args, **kwargs: None + + def _extract_state_dict(self, full_state_dict: dict): + """Extract the state dict of the draft model in deployment format.""" + export_sd = super()._extract_state_dict(full_state_dict) + if self.parallel_draft_step <= 1: + return export_sd + + for i in range(self.parallel_draft_step - 1): + for j in range(self.parallel_draft_heads_num_layers): export_sd[f"parallel_draft_heads.{i}.medusa_layers.{j}.linear.weight"] = ( export_sd.pop(f"parallel_draft_heads.medusa_heads.{i}.{j}.linear.weight") ) @@ -143,180 +246,4 @@ def export_spec_ckpt_state_dict(model: nn.Module): export_sd["parallel_draft_heads.lm_head.weight"] = export_sd.pop( "parallel_draft_heads.lm_head.weight" ) - # NOTE: tmp: bypassing format check for parallel draft return export_sd - - _check_valid_sd( - export_sd, model.eagle_config.eagle_decoder_type, model.eagle_config.num_hidden_layers - ) - - return export_sd - - -def export_spec_ckpt_config(model: nn.Module): - """Return the config of draft model in official format.""" - assert spec_opt_only(model), "Not purely eagle model." - - # This is the config keys in official checkpoint. - llama_eagle_template_config = { - "architectures": ["LlamaForCausalLMEagle3"], - "bos_token_id": None, - "eos_token_id": None, - "hidden_act": None, - "hidden_size": None, - "initializer_range": None, - "intermediate_size": None, - "max_position_embeddings": None, - "model_type": "llama", - "num_attention_heads": None, - "num_key_value_heads": None, - "num_hidden_layers": None, - "pad_token_id": None, - "rms_norm_eps": None, - "tie_word_embeddings": False, - "torch_dtype": None, - "transformers_version": None, - "use_cache": None, - "vocab_size": None, - "draft_vocab_size": None, - "rope_scaling": None, - "attention_bias": None, - "attention_dropout": None, - "head_dim": None, - "mlp_bias": None, - "pretraining_tp": None, - "rope_theta": None, - "eagle_config": { - "eagle_aux_hidden_state_layer_ids": None, - "use_aux_hidden_state": None, - "use_input_layernorm_in_first_layer": None, - "use_last_layernorm": None, - "use_mtp_layernorm": None, - "next_layer_regular": True, - "parallel_draft_step": None, - "parallel_draft_heads_num_layers": None, - }, - } - - kimik2_eagle_template_config = { - "architectures": ["Eagle3DeepseekV2ForCausalLM"], - "attention_bias": None, - "attention_dropout": None, - "aux_loss_alpha": None, - "bos_token_id": None, - "chunk_size_feed_forward": None, - "diversity_penalty": None, - "do_sample": None, - "early_stopping": None, - "encoder_no_repeat_ngram_size": None, - "eos_token_id": None, - "ep_size": None, - "first_k_dense_replace": None, - "forced_bos_token_id": None, - "forced_eos_token_id": None, - "hidden_act": None, - "hidden_size": None, - "id2label": None, - "initializer_range": None, - "intermediate_size": None, - "is_decoder": None, - "is_encoder_decoder": None, - "kv_lora_rank": None, - "label2id": None, - "length_penalty": None, - "max_length": None, - "max_position_embeddings": None, - "min_length": None, - "model_type": "kimi_k2", - "moe_intermediate_size": None, - "moe_layer_freq": None, - "n_group": None, - "n_routed_experts": None, - "n_shared_experts": None, - "no_repeat_ngram_size": None, - "norm_topk_prob": None, - "num_attention_heads": None, - "num_beam_groups": None, - "num_beams": None, - "num_experts_per_tok": None, - "num_hidden_layers": None, - "num_key_value_heads": None, - "num_nextn_predict_layers": None, - "num_return_sequences": None, - "output_attentions": None, - "output_hidden_states": None, - "output_scores": None, - "pad_token_id": None, - "pretraining_tp": None, - "pruned_heads": None, - "q_lora_rank": None, - "qk_nope_head_dim": None, - "qk_rope_head_dim": None, - "remove_invalid_values": None, - "repetition_penalty": None, - "return_dict": None, - "return_dict_in_generate": None, - "rms_norm_eps": None, - "rope_scaling": None, - "rope_theta": None, - "routed_scaling_factor": None, - "scoring_func": None, - "sep_token_id": None, - "seq_aux": None, - "temperature": None, - "tf_legacy_loss": None, - "tie_encoder_decoder": None, - "tie_word_embeddings": None, - "top_k": None, - "top_p": None, - "topk_group": None, - "topk_method": None, - "torch_dtype": None, - "torchscript": None, - "transformers_version": None, - "typical_p": None, - "use_bfloat16": None, - "use_cache": None, - "v_head_dim": None, - "vocab_size": None, - "eagle_config": { - "eagle_aux_hidden_state_layer_ids": None, - "use_aux_hidden_state": None, - "use_input_layernorm_in_first_layer": None, - "use_last_layernorm": None, - "use_mtp_layernorm": None, - "next_layer_regular": True, - "parallel_draft_step": None, - "parallel_draft_heads_num_layers": None, - }, - } - - template_config: dict = { - "llama": llama_eagle_template_config, - "kimik2": kimik2_eagle_template_config, - }[model.eagle_config.eagle_decoder_type] - - def _get_config_from_eagle_config_or_base_config(key: str, model: nn.Module): - if getattr(model.eagle_config, key, None) is not None: - return getattr(model.eagle_config, key) - elif getattr(model.config, key, None) is not None: - return getattr(model.config, key) - else: - return None - - for key in template_config: - value = template_config[key] - if isinstance(value, dict): - # for eagle config, we find it in model.eagle_config - for sub_key in value: - if value[sub_key] is None: - value[sub_key] = _get_config_from_eagle_config_or_base_config(sub_key, model) - elif value is None: - # First, we try to load fron eagle config. - new_value = _get_config_from_eagle_config_or_base_config(key, model) - # If the value is a torch.dtype, we convert to string for serialization. - if isinstance(new_value, torch.dtype): - new_value = str(new_value).replace("torch.", "") - template_config[key] = new_value - - return template_config diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index ca80cb450..4b983bef9 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -81,7 +81,7 @@ QUANTIZATION_W4A8_NVFP4_FP8, ) from .model_utils import get_language_model_from_vl, is_multimodal_model -from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only +from .plugins import SpeculativeDecodingExporter, has_spec_opt from .quant_utils import ( fuse_prequant_layernorm, fuse_prequant_to_linear, @@ -98,7 +98,7 @@ to_quantized_weight, ) -__all__ = ["export_hf_checkpoint"] +__all__ = ["export_hf_checkpoint", "export_speculative_decoding"] def _is_enabled_quantizer(quantizer): @@ -978,6 +978,18 @@ def _export_diffusers_checkpoint( print(f"Export complete. Saved to: {export_dir}") +def export_speculative_decoding( + model: torch.nn.Module, + export_dir: Path | str, + dtype: torch.dtype | None = None, +) -> None: + """Export speculative decoding HuggingFace model checkpoint.""" + assert has_spec_opt(model), "Model is not optimized for speculative decoding." + + exporter: SpeculativeDecodingExporter = model.get_exporter() + exporter.export(export_dir, dtype) + + def export_hf_checkpoint( model: Any, dtype: torch.dtype | None = None, @@ -1013,15 +1025,6 @@ def export_hf_checkpoint( _export_diffusers_checkpoint(model, dtype, export_dir, components) return - # Transformers model export - # NOTE: (hg) Early exit for speculative decoding models - # This is a temp workaround to avoid error with offline spec ckpt during export - if spec_opt_only(model): - save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors") - with open(f"{export_dir}/config.json", "w") as file: - json.dump(export_spec_ckpt_config(model), file, indent=4) - return - try: post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index f8b7e33df..fb7da52c7 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -50,6 +50,11 @@ from transformers.utils import ModelOutput from transformers.utils.quantization_config import QuantizationMethod +from ...export.plugins.hf_spec_export import ( + EagleExporter, + EagleMedusaExporter, + SpeculativeDecodingExporter, +) from ..eagle.conversion import EagleDMRegistry from ..eagle.eagle_model import EagleModel from ..eagle.utils import expand_mask, make_causal_mask @@ -450,6 +455,18 @@ def _base_llm_config(self): or self.config ) + @property + def _draft_model_config(self): + """Return the llm config for the draft model.""" + return self.eagle_config + + def get_exporter(self) -> SpeculativeDecodingExporter: + """Get the exporter for the draft model.""" + exporter_cls = ( + EagleExporter if self.eagle_config.parallel_draft_step <= 1 else EagleMedusaExporter + ) + return exporter_cls(self) + def _find_base_model_parts(self): """Find model parts from different models and set base_{part}_path attributes.""" base_model_parts_mapping = { diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index 4f80692ca..9c73ea96a 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -145,7 +145,7 @@ def test_export_hf_checkpoint(eagle_output_dir): # Check the exported checkpoints have required keys state_dict = safetensors.torch.load_file(eagle_output_dir / "eagle-tinyllama-export" / "model.safetensors") for required_key in LLAMA_EAGLE_SINGLE_LAYER["required"]: - assert required_key in state_dict, f"Missing key '{required_key}' in state_dict" + assert f"{required_key}.weight" in state_dict, f"Missing key '{required_key}.weight' in state_dict" def test_convert_to_vllm_ckpt(tiny_llama_path, eagle_output_dir):