Skip to content

Feature request: Flash Attention 2 support for T5Gemma 2 #45522

@arunkumarchithanar

Description

@arunkumarchithanar

Feature request

Add Flash Attention 2 support for T5Gemma2ForConditionalGeneration (and companion variants: encoder, decoder, etc., wherever attn_implementation="flash_attention_2" currently raises).

Currently, loading the model with FA2 fails at dispatch time:

from transformers import AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained(
    "google/t5gemma-2-4b-4b", attn_implementation="flash_attention_2"
)
ValueError: T5Gemma2ForConditionalGeneration does not support Flash Attention 2 yet.
Please request to add support where the model is hosted, on its model hub page:
https://huggingface.co/google/t5gemma-2-4b-4b/discussions/new
or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new

Raised from transformers.modeling_utils._flash_attn_can_dispatch (invoked via _check_and_adjust_attn_implementation). Filing here per the error message's instruction.

Motivation

T5Gemma 2 is advertised as a 128K-context encoder-decoder model; the technical report and model card claim 128K context support and publish Ruler-128K / MRCR-128K benchmark results. At those sequence lengths, FA2 is effectively table stakes for practical inference — eager attention is O(seq²) memory and becomes infeasible past a few K on modern GPUs.

Related: we filed #45521 for a separate bug in the eager attention path (fails above ~4K tokens at batch=1). Even when that's fixed, eager/sdpa-only long-context inference will be memory-bound well before 128K. FA2 would unblock the advertised context window in practice.

Why this is non-trivial (and what might already be reusable)

T5Gemma 2 uses merged self+cross attention in the decoder (§2 of the paper, Section on "T5Gemma 2 architecture"): decoder self-attention and cross-attention to the encoder output are fused into a single joint attention op per layer. This is the novel architectural contribution and the main blocker vs. other Gemma-family models that already have FA2.

Pieces that should be reusable from related model integrations:

  • Interleaved local/global + sliding window (5:1 ratio, sliding_window: 1024, _sliding_window_pattern: 6 in config) — Gemma 3 already supports this with FA2; the T5Gemma 2 decoder inherits the same patterns
  • RoPE with split base frequencies (local=10k, global=1M) — also standard Gemma 3
  • QK-norm + GQA — standard
  • Encoder side (bidirectional) — straightforward FA2 varlen usage

The merged self+cross path is what needs new integration work — likely segment-ids or varlen concatenation of (past_self_KV ∥ encoder_KV) with a two-region mask (causal+SWA on the self part, full on the cross part) piped into FA2.

Your contribution

Happy to test against a PR branch end-to-end on real long-context data (TReB English split has samples up to 28K tokens; we have an existing harness that covers 2.5K / 3.5K / 5K / 6.5K / 7.5K / 10K / 15K / 20K / 25K token lengths with known-passing expected outputs from sdpa/eager below the #45521 threshold). Can provide throughput + memory numbers on H100 NVL before/after.

Not volunteering to author the integration myself — don't have deep familiarity with the FA2 varlen / segment-id APIs and the merged-attention masking logic needs a reviewer who knows T5Gemma 2's design intent.

Related

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    Good Second IssueIssues that are more difficult to do than "Good First" issues - give it a try if you want!

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions