You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Proposal to add a new attn_implementation="sdpa_memeff" that pins torch's SDPA dispatcher to SDPBackend.EFFICIENT_ATTENTION (via sdpa_kernel([EFFICIENT_ATTENTION]) wrapping the existing sdpa_attention_forward). Filing as an issue to validate design direction before opening a PR.
Motivation — two independent failure modes that attn_implementation="sdpa" doesn't handle well
1. head_dim > 256
On every CUDA arch I've tested (RTX 5090 sm_120; same library caps apply on H100 sm_90):
Backend
head_dim=320 (Gemma 4)
FLASH
REJECT (library cap: head_dim ≤ 256)
CUDNN
REJECT (library cap: head_dim ≤ 256)
EFFICIENT (CUTLASS mem-eff)
ACCEPT
MATH
ACCEPT (but O(seq²) fp32 softmax — 74 GB at seq=32k, OOMs on 96 GB card)
So for Gemma 4 (google/gemma-4-31B, gemma-4-26B-A4B, gemma-4-E4B, gemma-4-E2B — all head_dim=320) there's only one fast-path backend, and stock sdpa has no way to pin it reliably.
2. Input layouts that disqualify FLASH at any head_dim
pytorch/pytorch#44928 — Qwen3.5 RLHF training sees dense 4D mask materialization (from 3D position_ids), which disqualifies FLASH's is_causal=True short-circuit. Dispatcher falls through to MATH and triggers NaN gradients in bf16.
Same root shape of problem as (1), driven by layout rather than head_dim. A manual with sdpa_kernel([EFFICIENT_ATTENTION]): ... workaround exists but requires model-code edits; process-global torch.backends.cuda.enable_math_sdp(False) has side effects.
Proposed fix
New file src/transformers/integrations/sdpa_memeff.py. Registers "sdpa_memeff" in ALL_ATTENTION_FUNCTIONS. Copies the existing sdpa_attention_forward with two additions:
Wraps the actual SDPA call in with sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION]): ...
Verified locally on full 8B Gemma-4-E4B-it (bf16, real checkpoint): 100% top-1 agreement vs stock sdpa where stock works, and correct execution where stock falls back to MATH and OOMs.
Summary
Proposal to add a new
attn_implementation="sdpa_memeff"that pins torch's SDPA dispatcher toSDPBackend.EFFICIENT_ATTENTION(viasdpa_kernel([EFFICIENT_ATTENTION])wrapping the existingsdpa_attention_forward). Filing as an issue to validate design direction before opening a PR.Motivation — two independent failure modes that
attn_implementation="sdpa"doesn't handle well1. head_dim > 256
On every CUDA arch I've tested (RTX 5090 sm_120; same library caps apply on H100 sm_90):
head_dim ≤ 256)head_dim ≤ 256)So for Gemma 4 (
google/gemma-4-31B,gemma-4-26B-A4B,gemma-4-E4B,gemma-4-E2B— all head_dim=320) there's only one fast-path backend, and stocksdpahas no way to pin it reliably.2. Input layouts that disqualify FLASH at any head_dim
pytorch/pytorch#44928 — Qwen3.5 RLHF training sees dense 4D mask materialization (from 3D position_ids), which disqualifies FLASH's
is_causal=Trueshort-circuit. Dispatcher falls through to MATH and triggers NaN gradients in bf16.Same root shape of problem as (1), driven by layout rather than head_dim. A manual
with sdpa_kernel([EFFICIENT_ATTENTION]): ...workaround exists but requires model-code edits; process-globaltorch.backends.cuda.enable_math_sdp(False)has side effects.Proposed fix
New file
src/transformers/integrations/sdpa_memeff.py. Registers"sdpa_memeff"inALL_ATTENTION_FUNCTIONS. Copies the existingsdpa_attention_forwardwith two additions:with sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION]): ...repeat_kv(EFFICIENT_ATTENTIONrejects dense GQA wherenum_heads_q != num_heads_kv).Verified locally on full 8B Gemma-4-E4B-it (bf16, real checkpoint): 100% top-1 agreement vs stock
sdpawhere stock works, and correct execution where stock falls back to MATH and OOMs.Questions for maintainers
Shape of the contribution. Dedicated
"sdpa_memeff"name vs. a config knob onsdpa(e.g.sdpa_preferred_backends=["EFFICIENT"]) vs. the pipe syntax introduced by PR [attn_implementation] remove recursive, allows custom kernels with wrappers #39823 ("sdpa|efficient_attention")?Overlap with pytorch-side fixes. pytorch's SDPA dispatcher is missing sm_120 support for CUDNN head_dim=256 (filed separately). Once that lands, CUDNN handles head_dim ≤ 256 on sm_120 — but head_dim=320 (Gemma 4) is still library-capped on both FLASH and CUDNN, and [Bug] Catastrophic gradient explosion (NaN) in RLHF with Qwen3.5 due to 3D position_ids forcing SDPA Math fallback and BF16 collapse #44928-class layout-induced FLASH disqualification is unaffected. So
sdpa_memeffisn't made redundant.Test coverage you'd want. Convergence (loss-curve equivalence with stock sdpa on a small model)? Numerical (per-batch logit diff)? Both?
Anything first-time-contributor I should know — CLA, preferred test location, maintainer norms for adding to
ALL_ATTENTION_FUNCTIONS, etc.?Happy to open a PR if this direction is right.
🤖 Drafted with Claude Code (Claude Opus 4.7), reviewed and posted by me.