Skip to content

[Pallas] adjust_block_size_constraints max-propagates per-tensor TPU alignment across all tile-indexed tensors sharing a block_id #2842

Description

@norx1991

Problem

On the Pallas/TPU backend, when one tile-indexed tensor's dim requires a strict
hardware alignment (e.g. an 8-element sublane alignment on a small fp32 dim),
Helion's adjust_block_size_constraints takes the max alignment requirement
across all tile-indexed tensors sharing that block_id and applies it
uniformly. Tensors that have no alignment requirement of their own get their
block size inflated, which can multiply their VMEM footprint and push the
kernel over the 64 MiB scoped-VMEM cap.

Concrete example

examples/attention.py at B=8 H=32 S=8192 D=256 bf16, autotuned config
block_sizes=[2, 512, 2048], unroll, pb=False:

  • lse is fp32[B*H, S] — its tile-indexed leading dim requires 8-element
    sublane alignment from Mosaic.
  • Q/K/V/out are bf16[B*H, S, D] — they have no per-tensor alignment
    requirement on block_b.
  • All four share the same block_id for block_b.

adjust_block_size_constraints propagates the 8-element requirement to Q/K/V/out:

Variant _BLOCK_SIZE_0 K/V block shape Total scoped VMEM
lse 2D [B*H, S] 8 (inflated from 2) (8, None, None) 238 MB (fails to compile — exceeds 64 MiB cap)
lse 3D [B*H, S, 1] 2 (2, None, None) 34 MB (compiles, 24.72 ms median)

The 3D LSE workaround in #2743 sidesteps the inflation by carrying the
alignment-sensitive dim as a trailing size-1 axis, so the tile-indexed leading
dim is no longer the one Mosaic needs to align.

Suggested direction

Refactor helion/_compiler/backend.py::adjust_block_size_constraints so the
alignment requirement applies per-tensor (via a per-tensor effective block
size at codegen) rather than getting max-merged into the shared block_id's
constraint. The tile loop's iteration count would still be driven by the
common block_id, but per-tensor BlockSpecs would expose only the alignment
each tensor actually needs.

Workaround

When a Pallas kernel has a small alignment-sensitive output sharing a tile
dim with large tensors, restructure the small output so the alignment-sensitive
axis is not the tile-indexed axis. See #2743 for a concrete example.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions