Skip to content

Proposal: add sdpa_memeff attn_implementation for shape combinations no fast backend covers #45636

@dvdimitrov13

Description

@dvdimitrov13

Summary

Proposal to add a new attn_implementation="sdpa_memeff" that pins torch's SDPA dispatcher to SDPBackend.EFFICIENT_ATTENTION (via sdpa_kernel([EFFICIENT_ATTENTION]) wrapping the existing sdpa_attention_forward). Filing as an issue to validate design direction before opening a PR.

Motivation — two independent failure modes that attn_implementation="sdpa" doesn't handle well

1. head_dim > 256

On every CUDA arch I've tested (RTX 5090 sm_120; same library caps apply on H100 sm_90):

Backend head_dim=320 (Gemma 4)
FLASH REJECT (library cap: head_dim ≤ 256)
CUDNN REJECT (library cap: head_dim ≤ 256)
EFFICIENT (CUTLASS mem-eff) ACCEPT
MATH ACCEPT (but O(seq²) fp32 softmax — 74 GB at seq=32k, OOMs on 96 GB card)

So for Gemma 4 (google/gemma-4-31B, gemma-4-26B-A4B, gemma-4-E4B, gemma-4-E2B — all head_dim=320) there's only one fast-path backend, and stock sdpa has no way to pin it reliably.

2. Input layouts that disqualify FLASH at any head_dim

pytorch/pytorch#44928 — Qwen3.5 RLHF training sees dense 4D mask materialization (from 3D position_ids), which disqualifies FLASH's is_causal=True short-circuit. Dispatcher falls through to MATH and triggers NaN gradients in bf16.

Same root shape of problem as (1), driven by layout rather than head_dim. A manual with sdpa_kernel([EFFICIENT_ATTENTION]): ... workaround exists but requires model-code edits; process-global torch.backends.cuda.enable_math_sdp(False) has side effects.

Proposed fix

New file src/transformers/integrations/sdpa_memeff.py. Registers "sdpa_memeff" in ALL_ATTENTION_FUNCTIONS. Copies the existing sdpa_attention_forward with two additions:

  1. Wraps the actual SDPA call in with sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION]): ...
  2. Unconditional GQA repeat_kv (EFFICIENT_ATTENTION rejects dense GQA where num_heads_q != num_heads_kv).

Verified locally on full 8B Gemma-4-E4B-it (bf16, real checkpoint): 100% top-1 agreement vs stock sdpa where stock works, and correct execution where stock falls back to MATH and OOMs.

Questions for maintainers

  1. Shape of the contribution. Dedicated "sdpa_memeff" name vs. a config knob on sdpa (e.g. sdpa_preferred_backends=["EFFICIENT"]) vs. the pipe syntax introduced by PR [attn_implementation] remove recursive, allows custom kernels with wrappers #39823 ("sdpa|efficient_attention")?

  2. Overlap with pytorch-side fixes. pytorch's SDPA dispatcher is missing sm_120 support for CUDNN head_dim=256 (filed separately). Once that lands, CUDNN handles head_dim ≤ 256 on sm_120 — but head_dim=320 (Gemma 4) is still library-capped on both FLASH and CUDNN, and [Bug] Catastrophic gradient explosion (NaN) in RLHF with Qwen3.5 due to 3D position_ids forcing SDPA Math fallback and BF16 collapse #44928-class layout-induced FLASH disqualification is unaffected. So sdpa_memeff isn't made redundant.

  3. Test coverage you'd want. Convergence (loss-curve equivalence with stock sdpa on a small model)? Numerical (per-batch logit diff)? Both?

  4. Anything first-time-contributor I should know — CLA, preferred test location, maintainer norms for adding to ALL_ATTENTION_FUNCTIONS, etc.?

Happy to open a PR if this direction is right.


🤖 Drafted with Claude Code (Claude Opus 4.7), reviewed and posted by me.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions