Feature request
Add Flash Attention 2 support for T5Gemma2ForConditionalGeneration (and companion variants: encoder, decoder, etc., wherever attn_implementation="flash_attention_2" currently raises).
Currently, loading the model with FA2 fails at dispatch time:
from transformers import AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained(
"google/t5gemma-2-4b-4b", attn_implementation="flash_attention_2"
)
ValueError: T5Gemma2ForConditionalGeneration does not support Flash Attention 2 yet.
Please request to add support where the model is hosted, on its model hub page:
https://huggingface.co/google/t5gemma-2-4b-4b/discussions/new
or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new
Raised from transformers.modeling_utils._flash_attn_can_dispatch (invoked via _check_and_adjust_attn_implementation). Filing here per the error message's instruction.
Motivation
T5Gemma 2 is advertised as a 128K-context encoder-decoder model; the technical report and model card claim 128K context support and publish Ruler-128K / MRCR-128K benchmark results. At those sequence lengths, FA2 is effectively table stakes for practical inference — eager attention is O(seq²) memory and becomes infeasible past a few K on modern GPUs.
Related: we filed #45521 for a separate bug in the eager attention path (fails above ~4K tokens at batch=1). Even when that's fixed, eager/sdpa-only long-context inference will be memory-bound well before 128K. FA2 would unblock the advertised context window in practice.
Why this is non-trivial (and what might already be reusable)
T5Gemma 2 uses merged self+cross attention in the decoder (§2 of the paper, Section on "T5Gemma 2 architecture"): decoder self-attention and cross-attention to the encoder output are fused into a single joint attention op per layer. This is the novel architectural contribution and the main blocker vs. other Gemma-family models that already have FA2.
Pieces that should be reusable from related model integrations:
- Interleaved local/global + sliding window (5:1 ratio,
sliding_window: 1024, _sliding_window_pattern: 6 in config) — Gemma 3 already supports this with FA2; the T5Gemma 2 decoder inherits the same patterns
- RoPE with split base frequencies (local=10k, global=1M) — also standard Gemma 3
- QK-norm + GQA — standard
- Encoder side (bidirectional) — straightforward FA2
varlen usage
The merged self+cross path is what needs new integration work — likely segment-ids or varlen concatenation of (past_self_KV ∥ encoder_KV) with a two-region mask (causal+SWA on the self part, full on the cross part) piped into FA2.
Your contribution
Happy to test against a PR branch end-to-end on real long-context data (TReB English split has samples up to 28K tokens; we have an existing harness that covers 2.5K / 3.5K / 5K / 6.5K / 7.5K / 10K / 15K / 20K / 25K token lengths with known-passing expected outputs from sdpa/eager below the #45521 threshold). Can provide throughput + memory numbers on H100 NVL before/after.
Not volunteering to author the integration myself — don't have deep familiarity with the FA2 varlen / segment-id APIs and the merged-attention masking logic needs a reviewer who knows T5Gemma 2's design intent.
Related
Thanks!
Feature request
Add Flash Attention 2 support for
T5Gemma2ForConditionalGeneration(and companion variants: encoder, decoder, etc., whereverattn_implementation="flash_attention_2"currently raises).Currently, loading the model with FA2 fails at dispatch time:
Raised from
transformers.modeling_utils._flash_attn_can_dispatch(invoked via_check_and_adjust_attn_implementation). Filing here per the error message's instruction.Motivation
T5Gemma 2 is advertised as a 128K-context encoder-decoder model; the technical report and model card claim 128K context support and publish Ruler-128K / MRCR-128K benchmark results. At those sequence lengths, FA2 is effectively table stakes for practical inference — eager attention is O(seq²) memory and becomes infeasible past a few K on modern GPUs.
Related: we filed #45521 for a separate bug in the eager attention path (fails above ~4K tokens at batch=1). Even when that's fixed, eager/sdpa-only long-context inference will be memory-bound well before 128K. FA2 would unblock the advertised context window in practice.
Why this is non-trivial (and what might already be reusable)
T5Gemma 2 uses merged self+cross attention in the decoder (§2 of the paper, Section on "T5Gemma 2 architecture"): decoder self-attention and cross-attention to the encoder output are fused into a single joint attention op per layer. This is the novel architectural contribution and the main blocker vs. other Gemma-family models that already have FA2.
Pieces that should be reusable from related model integrations:
sliding_window: 1024,_sliding_window_pattern: 6in config) — Gemma 3 already supports this with FA2; the T5Gemma 2 decoder inherits the same patternsvarlenusageThe merged self+cross path is what needs new integration work — likely segment-ids or varlen concatenation of (past_self_KV ∥ encoder_KV) with a two-region mask (causal+SWA on the self part, full on the cross part) piped into FA2.
Your contribution
Happy to test against a PR branch end-to-end on real long-context data (TReB English split has samples up to 28K tokens; we have an existing harness that covers 2.5K / 3.5K / 5K / 6.5K / 7.5K / 10K / 15K / 20K / 25K token lengths with known-passing expected outputs from
sdpa/eagerbelow the #45521 threshold). Can provide throughput + memory numbers on H100 NVL before/after.Not volunteering to author the integration myself — don't have deep familiarity with the FA2
varlen/ segment-id APIs and the merged-attention masking logic needs a reviewer who knows T5Gemma 2's design intent.Related
Thanks!