Problem
On the Pallas/TPU backend, when one tile-indexed tensor's dim requires a strict
hardware alignment (e.g. an 8-element sublane alignment on a small fp32 dim),
Helion's adjust_block_size_constraints takes the max alignment requirement
across all tile-indexed tensors sharing that block_id and applies it
uniformly. Tensors that have no alignment requirement of their own get their
block size inflated, which can multiply their VMEM footprint and push the
kernel over the 64 MiB scoped-VMEM cap.
Concrete example
examples/attention.py at B=8 H=32 S=8192 D=256 bf16, autotuned config
block_sizes=[2, 512, 2048], unroll, pb=False:
lse is fp32[B*H, S] — its tile-indexed leading dim requires 8-element
sublane alignment from Mosaic.
Q/K/V/out are bf16[B*H, S, D] — they have no per-tensor alignment
requirement on block_b.
- All four share the same
block_id for block_b.
adjust_block_size_constraints propagates the 8-element requirement to Q/K/V/out:
| Variant |
_BLOCK_SIZE_0 |
K/V block shape |
Total scoped VMEM |
lse 2D [B*H, S] |
8 (inflated from 2) |
(8, None, None) |
238 MB (fails to compile — exceeds 64 MiB cap) |
lse 3D [B*H, S, 1] |
2 |
(2, None, None) |
34 MB (compiles, 24.72 ms median) |
The 3D LSE workaround in #2743 sidesteps the inflation by carrying the
alignment-sensitive dim as a trailing size-1 axis, so the tile-indexed leading
dim is no longer the one Mosaic needs to align.
Suggested direction
Refactor helion/_compiler/backend.py::adjust_block_size_constraints so the
alignment requirement applies per-tensor (via a per-tensor effective block
size at codegen) rather than getting max-merged into the shared block_id's
constraint. The tile loop's iteration count would still be driven by the
common block_id, but per-tensor BlockSpecs would expose only the alignment
each tensor actually needs.
Workaround
When a Pallas kernel has a small alignment-sensitive output sharing a tile
dim with large tensors, restructure the small output so the alignment-sensitive
axis is not the tile-indexed axis. See #2743 for a concrete example.
Problem
On the Pallas/TPU backend, when one tile-indexed tensor's dim requires a strict
hardware alignment (e.g. an 8-element sublane alignment on a small fp32 dim),
Helion's
adjust_block_size_constraintstakes the max alignment requirementacross all tile-indexed tensors sharing that
block_idand applies ituniformly. Tensors that have no alignment requirement of their own get their
block size inflated, which can multiply their VMEM footprint and push the
kernel over the 64 MiB scoped-VMEM cap.
Concrete example
examples/attention.pyatB=8 H=32 S=8192 D=256bf16, autotuned configblock_sizes=[2, 512, 2048],unroll,pb=False:lseisfp32[B*H, S]— its tile-indexed leading dim requires 8-elementsublane alignment from Mosaic.
Q/K/V/outarebf16[B*H, S, D]— they have no per-tensor alignmentrequirement on
block_b.block_idforblock_b.adjust_block_size_constraintspropagates the 8-element requirement to Q/K/V/out:_BLOCK_SIZE_0[B*H, S](8, None, None)[B*H, S, 1](2, None, None)The 3D LSE workaround in #2743 sidesteps the inflation by carrying the
alignment-sensitive dim as a trailing size-1 axis, so the tile-indexed leading
dim is no longer the one Mosaic needs to align.
Suggested direction
Refactor
helion/_compiler/backend.py::adjust_block_size_constraintsso thealignment requirement applies per-tensor (via a per-tensor effective block
size at codegen) rather than getting max-merged into the shared
block_id'sconstraint. The tile loop's iteration count would still be driven by the
common block_id, but per-tensor BlockSpecs would expose only the alignment
each tensor actually needs.
Workaround
When a Pallas kernel has a small alignment-sensitive output sharing a tile
dim with large tensors, restructure the small output so the alignment-sensitive
axis is not the tile-indexed axis. See #2743 for a concrete example.