Skip to content

[autotuner] composed-fact seed for fused matmul + reduction-epilogue#2846

Open
calebmkim wants to merge 1 commit into
calebmkim/stack/8from
calebmkim/stack/23
Open

[autotuner] composed-fact seed for fused matmul + reduction-epilogue#2846
calebmkim wants to merge 1 commit into
calebmkim/stack/8from
calebmkim/stack/23

Conversation

@calebmkim

@calebmkim calebmkim commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

Stacked PRs:


[autotuner] composed-fact seed for fused matmul + reduction-epilogue

The first composed fact. MatmulWithReductionEpilogueFact holds a MatmulFact + the
over-output-axis epilogue ReductionFact for a fused matmul + reduction (matmul_rms_norm /
matmul_layernorm / matmul_softmax / matmul_l2_normalize / matmul_logsumexp / ...).
build_matmul_reduction_epilogue_facts composes the two when exactly one of each co-occurs;
TritonMatmulReductionEpilogueHeuristic then seeds a footprint-aware tile. Fires only on the
fused family — a pure matmul and a pure reduction do not compose, so the reduction curriculum stays
byte-identical.

Resident (row-in-registers) only. The seed assumes the whole reduced row lives in the
register-resident [M_BLOCK, N] fp32 accumulator (N is hl.specialize'd → n_block_id is None);
is_eligible fires only on that shape. The earlier looped/tiled-N branch (perf-unverified) was
removed per review.

Depends on #2829 (stage 2a) — this is the first consumer of its aperture. 2a removed the
if spec.matmul_facts: return guard in register_user_tiled_reductions, which lets the materialized
branch register the over-N epilogue reduction even when a matmul is present. Without that, the
composed fact never forms. It does not depend on 2b/#2830's per_feature_accumulator (stacked on
#2830 for ordering; re-parentable to #2829).

What changes

  • config_spec.pyMatmulWithReductionEpilogueFact (composes the two sub-facts + n_extent +
    m/k_block_id; holds them, does not re-derive).
  • device_ir.pybuild_matmul_reduction_epilogue_facts (Phase 4): composes iff exactly 1
    MatmulFact + 1 ReductionFact.
  • triton.pyTritonMatmulReductionEpilogueHeuristic: footprint-aware M_BLOCK (largest pow2
    whose [M_BLOCK, N] fp32 accumulator fits ACC_BUDGET_BYTES=128 KiB, dtype-scaled, capped at 64),
    K_BLOCK=32, num_stages=3, num_warps ramp, plus a staged-SMEM byte-cap that extends the
    feasible-N range (below).
  • __init__.py — registers the heuristic. Reduction heuristics gate on not matmul_facts, so the
    fused family is seeded by this one alone; pure matmul keeps its skinny-gemm seed.

Staged-SMEM byte-cap (feasible-N extension)

The resident kernel's SMEM is dominated by the staged [K_BLOCK, N] y-operand
(num_stages·K_BLOCK·N·itemsize); past the small-N range [K_BLOCK=32, num_stages=3] overflows it.
The seed caps that footprint to SMEM_STAGED_BUDGET_BYTES=192 KiB: when over, drop K_BLOCK 32→16
first
(halves the staged bytes and sidesteps a measured non-monotonic K_BLOCK=32 ptxas cliff),
then drop num_stages only if still over. Small-N stays byte-identical (the cap doesn't fire).
Because n_extent is the next power of two of N, the emitted bands are KB32/st3KB16/st3
KB16/st1 (the KB16/st2 band never occurs).

Net: the seed compiles to N≤4096 bf16 / N≤2048 fp32 — roughly 2× the default's wall
(N≤2048 / N≤1024) — and fails closed beyond (clean OutOfResources → the autotuner falls back;
the seed is never the forced default, promote_seed_to_default=False). Where the default OOMs but the
seed still compiles (bf16 N=3072/4096, fp32 N=1536/2048, the 2048³ fp32 square), the seed is a
feasibility rescue.

Performance

H100 sm90, fwd-only, cold-L2 CUDA-graph device time (256 MiB L2 flush per replay), accuracy-gated
(relerr vs an fp32 reference). seed = the heuristic's emitted config (extracted faithfully via
kernel.bindget_seed_config); default = config_spec.default_config() (the bar); tc =
torch.compile with native-dtype matmul + fp32 accumulate. >1 = seed faster. "rescue" = default
uncompilable (OOM), seed runs. The seed is epilogue-blind, so the 5 kernels track each other at
matched shapes.

kernel (M, K, N) dtype regime seed/default seed/tc
layernorm 131072, 512, 512 bf16 small-N 1.82× 1.15×
layernorm 131072, 512, 512 fp32 small-N 1.55× 2.92×
layernorm 131072, 512, 2048 bf16 larger-N 14.57× 0.37×
layernorm 131072, 512, 4096 bf16 larger-N rescue 0.13×
layernorm 131072, 4096, 2048 bf16 larger-K 13.65× 0.19×
layernorm 512³ bf16 square 1.72× 0.59×
layernorm 1024³ bf16 square 5.95× 0.34×
layernorm 1024³ fp32 square 26.94× 0.99×
layernorm 2048³ fp32 square rescue 0.77×
rmsnorm 1024³ bf16 square 5.69× 0.34×
rmsnorm 131072, 512, 4096 bf16 larger-N rescue 0.12×
rmsnorm 131072, 8192, 2048 bf16 larger-K 12.34× 0.15×
softmax 1024³ fp32 square 29.69× 0.98×
softmax 131072, 512, 2048 fp32 larger-N rescue 1.22×
softmax 131072, 2048, 2048 bf16 larger-K 13.63× 0.23×
l2_normalize 2048³ bf16 square 13.88× 0.18×
l2_normalize 131072, 512, 4096 bf16 larger-N rescue 0.12×
l2_normalize 131072, 6144, 1024 bf16 larger-K 3.31× 0.39×
logsumexp 1024³ bf16 square 5.54× 0.36×
logsumexp 131072, 512, 4096 bf16 larger-N rescue 0.10×
logsumexp 131072, 3072, 2048 bf16 larger-K 3.63× 0.19×

vs the Helion default (the bar): geomean 6.93× over the 15 both-compile cells (1.55–29.69×), no
regressions, plus 6 feasibility rescues.
The default is catastrophic at large N/K — it spills the
huge fp32 accumulator to local memory (e.g. 257 ms / 509 ms at K=4096/8192 vs the seed's 19 / 41 ms),
or OOMs outright — which the seed's num_warps ramp + byte-cap avoid.

vs torch.compile: fp32 is competitive-to-winning (0.77–2.92×; the softmax N=2048 fp32 rescue beats
it 1.22× while the default can't even run); bf16 loses at high arithmetic intensity
(large-N /
large-K / squares, 0.10–0.59×, where cuBLAS bf16 tensor-core GEMM dominates) and wins only at small N.
This is the expected, documented regime: the resident matmul+epilogue is GEMM/occupancy-bound at large
N, so the seed's role is to be a far better seed and feasibility backstop than the Helion default
(which it is everywhere), not to beat cuBLAS.

Scope / limitations

  • Resident structure → feasibility is N-bounded (seed N≤4096 bf16 / N≤2048 fp32; fails closed
    beyond, autotuner falls back). Large-square / large-N is occupancy- and GEMM-bound and loses to
    cuBLAS — a structural ceiling of the resident pattern (documented in the table), not a seed defect;
    a competitive large-N kernel needs a tiled/looped 2D-GEMM structure (separate work).

Disjointness / no-regression

  • Composed fact fires only on the fused family: fused → composed=1; pure matmul → 0 (keeps
    skinny-gemm); pure reduction → 0. The 9 reduction + 8 transfer curriculum kernels have no matmul, so
    the fact never forms — config_recorder byte-identical. Small-N matmul+epilogue configs are
    byte-identical with/without the byte-cap (it doesn't fire there).

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 23, 2026
calebmkim pushed a commit that referenced this pull request Jun 23, 2026
The first **composed** fact. `MatmulWithReductionEpilogueFact` holds a `MatmulFact` + the
over-output-axis epilogue `ReductionFact` for a fused matmul + reduction (`matmul_rms_norm` /
`matmul_layernorm` / `matmul_softmax` / `matmul_l2_normalize` / `matmul_sum` / ...).
`build_matmul_reduction_epilogue_facts` composes the two when exactly one of each co-occurs;
`TritonMatmulReductionEpilogueHeuristic` then seeds a footprint-aware tile. Fires **only** on the
fused family — a pure matmul and a pure reduction do not compose, so the reduction curriculum stays
byte-identical.

**Depends on #2829 (stage 2a) — this is the first consumer of its aperture.** 2a removed the
`if spec.matmul_facts: return` guard in `register_user_tiled_reductions`, which lets the materialized
branch register the over-N epilogue reduction even when a matmul is present. Without that, the
composed fact never forms. It does **not** depend on 2b/#2830's `per_feature_accumulator` (stacked on
#2830 for ordering; re-parentable to #2829).

## What changes
- **`config_spec.py`** — `MatmulWithReductionEpilogueFact` (composes the two sub-facts + `n_extent` +
  `m`/`k_block_id`; holds them, does not re-derive).
- **`device_ir.py`** — `build_matmul_reduction_epilogue_facts` (Phase 4): composes iff exactly 1
  `MatmulFact` + 1 `ReductionFact`.
- **`triton.py`** — `TritonMatmulReductionEpilogueHeuristic`: footprint-aware `M_BLOCK` (largest pow2
  whose `[M_BLOCK, N]` fp32 accumulator fits `ACC_BUDGET_BYTES=128 KiB`, dtype-scaled, capped at 64),
  `K_BLOCK=32`, `num_stages=3`, `num_warps` ramp.
- **`__init__.py`** — registers the heuristic. Reduction heuristics gate on `not matmul_facts`, so the
  fused family is seeded by this one alone; pure matmul keeps its skinny-gemm seed.

## Performance — H100 sm90, fwd-only, cold-L2 `do_bench`, accuracy-gated
`seeded_vs_default` = vs the N-blind `[32,32]` default; `seed/grid-best` = vs the best config from a
**quick Helion grid sweep** (`tile_m`×`tile_k`×`num_warps`, `num_stages=3`) — a cheap stand-in for the
Helion optimum, *not* torch.compile and *not* full autotune; `seed_vs_tc` = the seed vs torch.compile
**default mode** (the apples-to-apples baseline for a one-shot seed).
**The fact is general (whole family, all N); the measured claim is small N.**

| kernel | shape (M,K,N) | seeded_vs_default | seed/grid-best | seed_vs_tc (tc default) |
|---|---|---|---|---|
| matmul_rms_norm | 131072,256,256 | 1.43× | 0.99 | **1.74×** |
| matmul_rms_norm | 131072,256,512 | 1.88× | 0.96 | **1.53×** |
| matmul_rms_norm | 65536,512,512 | 2.40× | 0.89 | **1.21×** |
| matmul_layernorm | 131072,256,256 | 1.31× | 0.96 | **1.50×** |
| matmul_softmax | 131072,256,512 | 1.44× | 0.95 | **1.26×** |
| matmul_l2_normalize | 131072,256,256 | 1.45× | 0.999 | **1.79×** |
| matmul_sum | 131072,256,256 | 1.73× | 0.988 | **1.96×** |
| matmul_logsumexp *(held out)* | 131072,256,256 | 1.65× | 0.998 | **1.72×** |
| matmul_max *(held out)* | 131072,256,256 | 1.70× | 0.991 | **1.97×** |

~1.3–2.4× over the default (up to 2.9× at one shape), within 0–11% of the grid optimum, generalizing
across 7 epilogues incl. held-out `logsumexp`/`max` (the fact is blind to which reduction it is).
**Beats torch.compile default by 1.21–1.97× on every measured shape** (all N≤512, acc-gated cold-L2,
seed `block_sizes=[64,32]`); at N≥768 Helion no longer beats tc (`best_vs_tc≈0.72` at N=768) though it
still beats the Helion default. bf16/fp16 ~1.3–2.9×; fp32 ~1.4–1.5× — validated on `matmul_rms_norm`;
the other epilogues are benched bf16-only (the dtype-aware `M_BLOCK` ceiling is itemsize-faithful, so
it carries by construction).

## Looped / tiled-N variant — perf-unverified
The composed fact also matches the looped shape (a manual `hl.tile(N)` reduction, e.g. large-vocab
fused linear+logsumexp / FLCE), where the seed pins `M_BLOCK` and footprint-budgets the N tile
(resident at small N, chunked at large N) instead of the previous floor-to-1 mis-seed. This branch is
**correctness-verified but perf-unverified** — a footprint-justified educated seed that beats the
catastrophic N-blind default by construction; as a seed it is an autotune candidate (validated at
runtime), never the forced default. The measured claim above is the resident path only.

## Scope / limitations
- **Resident, large N:** the `hl.specialize(N)` resident structure is SMEM-infeasible for *every*
  config beyond ~N=2048 (the `[K_BLOCK, N]` operand overflows SMEM) — a limit of the resident pattern,
  not this heuristic (the default config OOMs there too). Large-N matmul+reduction should be written in
  the tiled/looped form. This seed is never forced (`promote_seed_to_default=False`), so it neither
  causes nor can repair that.
- fp32 coverage as noted above.

## Disjointness / no-regression
- Composed fact fires only on the fused family: fused → `composed=1`; pure matmul → 0 (keeps
  skinny-gemm); pure reduction → 0. The 9 reduction + 8 transfer curriculum kernels have no matmul, so
  the fact never forms — config_recorder byte-identical.
- Gate verdicts (fresh-context adversarial agents): **D** faithful/disjoint, **F** footprint
  mechanism (boundary verified at N=1024/2048), **H** general fact + narrow claim, **A** adversarial
  repro (a brand-new `matmul_var` kernel gets the identical seed — no kernel-name fence), **R**
  byte-identity, **E** held-out generalization (logsumexp/max + a held-out test shape).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

stack-info: PR: #2846, branch: calebmkim/stack/23
@calebmkim calebmkim force-pushed the calebmkim/stack/23 branch from a751f96 to 7d17c7e Compare June 23, 2026 05:03
@calebmkim calebmkim marked this pull request as ready for review June 23, 2026 16:50
@ethche ethche self-requested a review June 24, 2026 02:27

@ethche ethche left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, let's assume that for such kernels the entire row is in registers. Also let's see benchmarking for other (even square shapes) where matmul + layernorm is worse than inductor, but we should still compare against the existing helion default config

calebmkim pushed a commit that referenced this pull request Jun 24, 2026
The first **composed** fact. `MatmulWithReductionEpilogueFact` holds a `MatmulFact` + the
over-output-axis epilogue `ReductionFact` for a fused matmul + reduction (`matmul_rms_norm` /
`matmul_layernorm` / `matmul_softmax` / `matmul_l2_normalize` / `matmul_sum` / ...).
`build_matmul_reduction_epilogue_facts` composes the two when exactly one of each co-occurs;
`TritonMatmulReductionEpilogueHeuristic` then seeds a footprint-aware tile. Fires **only** on the
fused family — a pure matmul and a pure reduction do not compose, so the reduction curriculum stays
byte-identical.

**Depends on #2829 (stage 2a) — this is the first consumer of its aperture.** 2a removed the
`if spec.matmul_facts: return` guard in `register_user_tiled_reductions`, which lets the materialized
branch register the over-N epilogue reduction even when a matmul is present. Without that, the
composed fact never forms. It does **not** depend on 2b/#2830's `per_feature_accumulator` (stacked on
#2830 for ordering; re-parentable to #2829).

## What changes
- **`config_spec.py`** — `MatmulWithReductionEpilogueFact` (composes the two sub-facts + `n_extent` +
  `m`/`k_block_id`; holds them, does not re-derive).
- **`device_ir.py`** — `build_matmul_reduction_epilogue_facts` (Phase 4): composes iff exactly 1
  `MatmulFact` + 1 `ReductionFact`.
- **`triton.py`** — `TritonMatmulReductionEpilogueHeuristic`: footprint-aware `M_BLOCK` (largest pow2
  whose `[M_BLOCK, N]` fp32 accumulator fits `ACC_BUDGET_BYTES=128 KiB`, dtype-scaled, capped at 64),
  `K_BLOCK=32`, `num_stages=3`, `num_warps` ramp.
- **`__init__.py`** — registers the heuristic. Reduction heuristics gate on `not matmul_facts`, so the
  fused family is seeded by this one alone; pure matmul keeps its skinny-gemm seed.

## Performance — H100 sm90, fwd-only, cold-L2 `do_bench`, accuracy-gated
`seeded_vs_default` = vs the N-blind `[32,32]` default; `seed/grid-best` = vs the best config from a
**quick Helion grid sweep** (`tile_m`×`tile_k`×`num_warps`, `num_stages=3`) — a cheap stand-in for the
Helion optimum, *not* torch.compile and *not* full autotune; `seed_vs_tc` = the seed vs torch.compile
**default mode** (the apples-to-apples baseline for a one-shot seed).
**The fact is general (whole family, all N); the measured claim is small N.**

| kernel | shape (M,K,N) | seeded_vs_default | seed/grid-best | seed_vs_tc (tc default) |
|---|---|---|---|---|
| matmul_rms_norm | 131072,256,256 | 1.43× | 0.99 | **1.74×** |
| matmul_rms_norm | 131072,256,512 | 1.88× | 0.96 | **1.53×** |
| matmul_rms_norm | 65536,512,512 | 2.40× | 0.89 | **1.21×** |
| matmul_layernorm | 131072,256,256 | 1.31× | 0.96 | **1.50×** |
| matmul_softmax | 131072,256,512 | 1.44× | 0.95 | **1.26×** |
| matmul_l2_normalize | 131072,256,256 | 1.45× | 0.999 | **1.79×** |
| matmul_sum | 131072,256,256 | 1.73× | 0.988 | **1.96×** |
| matmul_logsumexp *(held out)* | 131072,256,256 | 1.65× | 0.998 | **1.72×** |
| matmul_max *(held out)* | 131072,256,256 | 1.70× | 0.991 | **1.97×** |

~1.3–2.4× over the default (up to 2.9× at one shape), within 0–11% of the grid optimum, generalizing
across 7 epilogues incl. held-out `logsumexp`/`max` (the fact is blind to which reduction it is).
**Beats torch.compile default by 1.21–1.97× on every measured shape** (all N≤512, acc-gated cold-L2,
seed `block_sizes=[64,32]`); at N≥768 Helion no longer beats tc (`best_vs_tc≈0.72` at N=768) though it
still beats the Helion default. bf16/fp16 ~1.3–2.9×; fp32 ~1.4–1.5× — validated on `matmul_rms_norm`;
the other epilogues are benched bf16-only (the dtype-aware `M_BLOCK` ceiling is itemsize-faithful, so
it carries by construction).

## Looped / tiled-N variant — perf-unverified
The composed fact also matches the looped shape (a manual `hl.tile(N)` reduction, e.g. large-vocab
fused linear+logsumexp / FLCE), where the seed pins `M_BLOCK` and footprint-budgets the N tile
(resident at small N, chunked at large N) instead of the previous floor-to-1 mis-seed. This branch is
**correctness-verified but perf-unverified** — a footprint-justified educated seed that beats the
catastrophic N-blind default by construction; as a seed it is an autotune candidate (validated at
runtime), never the forced default. The measured claim above is the resident path only.

## Scope / limitations
- **Resident, large N:** the `hl.specialize(N)` resident structure is SMEM-infeasible for *every*
  config beyond ~N=2048 (the `[K_BLOCK, N]` operand overflows SMEM) — a limit of the resident pattern,
  not this heuristic (the default config OOMs there too). Large-N matmul+reduction should be written in
  the tiled/looped form. This seed is never forced (`promote_seed_to_default=False`), so it neither
  causes nor can repair that.
- fp32 coverage as noted above.

## Disjointness / no-regression
- Composed fact fires only on the fused family: fused → `composed=1`; pure matmul → 0 (keeps
  skinny-gemm); pure reduction → 0. The 9 reduction + 8 transfer curriculum kernels have no matmul, so
  the fact never forms — config_recorder byte-identical.
- Gate verdicts (fresh-context adversarial agents): **D** faithful/disjoint, **F** footprint
  mechanism (boundary verified at N=1024/2048), **H** general fact + narrow claim, **A** adversarial
  repro (a brand-new `matmul_var` kernel gets the identical seed — no kernel-name fence), **R**
  byte-identity, **E** held-out generalization (logsumexp/max + a held-out test shape).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

stack-info: PR: #2846, branch: calebmkim/stack/23
@calebmkim calebmkim force-pushed the calebmkim/stack/23 branch from 7d17c7e to 9583704 Compare June 24, 2026 05:43
@calebmkim

Copy link
Copy Markdown
Contributor Author

Looked into this. Basically, because of SMEM budget, this heuristic will fail to compile for many shapes. The compile limit is driven by N: when it gets too large we run out of smem. So when it fails to compile, we could just make it fall back to the default config.

Benchmarking on the compileable shapes (cold-L2, seed vs the existing Helion default):

Shape A = 131072×4096×1024 bf16 (large M, large K)
Shape B = 131072×1024×512 fp32 (large M, medium K)
Shape C = 1024×1024×1024 bf16 (square)

epilogue A seed/def A seed/tc B seed/def B seed/tc C seed/def C seed/tc
layernorm 3.55× 0.44× 1.58× 2.83× 5.96× 0.33×
rms_norm 3.38× 0.44× 1.77× 2.86× 5.74× 0.34×
softmax 3.17× 0.44× 1.97× 2.80× 5.32× 0.33×
l2_normalize 3.36× 0.44× 1.76× 2.86× 5.77× 0.33×
logsumexp* 3.17× 0.43× 1.57× 2.87× 5.57× 0.35×

The seed beats the existing default on every compileable shape — including the square shapes / large k shapes where we lose to inductor, but still beat the default easily.

So I'd say there are two options:

  1. Expand to square/large-N. (1a) = re-writing the kernel to include looped N dimension, but we agreed to abandon this option. (1b) = within the heuristic, we should realize when we're running out of smem budget and change the config. We'll still hit an N ceiling eventually, but we could probably push it up. The default's cap is N=2048 (bf16) / N=1024 (fp32), whereas my seed's is N = 1024 bf15 / N = 512 (fp32) ... so we could push for that N=1024-2048 perf if you want, I think there's some genuine opportunity for improvement.
  2. Leave as-is — a small-N seed that's strictly better than the default across its whole feasible domain and falls back gracefully elsewhere (maybe we can reject the config by calculating smem usage up front so we don't have to fail -- that's a small change).

Lmk your thoughts.

calebmkim pushed a commit that referenced this pull request Jun 24, 2026
The first **composed** fact. `MatmulWithReductionEpilogueFact` holds a `MatmulFact` + the
over-output-axis epilogue `ReductionFact` for a fused matmul + reduction (`matmul_rms_norm` /
`matmul_layernorm` / `matmul_softmax` / `matmul_l2_normalize` / `matmul_logsumexp` / ...).
`build_matmul_reduction_epilogue_facts` composes the two when exactly one of each co-occurs;
`TritonMatmulReductionEpilogueHeuristic` then seeds a footprint-aware tile. Fires **only** on the
fused family — a pure matmul and a pure reduction do not compose, so the reduction curriculum stays
byte-identical.

**Resident (row-in-registers) only.** The seed assumes the whole reduced row lives in the
register-resident `[M_BLOCK, N]` fp32 accumulator (N is `hl.specialize`'d → `n_block_id is None`);
`is_eligible` fires only on that shape. The earlier looped/tiled-N branch (perf-unverified) was
removed per review.

**Depends on #2829 (stage 2a) — this is the first consumer of its aperture.** 2a removed the
`if spec.matmul_facts: return` guard in `register_user_tiled_reductions`, which lets the materialized
branch register the over-N epilogue reduction even when a matmul is present. Without that, the
composed fact never forms. It does **not** depend on 2b/#2830's `per_feature_accumulator` (stacked on
#2830 for ordering; re-parentable to #2829).

## What changes
- **`config_spec.py`** — `MatmulWithReductionEpilogueFact` (composes the two sub-facts + `n_extent` +
  `m`/`k_block_id`; holds them, does not re-derive).
- **`device_ir.py`** — `build_matmul_reduction_epilogue_facts` (Phase 4): composes iff exactly 1
  `MatmulFact` + 1 `ReductionFact`.
- **`triton.py`** — `TritonMatmulReductionEpilogueHeuristic`: footprint-aware `M_BLOCK` (largest pow2
  whose `[M_BLOCK, N]` fp32 accumulator fits `ACC_BUDGET_BYTES=128 KiB`, dtype-scaled, capped at 64),
  `K_BLOCK=32`, `num_stages=3`, `num_warps` ramp, plus a staged-SMEM byte-cap that extends the
  feasible-N range (below).
- **`__init__.py`** — registers the heuristic. Reduction heuristics gate on `not matmul_facts`, so the
  fused family is seeded by this one alone; pure matmul keeps its skinny-gemm seed.

## Staged-SMEM byte-cap (feasible-N extension)
The resident kernel's SMEM is dominated by the staged `[K_BLOCK, N]` y-operand
(`num_stages·K_BLOCK·N·itemsize`); past the small-N range `[K_BLOCK=32, num_stages=3]` overflows it.
The seed caps that footprint to `SMEM_STAGED_BUDGET_BYTES=192 KiB`: when over, drop **`K_BLOCK 32→16`
first** (halves the staged bytes *and* sidesteps a measured non-monotonic `K_BLOCK=32` ptxas cliff),
then drop `num_stages` only if still over. Small-N stays **byte-identical** (the cap doesn't fire).
Because `n_extent` is the next power of two of `N`, the emitted bands are `KB32/st3` → `KB16/st3` →
`KB16/st1` (the `KB16/st2` band never occurs).

Net: the seed compiles to **N≤4096 bf16 / N≤2048 fp32** — roughly **2× the default's wall**
(N≤2048 / N≤1024) — and **fails closed** beyond (clean `OutOfResources` → the autotuner falls back;
the seed is never the forced default, `promote_seed_to_default=False`). Where the default OOMs but the
seed still compiles (bf16 N=3072/4096, fp32 N=1536/2048, the 2048³ fp32 square), the seed is a
**feasibility rescue**.

## Performance
H100 sm90, fwd-only, **cold-L2 CUDA-graph device time** (256 MiB L2 flush per replay), accuracy-gated
(relerr vs an fp32 reference). `seed` = the heuristic's emitted config (extracted faithfully via
`kernel.bind` → `get_seed_config`); `default` = `config_spec.default_config()` (the bar); `tc` =
`torch.compile` with native-dtype matmul + fp32 accumulate. `>1` = seed faster. "rescue" = default
uncompilable (OOM), seed runs. The seed is epilogue-blind, so the 5 kernels track each other at
matched shapes.

| kernel | (M, K, N) | dtype | regime | seed/default | seed/tc |
|---|---|---|---|---|---|
| layernorm | 131072, 512, 512 | bf16 | small-N | 1.82× | 1.15× |
| layernorm | 131072, 512, 512 | fp32 | small-N | 1.55× | 2.92× |
| layernorm | 131072, 512, 2048 | bf16 | larger-N | 14.57× | 0.37× |
| layernorm | 131072, 512, 4096 | bf16 | larger-N | rescue | 0.13× |
| layernorm | 131072, 4096, 2048 | bf16 | larger-K | 13.65× | 0.19× |
| layernorm | 512³ | bf16 | square | 1.72× | 0.59× |
| layernorm | 1024³ | bf16 | square | 5.95× | 0.34× |
| layernorm | 1024³ | fp32 | square | 26.94× | 0.99× |
| layernorm | 2048³ | fp32 | square | rescue | 0.77× |
| rmsnorm | 1024³ | bf16 | square | 5.69× | 0.34× |
| rmsnorm | 131072, 512, 4096 | bf16 | larger-N | rescue | 0.12× |
| rmsnorm | 131072, 8192, 2048 | bf16 | larger-K | 12.34× | 0.15× |
| softmax | 1024³ | fp32 | square | 29.69× | 0.98× |
| softmax | 131072, 512, 2048 | fp32 | larger-N | rescue | 1.22× |
| softmax | 131072, 2048, 2048 | bf16 | larger-K | 13.63× | 0.23× |
| l2_normalize | 2048³ | bf16 | square | 13.88× | 0.18× |
| l2_normalize | 131072, 512, 4096 | bf16 | larger-N | rescue | 0.12× |
| l2_normalize | 131072, 6144, 1024 | bf16 | larger-K | 3.31× | 0.39× |
| logsumexp | 1024³ | bf16 | square | 5.54× | 0.36× |
| logsumexp | 131072, 512, 4096 | bf16 | larger-N | rescue | 0.10× |
| logsumexp | 131072, 3072, 2048 | bf16 | larger-K | 3.63× | 0.19× |

**vs the Helion default (the bar): geomean 6.93× over the 15 both-compile cells (1.55–29.69×), no
regressions, plus 6 feasibility rescues.** The default is catastrophic at large N/K — it spills the
huge fp32 accumulator to local memory (e.g. 257 ms / 509 ms at K=4096/8192 vs the seed's 19 / 41 ms),
or OOMs outright — which the seed's `num_warps` ramp + byte-cap avoid.

**vs torch.compile: fp32 is competitive-to-winning (0.77–2.92×; the softmax N=2048 fp32 rescue beats
it 1.22× while the default can't even run); bf16 loses at high arithmetic intensity** (large-N /
large-K / squares, 0.10–0.59×, where cuBLAS bf16 tensor-core GEMM dominates) and wins only at small N.
This is the expected, documented regime: the resident matmul+epilogue is GEMM/occupancy-bound at large
N, so the seed's role is to be a far better *seed and feasibility backstop than the Helion default*
(which it is everywhere), not to beat cuBLAS.

## Scope / limitations
- **Resident structure → feasibility is N-bounded** (seed N≤4096 bf16 / N≤2048 fp32; fails closed
  beyond, autotuner falls back). Large-square / large-N is occupancy- and GEMM-bound and loses to
  cuBLAS — a structural ceiling of the resident pattern (documented in the table), not a seed defect;
  a competitive large-N kernel needs a tiled/looped 2D-GEMM structure (separate work).

## Disjointness / no-regression
- Composed fact fires only on the fused family: fused → `composed=1`; pure matmul → 0 (keeps
  skinny-gemm); pure reduction → 0. The 9 reduction + 8 transfer curriculum kernels have no matmul, so
  the fact never forms — config_recorder byte-identical. Small-N matmul+epilogue configs are
  byte-identical with/without the byte-cap (it doesn't fire there).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

stack-info: PR: #2846, branch: calebmkim/stack/23
@calebmkim calebmkim force-pushed the calebmkim/stack/23 branch from 9583704 to 05c44ef Compare June 24, 2026 17:47
@calebmkim

Copy link
Copy Markdown
Contributor Author

Actually I just decided to change the heuristic to fire on larger n shapes. It still outperforms the default on basically all configs. ("rescue" means that the default config errored / wasn't compilable), even thouhg it starts to lose pretty badly to torch.compiile() for larger shapes.

kernel (M, K, N) dtype regime seed/default seed/tc
layernorm 131072, 512, 512 bf16 small-N 1.82× 1.15×
layernorm 131072, 512, 512 fp32 small-N 1.55× 2.92×
layernorm 131072, 512, 2048 bf16 larger-N 14.57× 0.37×
layernorm 131072, 512, 4096 bf16 larger-N rescue 0.13×
layernorm 131072, 4096, 2048 bf16 larger-K 13.65× 0.19×
layernorm 512³ bf16 square 1.72× 0.59×
layernorm 1024³ bf16 square 5.95× 0.34×
layernorm 1024³ fp32 square 26.94× 0.99×
layernorm 2048³ fp32 square rescue 0.77×
rmsnorm 1024³ bf16 square 5.69× 0.34×
rmsnorm 131072, 512, 4096 bf16 larger-N rescue 0.12×
rmsnorm 131072, 8192, 2048 bf16 larger-K 12.34× 0.15×
softmax 1024³ fp32 square 29.69× 0.98×
softmax 131072, 512, 2048 fp32 larger-N rescue 1.22×
softmax 131072, 2048, 2048 bf16 larger-K 13.63× 0.23×
l2_normalize 2048³ bf16 square 13.88× 0.18×
l2_normalize 131072, 512, 4096 bf16 larger-N rescue 0.12×
l2_normalize 131072, 6144, 1024 bf16 larger-K 3.31× 0.39×
logsumexp 1024³ bf16 square 5.54× 0.36×
logsumexp 131072, 512, 4096 bf16 larger-N rescue 0.10×
logsumexp 131072, 3072, 2048 bf16 larger-K 3.63× 0.19×

The first **composed** fact. `MatmulWithReductionEpilogueFact` holds a `MatmulFact` + the
over-output-axis epilogue `ReductionFact` for a fused matmul + reduction (`matmul_rms_norm` /
`matmul_layernorm` / `matmul_softmax` / `matmul_l2_normalize` / `matmul_logsumexp` / ...).
`build_matmul_reduction_epilogue_facts` composes the two when exactly one of each co-occurs;
`TritonMatmulReductionEpilogueHeuristic` then seeds a footprint-aware tile. Fires **only** on the
fused family — a pure matmul and a pure reduction do not compose, so the reduction curriculum stays
byte-identical.

**Resident (row-in-registers) only.** The seed assumes the whole reduced row lives in the
register-resident `[M_BLOCK, N]` fp32 accumulator (N is `hl.specialize`'d → `n_block_id is None`);
`is_eligible` fires only on that shape. The earlier looped/tiled-N branch (perf-unverified) was
removed per review.

**Depends on #2829 (stage 2a) — this is the first consumer of its aperture.** 2a removed the
`if spec.matmul_facts: return` guard in `register_user_tiled_reductions`, which lets the materialized
branch register the over-N epilogue reduction even when a matmul is present. Without that, the
composed fact never forms. It does **not** depend on 2b/#2830's `per_feature_accumulator` (stacked on
#2830 for ordering; re-parentable to #2829).

## What changes
- **`config_spec.py`** — `MatmulWithReductionEpilogueFact` (composes the two sub-facts + `n_extent` +
  `m`/`k_block_id`; holds them, does not re-derive).
- **`device_ir.py`** — `build_matmul_reduction_epilogue_facts` (Phase 4): composes iff exactly 1
  `MatmulFact` + 1 `ReductionFact`.
- **`triton.py`** — `TritonMatmulReductionEpilogueHeuristic`: footprint-aware `M_BLOCK` (largest pow2
  whose `[M_BLOCK, N]` fp32 accumulator fits `ACC_BUDGET_BYTES=128 KiB`, dtype-scaled, capped at 64),
  `K_BLOCK=32`, `num_stages=3`, `num_warps` ramp, plus a staged-SMEM byte-cap that extends the
  feasible-N range (below).
- **`__init__.py`** — registers the heuristic. Reduction heuristics gate on `not matmul_facts`, so the
  fused family is seeded by this one alone; pure matmul keeps its skinny-gemm seed.

## Staged-SMEM byte-cap (feasible-N extension)
The resident kernel's SMEM is dominated by the staged `[K_BLOCK, N]` y-operand
(`num_stages·K_BLOCK·N·itemsize`); past the small-N range `[K_BLOCK=32, num_stages=3]` overflows it.
The seed caps that footprint to `SMEM_STAGED_BUDGET_BYTES=192 KiB`: when over, drop **`K_BLOCK 32→16`
first** (halves the staged bytes *and* sidesteps a measured non-monotonic `K_BLOCK=32` ptxas cliff),
then drop `num_stages` only if still over. Small-N stays **byte-identical** (the cap doesn't fire).
Because `n_extent` is the next power of two of `N`, the emitted bands are `KB32/st3` → `KB16/st3` →
`KB16/st1` (the `KB16/st2` band never occurs).

Net: the seed compiles to **N≤4096 bf16 / N≤2048 fp32** — roughly **2× the default's wall**
(N≤2048 / N≤1024) — and **fails closed** beyond (clean `OutOfResources` → the autotuner falls back;
the seed is never the forced default, `promote_seed_to_default=False`). Where the default OOMs but the
seed still compiles (bf16 N=3072/4096, fp32 N=1536/2048, the 2048³ fp32 square), the seed is a
**feasibility rescue**.

## Performance
H100 sm90, fwd-only, **cold-L2 CUDA-graph device time** (256 MiB L2 flush per replay), accuracy-gated
(relerr vs an fp32 reference). `seed` = the heuristic's emitted config (extracted faithfully via
`kernel.bind` → `get_seed_config`); `default` = `config_spec.default_config()` (the bar); `tc` =
`torch.compile` with native-dtype matmul + fp32 accumulate. `>1` = seed faster. "rescue" = default
uncompilable (OOM), seed runs. The seed is epilogue-blind, so the 5 kernels track each other at
matched shapes.

| kernel | (M, K, N) | dtype | regime | seed/default | seed/tc |
|---|---|---|---|---|---|
| layernorm | 131072, 512, 512 | bf16 | small-N | 1.82× | 1.15× |
| layernorm | 131072, 512, 512 | fp32 | small-N | 1.55× | 2.92× |
| layernorm | 131072, 512, 2048 | bf16 | larger-N | 14.57× | 0.37× |
| layernorm | 131072, 512, 4096 | bf16 | larger-N | rescue | 0.13× |
| layernorm | 131072, 4096, 2048 | bf16 | larger-K | 13.65× | 0.19× |
| layernorm | 512³ | bf16 | square | 1.72× | 0.59× |
| layernorm | 1024³ | bf16 | square | 5.95× | 0.34× |
| layernorm | 1024³ | fp32 | square | 26.94× | 0.99× |
| layernorm | 2048³ | fp32 | square | rescue | 0.77× |
| rmsnorm | 1024³ | bf16 | square | 5.69× | 0.34× |
| rmsnorm | 131072, 512, 4096 | bf16 | larger-N | rescue | 0.12× |
| rmsnorm | 131072, 8192, 2048 | bf16 | larger-K | 12.34× | 0.15× |
| softmax | 1024³ | fp32 | square | 29.69× | 0.98× |
| softmax | 131072, 512, 2048 | fp32 | larger-N | rescue | 1.22× |
| softmax | 131072, 2048, 2048 | bf16 | larger-K | 13.63× | 0.23× |
| l2_normalize | 2048³ | bf16 | square | 13.88× | 0.18× |
| l2_normalize | 131072, 512, 4096 | bf16 | larger-N | rescue | 0.12× |
| l2_normalize | 131072, 6144, 1024 | bf16 | larger-K | 3.31× | 0.39× |
| logsumexp | 1024³ | bf16 | square | 5.54× | 0.36× |
| logsumexp | 131072, 512, 4096 | bf16 | larger-N | rescue | 0.10× |
| logsumexp | 131072, 3072, 2048 | bf16 | larger-K | 3.63× | 0.19× |

**vs the Helion default (the bar): geomean 6.93× over the 15 both-compile cells (1.55–29.69×), no
regressions, plus 6 feasibility rescues.** The default is catastrophic at large N/K — it spills the
huge fp32 accumulator to local memory (e.g. 257 ms / 509 ms at K=4096/8192 vs the seed's 19 / 41 ms),
or OOMs outright — which the seed's `num_warps` ramp + byte-cap avoid.

**vs torch.compile: fp32 is competitive-to-winning (0.77–2.92×; the softmax N=2048 fp32 rescue beats
it 1.22× while the default can't even run); bf16 loses at high arithmetic intensity** (large-N /
large-K / squares, 0.10–0.59×, where cuBLAS bf16 tensor-core GEMM dominates) and wins only at small N.
This is the expected, documented regime: the resident matmul+epilogue is GEMM/occupancy-bound at large
N, so the seed's role is to be a far better *seed and feasibility backstop than the Helion default*
(which it is everywhere), not to beat cuBLAS.

## Scope / limitations
- **Resident structure → feasibility is N-bounded** (seed N≤4096 bf16 / N≤2048 fp32; fails closed
  beyond, autotuner falls back). Large-square / large-N is occupancy- and GEMM-bound and loses to
  cuBLAS — a structural ceiling of the resident pattern (documented in the table), not a seed defect;
  a competitive large-N kernel needs a tiled/looped 2D-GEMM structure (separate work).

## Disjointness / no-regression
- Composed fact fires only on the fused family: fused → `composed=1`; pure matmul → 0 (keeps
  skinny-gemm); pure reduction → 0. The 9 reduction + 8 transfer curriculum kernels have no matmul, so
  the fact never forms — config_recorder byte-identical. Small-N matmul+epilogue configs are
  byte-identical with/without the byte-cap (it doesn't fire there).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

stack-info: PR: #2846, branch: calebmkim/stack/23
@calebmkim calebmkim force-pushed the calebmkim/stack/23 branch from 05c44ef to cc338df Compare June 24, 2026 17:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants