Skip to content

Commit 67e69cb

Browse files
gambletanclaude
andcommitted
fix: compute image_seq_len from spatial dims, not channel dim in Lumina2 pipeline
Fixes #12913 `image_seq_len` was computed as `latents.shape[1]`, which gives the channel dimension (e.g. 16) since Lumina2 latents have shape `(batch, channels, height, width)` and are NOT packed/reshaped before this point. The Lumina2 transformer internally patchifies the latents with `patch_size=2`, so the correct spatial sequence length is `(H // patch_size) * (W // patch_size)`. This incorrect value was passed to `calculate_shift()`, which computes the `mu` parameter for the flow-matching scheduler. Using channel count instead of token count produces a completely wrong shift, degrading generation quality. The fix reads `patch_size` from `self.transformer.config.patch_size` and computes `image_seq_len` from the last two (spatial) dimensions of the latents tensor, matching how the transformer itself computes its input sequence length. For reference, the Flux pipeline correctly uses `latents.shape[1]` because Flux latents are pre-packed into `(batch, seq_len, channels)` before this computation. Lumina2 does not pre-pack, so the same indexing does not apply. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent e5aa719 commit 67e69cb

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

src/diffusers/pipelines/lumina2/pipeline_lumina2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,8 @@ def __call__(
696696

697697
# 5. Prepare timesteps
698698
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
699-
image_seq_len = latents.shape[1]
699+
patch_size = self.transformer.config.patch_size
700+
image_seq_len = (latents.shape[-2] // patch_size) * (latents.shape[-1] // patch_size)
700701
mu = calculate_shift(
701702
image_seq_len,
702703
self.scheduler.config.get("base_image_seq_len", 256),

tests/pipelines/lumina2/test_pipeline_lumina2.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,68 @@ def get_dummy_components(self):
9999
}
100100
return components
101101

102+
def test_image_seq_len_uses_spatial_dimensions(self):
103+
"""Test that image_seq_len is computed from spatial dims, not channel dim.
104+
105+
Lumina2 latents have shape (batch, channels, height, width) and are NOT
106+
packed before image_seq_len is computed. The transformer patchifies
107+
internally with patch_size=2, so the correct sequence length is
108+
(H // patch_size) * (W // patch_size).
109+
110+
Previously, the code used latents.shape[1] which gives the channel
111+
count (e.g. 4) instead of the spatial sequence length (e.g. 64 for
112+
16x16 latents with patch_size=2). This caused calculate_shift() to
113+
compute a completely wrong mu value for the scheduler.
114+
"""
115+
components = self.get_dummy_components()
116+
pipe = Lumina2Pipeline(**components)
117+
pipe.to(torch.device("cpu"))
118+
119+
patch_size = pipe.transformer.config.patch_size # 2
120+
121+
# Use height=32, width=32 -> latent size 4x4 (vae downscale 8x)
122+
# With patch_size=2: seq_len = (4//2)*(4//2) = 4
123+
# Channel dim = 4, which would be wrong if used as seq_len
124+
# Use a larger size to make the distinction clearer
125+
height, width = 64, 64
126+
latent_h, latent_w = height // 8, width // 8 # 8, 8
127+
expected_seq_len = (latent_h // patch_size) * (latent_w // patch_size) # 16
128+
129+
# The channel dimension is 4 (from vae latent_channels)
130+
# If the bug were present, image_seq_len would be 4 instead of 16
131+
channels = components["vae"].config.latent_channels # 4
132+
self.assertNotEqual(channels, expected_seq_len, "Test needs channels != expected_seq_len to be meaningful")
133+
134+
# Capture the mu value passed to the scheduler
135+
captured = {}
136+
original_set_timesteps = pipe.scheduler.set_timesteps
137+
138+
def capture_mu_set_timesteps(*args, **kwargs):
139+
captured["mu"] = kwargs.get("mu")
140+
return original_set_timesteps(*args, **kwargs)
141+
142+
pipe.scheduler.set_timesteps = capture_mu_set_timesteps
143+
144+
# Run pipeline with specific dimensions
145+
generator = torch.Generator(device="cpu").manual_seed(0)
146+
pipe(
147+
prompt="test",
148+
height=height,
149+
width=width,
150+
num_inference_steps=1,
151+
generator=generator,
152+
output_type="latent",
153+
)
154+
155+
# Verify mu was computed using spatial seq_len, not channel dim
156+
from diffusers.pipelines.lumina2.pipeline_lumina2 import calculate_shift
157+
158+
correct_mu = calculate_shift(expected_seq_len)
159+
wrong_mu = calculate_shift(channels)
160+
161+
self.assertAlmostEqual(captured["mu"], correct_mu, places=5, msg="mu should use spatial sequence length")
162+
self.assertNotAlmostEqual(captured["mu"], wrong_mu, places=5, msg="mu should NOT use channel dimension")
163+
102164
def get_dummy_inputs(self, device, seed=0):
103165
if str(device).startswith("mps"):
104166
generator = torch.manual_seed(seed)

0 commit comments

Comments
 (0)