Skip to content

T5Gemma2: decoder self-attention fixed 4097-element mask at batch=1, fails on inputs >4094 tokens #45521

@arunkumarchithanar

Description

@arunkumarchithanar

System Info

  • transformers 5.0.0 (T5Gemma 2 support shipped in T5gemma2 #41834)
  • torch 2.8.0, CUDA 12.8.1, Python 3.12
  • Hardware: 1× NVIDIA H100 NVL 94 GB (reproduced on same bug under A100 80 GB SXM)
  • Model: google/t5gemma-2-4b-4b (gated)
  • Base image: runpod/pytorch:1.0.3-cu1281-torch280-ubuntu2404

Who can help?

@ArthurZucker @gante — attention / generation internals for T5Gemma 2.

Reproduction

AutoModelForSeq2SeqLM.from_pretrained(..., attn_implementation="eager") + model.generate() at batch_size=1 raises a shape-mismatch on the decoder self-attention as soon as the input exceeds ~4094 tokens. Batching is not required to trigger this — batch=1 is enough.

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

MODEL = "google/t5gemma-2-4b-4b"
tok = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL, dtype=torch.bfloat16, attn_implementation="eager", device_map="auto",
).eval()

filler = "| a | b | c |\n|---|---|---|\n| 1 | 2 | 3 |\n"
for n in (3500, 4090, 4100, 5000, 8000):
    prompt = "Answer.\n\nTable:\n" + filler * (n // 20) + "\n\nQ: sum?"
    ids = tok(prompt, return_tensors="pt", truncation=True, max_length=n).input_ids.to(model.device)
    try:
        with torch.no_grad():
            out = model.generate(ids, max_new_tokens=16, do_sample=False)
        print(f"OK  len={ids.shape[-1]}")
    except RuntimeError as e:
        print(f"FAIL len={ids.shape[-1]}: {e}")

Expected behavior

T5Gemma 2's decoder config advertises max_position_embeddings: 131072, so generate() should handle inputs well beyond 4K tokens at batch=1. (Qwen 2.5 7B Instruct via vLLM on the same TReB prompts handles up to the dataset max of 28,117 tokens with no issue.)

Actual behavior

On real TReB English samples (JT-LM/JIUTIAN-TReB), measured 2026-04-20:

input tokens result RuntimeError "b" value
2497 OK
3525 OK
5015 FAIL 5018
6499 FAIL 6502
7493 FAIL 7496
9997 FAIL 10000
14967 FAIL 14970
19808 FAIL 19811
25135 FAIL 25138

Error (abridged):

File ".../transformers/models/t5gemma2/modeling_t5gemma2.py", line 544, in forward
    hidden_states, _, _ = self.self_attn(
File ".../transformers/models/t5gemma2/modeling_t5gemma2.py", line 431, in forward
    attn_output, attn_weights = attention_interface(
File ".../transformers/models/t5gemma2/modeling_t5gemma2.py", line 244, in eager_attention_forward
    attn_weights = attn_weights + attention_mask
RuntimeError: The size of tensor a (4097) must match the size of
tensor b (5018) at non-singleton dimension 3

Key fingerprint: tensor a (4097) is constant across every failure, regardless of actual input length; tensor b (N) equals input_length + 3 exactly (three special tokens added by the tokenizer).

Hypothesis

Decoder config has:

sliding_window: 1024
_sliding_window_pattern: 6
max_position_embeddings: 131072

The constant 4097 = 4 × 1024 + 1 looks like a pre-allocated 4-window attention buffer + 1 query token. It isn't being resized when the input exceeds 4096 tokens — attn_weights keeps the fixed 4097 shape while attention_mask tracks the actual sequence length.

Looking at T5Gemma2Decoder.forward (~line 1070-1100 in modeling_t5gemma2.py), decoder self-attention masks come from create_causal_mask / create_sliding_window_causal_mask in transformers/masking_utils.py, called with self.config + inputs_embeds + position_ids. The mismatch between the (4097-wide) attention buffer and the mask that comes back from one of these helpers is probably where this originates. T5Gemma 2's decoder is also documented as using merged self+cross attention with the comment "we always need a mask during decoding for merged attention" (around line 1072) — this merging may be interacting with the sliding-window mask constructor in a way the other Gemma variants don't hit.

Workarounds tried

  • attn_implementation="sdpa" → same class of shape mismatch at longer lengths
  • attn_implementation="flash_attention_2"ValueError: T5Gemma2ForConditionalGeneration does not support Flash Attention 2 yet (from _flash_attn_can_dispatch)
  • max_input_tokens ≤ 4094 avoids the crash but caps an advertised-128K-context model at 4K, which is much less than the sliding_window config would suggest

Full test harness + data (public)

The reproducer above is the minimal version. If it helps, the full testing we did is in a public repo — real TReB English samples, specific IDs, eval driver and monitor: https://github.com/junos-ai-org/jiutian-treb/blob/experiment-setup/experiments/t5gemma_vs_qwen_treb/insights/transformers_bug_repro.py

The 9 failing sample IDs above come from the public HF dataset JT-LM/JIUTIAN-TReB (English split), so anyone can load them and reproduce end-to-end without additional setup. Raw prediction outputs and error logs from our diagnostic run are in experiments/t5gemma_vs_qwen_treb/results/t5gemma_base_threshold/ (gitignored, available on request).

Related (all closed as completed)

Same Gemma-family class of bug, each independently filed and fixed:

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    Good Second IssueIssues that are more difficult to do than "Good First" issues - give it a try if you want!

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions