Fix buffer dtype mismatch and causal mask recompute in export#18380
Fix buffer dtype mismatch and causal mask recompute in export#18380mergennachin merged 1 commit intomainfrom
Conversation
- Materialize state buffers as bf16 (not fp32) to match compute dtype, fixing "Expected bfloat16 inputs" error in Triton SDPA during AOTI lowering - Use hasattr(layer.attn, "mask") instead of isinstance(layer.attn, FullAttention) for causal mask recompute — the isinstance check fails when the module is imported via different Python paths (executorch.examples... vs examples...) - Remove unused FullAttention import from export.py - Fix model directory name in README examples (Qwen3.5-MoE-A3B → Qwen3.5-35B-A3B)
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18380
Note: Links to docs will display an error until the docs builds have been completed.
|
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
This PR updates the Qwen3.5-MoE export flow to avoid dtype mismatches during AOTI/CUDA lowering and to make causal mask recomputation robust to modules imported from different Python paths.
Changes:
- Materialize meta-device state buffers as
bf16(keeping masksbool) to match compute dtype and avoid Triton SDPA dtype errors. - Recompute causal masks based on
hasattr(layer.attn, "mask")instead ofisinstance(..., FullAttention)to avoid class-identity mismatches across import paths. - Documentation cleanup: fix the example model directory name and remove an unused import.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| examples/models/qwen3_5_moe/export.py | Adjusts meta-buffer materialization dtype and makes causal mask recompute detection import-path-agnostic. |
| examples/models/qwen3_5_moe/README.md | Updates example model directory/tokenizer paths to match the referenced HF model name. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # State buffers (KV cache, conv/recurrent state) are bf16 to match | ||
| # compute dtype. Masks stay bool, inv_freq stays float32. | ||
| for fqn, buf in list(model.named_buffers()): | ||
| if buf.device.type == "meta": | ||
| dtype = torch.bfloat16 if buf.dtype != torch.bool else torch.bool | ||
| parts = fqn.rsplit(".", 1) | ||
| parent = model.get_submodule(parts[0]) if len(parts) > 1 else model | ||
| parent.register_buffer( | ||
| parts[-1], | ||
| torch.zeros(buf.shape, dtype=buf.dtype, device="cpu"), | ||
| torch.zeros(buf.shape, dtype=dtype, device="cpu"), | ||
| ) |
There was a problem hiding this comment.
The comment says “inv_freq stays float32”, but the materialization loop currently casts every meta buffer that isn’t bool to bf16. That means the meta placeholder for rotary_emb.inv_freq will be bf16 until the later recompute step overwrites it, which is a bit inconsistent and could break if the recompute logic changes. Consider explicitly preserving float32 for inv_freq (e.g., by checking the buffer name) or defaulting to the original buf.dtype except for known state buffers that must be bf16.
fixing "Expected bfloat16 inputs" error in Triton SDPA during AOTI
lowering
FullAttention) for causal mask recompute — the isinstance check fails
when the module is imported via different Python paths
(executorch.examples... vs examples...)
Qwen3.5-35B-A3B)