Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/models/qwen3_5_moe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ recommended — the model is too large to fit in VRAM at bf16.

```bash
python export.py \
--model-dir ~/models/Qwen3.5-MoE-A3B \
--model-dir ~/models/Qwen3.5-35B-A3B \
--output-dir ./qwen35_moe_exports \
--qlinear 4w \
--qembedding 8w
Expand Down Expand Up @@ -73,7 +73,7 @@ The runner requires:
cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner \
--model_path qwen35_moe_exports/model.pte \
--data_path qwen35_moe_exports/aoti_cuda_blob.ptd \
--tokenizer_path ~/models/Qwen3.5-MoE-A3B/tokenizer.json \
--tokenizer_path ~/models/Qwen3.5-35B-A3B/tokenizer.json \
--prompt "The meaning of life is" \
--max_new_tokens 128
```
Expand Down
9 changes: 6 additions & 3 deletions examples/models/qwen3_5_moe/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
import torch.nn as nn

from executorch.examples.models.qwen3_5_moe.model import FullAttention, Qwen35MoE
from executorch.examples.models.qwen3_5_moe.model import Qwen35MoE


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -129,13 +129,16 @@ def _materialize_buffers(model, config):
Replaces meta buffers with real tensors on CPU, recomputes RoPE
inv_freq and causal masks.
"""
# State buffers (KV cache, conv/recurrent state) are bf16 to match
# compute dtype. Masks stay bool, inv_freq stays float32.
for fqn, buf in list(model.named_buffers()):
if buf.device.type == "meta":
dtype = torch.bfloat16 if buf.dtype != torch.bool else torch.bool
parts = fqn.rsplit(".", 1)
parent = model.get_submodule(parts[0]) if len(parts) > 1 else model
parent.register_buffer(
parts[-1],
torch.zeros(buf.shape, dtype=buf.dtype, device="cpu"),
torch.zeros(buf.shape, dtype=dtype, device="cpu"),
)
Comment on lines +132 to 142
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says “inv_freq stays float32”, but the materialization loop currently casts every meta buffer that isn’t bool to bf16. That means the meta placeholder for rotary_emb.inv_freq will be bf16 until the later recompute step overwrites it, which is a bit inconsistent and could break if the recompute logic changes. Consider explicitly preserving float32 for inv_freq (e.g., by checking the buffer name) or defaulting to the original buf.dtype except for known state buffers that must be bf16.

Copilot uses AI. Check for mistakes.

# Recompute RoPE inv_freq (zero-fill above is wrong for these)
Expand All @@ -153,7 +156,7 @@ def _materialize_buffers(model, config):

# Recompute causal masks for full attention layers
for layer in model.layers:
if isinstance(layer.attn, FullAttention):
if hasattr(layer.attn, "mask"):
mask = torch.tril(
torch.ones(config.max_seq_len, config.max_seq_len, dtype=torch.bool)
)
Expand Down
Loading