You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
AutoModelForSeq2SeqLM.from_pretrained(..., attn_implementation="eager") + model.generate()at batch_size=1 raises a shape-mismatch on the decoder self-attention as soon as the input exceeds ~4094 tokens. Batching is not required to trigger this — batch=1 is enough.
T5Gemma 2's decoder config advertises max_position_embeddings: 131072, so generate() should handle inputs well beyond 4K tokens at batch=1. (Qwen 2.5 7B Instruct via vLLM on the same TReB prompts handles up to the dataset max of 28,117 tokens with no issue.)
Actual behavior
On real TReB English samples (JT-LM/JIUTIAN-TReB), measured 2026-04-20:
input tokens
result
RuntimeError "b" value
2497
OK
—
3525
OK
—
5015
FAIL
5018
6499
FAIL
6502
7493
FAIL
7496
9997
FAIL
10000
14967
FAIL
14970
19808
FAIL
19811
25135
FAIL
25138
Error (abridged):
File ".../transformers/models/t5gemma2/modeling_t5gemma2.py", line 544, in forward
hidden_states, _, _ = self.self_attn(
File ".../transformers/models/t5gemma2/modeling_t5gemma2.py", line 431, in forward
attn_output, attn_weights = attention_interface(
File ".../transformers/models/t5gemma2/modeling_t5gemma2.py", line 244, in eager_attention_forward
attn_weights = attn_weights + attention_mask
RuntimeError: The size of tensor a (4097) must match the size of
tensor b (5018) at non-singleton dimension 3
Key fingerprint:tensor a (4097) is constant across every failure, regardless of actual input length; tensor b (N) equals input_length + 3 exactly (three special tokens added by the tokenizer).
The constant 4097 = 4 × 1024 + 1 looks like a pre-allocated 4-window attention buffer + 1 query token. It isn't being resized when the input exceeds 4096 tokens — attn_weights keeps the fixed 4097 shape while attention_mask tracks the actual sequence length.
Looking at T5Gemma2Decoder.forward (~line 1070-1100 in modeling_t5gemma2.py), decoder self-attention masks come from create_causal_mask / create_sliding_window_causal_mask in transformers/masking_utils.py, called with self.config + inputs_embeds + position_ids. The mismatch between the (4097-wide) attention buffer and the mask that comes back from one of these helpers is probably where this originates. T5Gemma 2's decoder is also documented as using merged self+cross attention with the comment "we always need a mask during decoding for merged attention" (around line 1072) — this merging may be interacting with the sliding-window mask constructor in a way the other Gemma variants don't hit.
Workarounds tried
attn_implementation="sdpa" → same class of shape mismatch at longer lengths
attn_implementation="flash_attention_2" → ValueError: T5Gemma2ForConditionalGeneration does not support Flash Attention 2 yet (from _flash_attn_can_dispatch)
max_input_tokens ≤ 4094 avoids the crash but caps an advertised-128K-context model at 4K, which is much less than the sliding_window config would suggest
The 9 failing sample IDs above come from the public HF dataset JT-LM/JIUTIAN-TReB (English split), so anyone can load them and reproduce end-to-end without additional setup. Raw prediction outputs and error logs from our diagnostic run are in experiments/t5gemma_vs_qwen_treb/results/t5gemma_base_threshold/ (gitignored, available on request).
Related (all closed as completed)
Same Gemma-family class of bug, each independently filed and fixed:
System Info
5.0.0(T5Gemma 2 support shipped in T5gemma2 #41834)2.8.0, CUDA12.8.1, Python3.12google/t5gemma-2-4b-4b(gated)runpod/pytorch:1.0.3-cu1281-torch280-ubuntu2404Who can help?
@ArthurZucker @gante — attention / generation internals for T5Gemma 2.
Reproduction
AutoModelForSeq2SeqLM.from_pretrained(..., attn_implementation="eager")+model.generate()atbatch_size=1raises a shape-mismatch on the decoder self-attention as soon as the input exceeds ~4094 tokens. Batching is not required to trigger this — batch=1 is enough.Expected behavior
T5Gemma 2's decoder config advertises
max_position_embeddings: 131072, sogenerate()should handle inputs well beyond 4K tokens at batch=1. (Qwen 2.5 7B Instruct via vLLM on the same TReB prompts handles up to the dataset max of 28,117 tokens with no issue.)Actual behavior
On real TReB English samples (
JT-LM/JIUTIAN-TReB), measured 2026-04-20:RuntimeError"b" valueError (abridged):
Key fingerprint:
tensor a (4097)is constant across every failure, regardless of actual input length;tensor b (N)equalsinput_length + 3exactly (three special tokens added by the tokenizer).Hypothesis
Decoder config has:
The constant
4097 = 4 × 1024 + 1looks like a pre-allocated 4-window attention buffer + 1 query token. It isn't being resized when the input exceeds 4096 tokens —attn_weightskeeps the fixed 4097 shape whileattention_masktracks the actual sequence length.Looking at
T5Gemma2Decoder.forward(~line 1070-1100 inmodeling_t5gemma2.py), decoder self-attention masks come fromcreate_causal_mask/create_sliding_window_causal_maskintransformers/masking_utils.py, called withself.config+inputs_embeds+position_ids. The mismatch between the (4097-wide) attention buffer and the mask that comes back from one of these helpers is probably where this originates. T5Gemma 2's decoder is also documented as using merged self+cross attention with the comment "we always need a mask during decoding for merged attention" (around line 1072) — this merging may be interacting with the sliding-window mask constructor in a way the other Gemma variants don't hit.Workarounds tried
attn_implementation="sdpa"→ same class of shape mismatch at longer lengthsattn_implementation="flash_attention_2"→ValueError: T5Gemma2ForConditionalGeneration does not support Flash Attention 2 yet(from_flash_attn_can_dispatch)max_input_tokens ≤ 4094avoids the crash but caps an advertised-128K-context model at 4K, which is much less than thesliding_windowconfig would suggestFull test harness + data (public)
The reproducer above is the minimal version. If it helps, the full testing we did is in a public repo — real TReB English samples, specific IDs, eval driver and monitor: https://github.com/junos-ai-org/jiutian-treb/blob/experiment-setup/experiments/t5gemma_vs_qwen_treb/insights/transformers_bug_repro.py
The 9 failing sample IDs above come from the public HF dataset
JT-LM/JIUTIAN-TReB(English split), so anyone can load them and reproduce end-to-end without additional setup. Raw prediction outputs and error logs from our diagnostic run are inexperiments/t5gemma_vs_qwen_treb/results/t5gemma_base_threshold/(gitignored, available on request).Related (all closed as completed)
Same Gemma-family class of bug, each independently filed and fixed:
Seq2SeqLM.generateThanks!