Skip to content

Fix buffer dtype mismatch and causal mask recompute in export#18380

Merged
mergennachin merged 1 commit intomainfrom
mnachin/qwen3_5_moe
Mar 20, 2026
Merged

Fix buffer dtype mismatch and causal mask recompute in export#18380
mergennachin merged 1 commit intomainfrom
mnachin/qwen3_5_moe

Conversation

@mergennachin
Copy link
Copy Markdown
Contributor

  • Materialize state buffers as bf16 (not fp32) to match compute dtype,
    fixing "Expected bfloat16 inputs" error in Triton SDPA during AOTI
    lowering
  • Use hasattr(layer.attn, "mask") instead of isinstance(layer.attn,
    FullAttention) for causal mask recompute — the isinstance check fails
    when the module is imported via different Python paths
    (executorch.examples... vs examples...)
  • Remove unused FullAttention import from export.py
  • Fix model directory name in README examples (Qwen3.5-MoE-A3B →
    Qwen3.5-35B-A3B)

- Materialize state buffers as bf16 (not fp32) to match compute dtype,
  fixing "Expected bfloat16 inputs" error in Triton SDPA during AOTI
  lowering
- Use hasattr(layer.attn, "mask") instead of isinstance(layer.attn,
  FullAttention) for causal mask recompute — the isinstance check fails
  when the module is imported via different Python paths
  (executorch.examples... vs examples...)
- Remove unused FullAttention import from export.py
- Fix model directory name in README examples (Qwen3.5-MoE-A3B →
  Qwen3.5-35B-A3B)
@mergennachin mergennachin requested a review from lucylq as a code owner March 20, 2026 20:28
Copilot AI review requested due to automatic review settings March 20, 2026 20:28
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 20, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18380

Note: Links to docs will display an error until the docs builds have been completed.

⚠️ 1 Awaiting Approval, 101 Pending

As of commit f922e70 with merge base 1cbc24c (image):

AWAITING APPROVAL - The following workflow needs approval before CI can run:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 20, 2026
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the Qwen3.5-MoE export flow to avoid dtype mismatches during AOTI/CUDA lowering and to make causal mask recomputation robust to modules imported from different Python paths.

Changes:

  • Materialize meta-device state buffers as bf16 (keeping masks bool) to match compute dtype and avoid Triton SDPA dtype errors.
  • Recompute causal masks based on hasattr(layer.attn, "mask") instead of isinstance(..., FullAttention) to avoid class-identity mismatches across import paths.
  • Documentation cleanup: fix the example model directory name and remove an unused import.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
examples/models/qwen3_5_moe/export.py Adjusts meta-buffer materialization dtype and makes causal mask recompute detection import-path-agnostic.
examples/models/qwen3_5_moe/README.md Updates example model directory/tokenizer paths to match the referenced HF model name.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +132 to 142
# State buffers (KV cache, conv/recurrent state) are bf16 to match
# compute dtype. Masks stay bool, inv_freq stays float32.
for fqn, buf in list(model.named_buffers()):
if buf.device.type == "meta":
dtype = torch.bfloat16 if buf.dtype != torch.bool else torch.bool
parts = fqn.rsplit(".", 1)
parent = model.get_submodule(parts[0]) if len(parts) > 1 else model
parent.register_buffer(
parts[-1],
torch.zeros(buf.shape, dtype=buf.dtype, device="cpu"),
torch.zeros(buf.shape, dtype=dtype, device="cpu"),
)
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says “inv_freq stays float32”, but the materialization loop currently casts every meta buffer that isn’t bool to bf16. That means the meta placeholder for rotary_emb.inv_freq will be bf16 until the later recompute step overwrites it, which is a bit inconsistent and could break if the recompute logic changes. Consider explicitly preserving float32 for inv_freq (e.g., by checking the buffer name) or defaulting to the original buf.dtype except for known state buffers that must be bf16.

Copilot uses AI. Check for mistakes.
@mergennachin mergennachin merged commit 07b7c7e into main Mar 20, 2026
146 of 150 checks passed
@mergennachin mergennachin deleted the mnachin/qwen3_5_moe branch March 20, 2026 20:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants