Skip to content

[autotuner] reduction seed: budgeted r_block + liveness-aware persistent/looped decision#2828

Merged
calebmkim merged 1 commit into
mainfrom
calebmkim/stack/6
Jul 1, 2026
Merged

[autotuner] reduction seed: budgeted r_block + liveness-aware persistent/looped decision#2828
calebmkim merged 1 commit into
mainfrom
calebmkim/stack/6

Conversation

@calebmkim

@calebmkim calebmkim commented Jun 20, 2026

Copy link
Copy Markdown
Contributor

Stacked PRs:


[autotuner] reduction seed: budgeted r_block + liveness-aware persistent/looped decision

Stage 1 of 3 in the reduction-seed heuristic stack. This is the "tighten what already
works" stage: the merged reduction seed (#2762) already produces good configs for the 9
core reduction kernels and transfers cleanly to unseen kernels, so most of this commit is
cleanliness + robustness hardening, with one behavioral change that matters (a liveness
analysis that loops instead of persists when a reduction body would spill).

What changes

  1. Fold the reduction-chunk size AND the persistent/looped verdict into one budgeted
    _reduction_rblock -> (r_block, persistent). Behavior-preserving (zero config diff for
    the 9+8 kernels below); just removes a duplicated budget computation.

  2. Rename "Band B" -> a descriptive name (the accumulator-residency band); no behavior change.

  3. A full-width M_BLOCK persist-cap (_m_block_cap) so a huge-M grid-size autotuner_min
    raise can no longer starve occupancy. This cap only ever lowers an over-raised floor
    and fires on very few shapes — it is mainly defensive for future kernels/shapes, not a
    tuned win on today's curriculum.

  4. The one real behavioral change — a conservative looping bias driven by a new walker
    liveness fact ReductionFact.body_live_tiles (peak simultaneously-live rdim-shaped tiles,
    computed once in the collect pass):

    • Err toward looping over persistent when the two are ~equivalent. Persistent had
      been chosen on near-ties, but staying persistent risks a register spill with a large perf
      cliff, whereas looping is safe — so ties now go to looping.
    • A multi-tile spill ceiling: when a body holds many live full-width tiles, persisting
      spills the register file. The liveness fact routes such bodies to a looped reduction.

    This is chiefly for fused_linear_jsd, whose softmax->log_softmax->KL->grad chain holds
    ~7 live full-width fp32 tiles; persisting spilled (n_spills 480 -> 0 after looping).

The 9 core + 8 robustness kernels are byte-identical across this commit, except
fused_linear_jsd (the liveness change above) — see the tables. This PR does not modify
examples/ (welford is left at base).

Robustness / transfer kernels (8)

Beyond the 9 core reduction kernels, the heuristic was stress-tested on 8 kernels it was
not tuned on (a genuine transfer test). Two are in this repo's examples/; the other six
live in the lab harness on the fork
(transfer_kernels.py):

  fused_linear_jsd        examples/fused_linear_jsd.py
  grpo                    examples/grpo_loss.py
  fused_add_rmsnorm       _lab/transfer/transfer_kernels.py   (fork @ reduction-3stage-stack)
  fused_add_layernorm     _lab/transfer/transfer_kernels.py   (same)
  gated_rmsnorm           _lab/transfer/transfer_kernels.py   (same)
  scaled_masked_softmax   _lab/transfer/transfer_kernels.py   (same)
  cross_entropy_ls_zloss  _lab/transfer/transfer_kernels.py   (same)
  dynamic_quant           _lab/transfer/transfer_kernels.py   (same)

Performance

Geomean G = torch.compile_us / helion_us across the shape suite (>1 = Helion faster). H100
80GB SXM5, cold-L2 do_bench median-of-11, accuracy-gated before timing. n = # shapes
benchmarked per kernel (each timed at both bf16 and fp32). = in the "after" column means
the emitted config is byte-identical to before (so the kernel and its perf are unchanged)

the point of these tables is that almost everything is =.

9 core reduction kernels (test split), before (BASE) -> after (stage1):

  kernel          n    bf16              fp32
  rms_norm        8    1.036 -> =        1.014 -> =
  layer_norm      8    1.041 -> =        1.017 -> =
  softmax         8    1.352 -> =        1.158 -> =
  sum             7    1.050 -> 1.049    0.995 -> 0.995   (N>=18432 flips persistent->looped, perf ~unchanged)
  long_sum        7    0.925 -> =        0.926 -> =
  cross_entropy   7    1.143 -> =        0.826 -> =
  kl_div          7    1.088 -> =        1.086 -> =
  jsd             7    0.802 -> =        1.037 -> =

8 transfer / robustness kernels (all shapes), before (BASE) -> after (stage1):

  kernel                   n    bf16              fp32
  fused_add_rmsnorm        12   0.993 -> =        1.012 -> =
  fused_add_layernorm      12   1.085 -> =        0.986 -> =
  gated_rmsnorm            12   1.015 -> =        1.000 -> =
  scaled_masked_softmax    11   1.110 -> =        1.003 -> =
  cross_entropy_ls_zloss   11   1.014 -> =        1.075 -> =
  dynamic_quant            8    1.131 -> =        1.058 -> =
  grpo                     7    1.253 -> =        1.023 -> =
  fused_linear_jsd         7    0.609 -> 0.835    0.857 -> 1.141   <-- the liveness/looping change

Takeaway: 16 of 17 kernels are byte-identical before/after (perf unchanged, as intended
for a cleanliness stage). The one perf win is fused_linear_jsd from the liveness analysis
(bf16 +37%, fp32 +33%).

Co-Authored-By: Claude Opus 4.8 (1M context) noreply@anthropic.com

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 20, 2026
@calebmkim calebmkim force-pushed the calebmkim/stack/6 branch from d90fcc0 to ce9cfc7 Compare June 20, 2026 20:21
@calebmkim calebmkim force-pushed the calebmkim/stack/6 branch from ce9cfc7 to 3df5c88 Compare June 20, 2026 20:23
@calebmkim calebmkim changed the title [stage1] reduction seed: budgeted r_block + liveness-aware persistent/looped decision [autotuner] reduction seed: budgeted r_block + liveness-aware persistent/looped decision Jun 20, 2026
@calebmkim calebmkim force-pushed the calebmkim/stack/6 branch from 3df5c88 to 4457e3b Compare June 22, 2026 05:02
@calebmkim calebmkim marked this pull request as ready for review June 22, 2026 05:10
@calebmkim calebmkim marked this pull request as draft June 23, 2026 05:01
@calebmkim calebmkim marked this pull request as ready for review June 23, 2026 05:03
…ent/looped decision

Stage 1 of 3 in the reduction-seed heuristic stack. This is the "tighten what already
works" stage: the merged reduction seed (#2762) already produces good configs for the 9
core reduction kernels and transfers cleanly to unseen kernels, so most of this commit is
cleanliness + robustness hardening, with **one** behavioral change that matters (a liveness
analysis that loops instead of persists when a reduction body would spill).

## What changes

1. Fold the reduction-chunk size AND the persistent/looped verdict into one budgeted
   `_reduction_rblock` -> `(r_block, persistent)`. Behavior-preserving (zero config diff for
   the 9+8 kernels below); just removes a duplicated budget computation.
2. Rename "Band B" -> a descriptive name (the accumulator-residency band); no behavior change.
3. A full-width `M_BLOCK` persist-cap (`_m_block_cap`) so a huge-M grid-size `autotuner_min`
   raise can no longer starve occupancy. This cap only ever **lowers** an over-raised floor
   and fires on very few shapes — it is mainly **defensive** for future kernels/shapes, not a
   tuned win on today's curriculum.
4. The one real behavioral change — a conservative looping bias driven by a new walker
   liveness fact `ReductionFact.body_live_tiles` (peak simultaneously-live rdim-shaped tiles,
   computed once in the collect pass):
   - Err toward **looping** over **persistent** when the two are ~equivalent. Persistent had
     been chosen on near-ties, but staying persistent risks a register spill with a large perf
     cliff, whereas looping is safe — so ties now go to looping.
   - A multi-tile spill ceiling: when a body holds many live full-width tiles, persisting
     spills the register file. The liveness fact routes such bodies to a looped reduction.

   This is chiefly for `fused_linear_jsd`, whose softmax->log_softmax->KL->grad chain holds
   ~7 live full-width fp32 tiles; persisting spilled (n_spills 480 -> 0 after looping).

The 9 core + 8 robustness kernels are byte-identical across this commit, **except**
`fused_linear_jsd` (the liveness change above) — see the tables. This PR does not modify
`examples/` (welford is left at base).

## Robustness / transfer kernels (8)

Beyond the 9 core reduction kernels, the heuristic was stress-tested on 8 kernels it was
**not** tuned on (a genuine transfer test). Two are in this repo's `examples/`; the other six
live in the lab harness on the fork
([transfer_kernels.py](https://github.com/calebmkim/helion/blob/reduction-3stage-stack/_lab/transfer/transfer_kernels.py)):

```
  fused_linear_jsd        examples/fused_linear_jsd.py
  grpo                    examples/grpo_loss.py
  fused_add_rmsnorm       _lab/transfer/transfer_kernels.py   (fork @ reduction-3stage-stack)
  fused_add_layernorm     _lab/transfer/transfer_kernels.py   (same)
  gated_rmsnorm           _lab/transfer/transfer_kernels.py   (same)
  scaled_masked_softmax   _lab/transfer/transfer_kernels.py   (same)
  cross_entropy_ls_zloss  _lab/transfer/transfer_kernels.py   (same)
  dynamic_quant           _lab/transfer/transfer_kernels.py   (same)
```

## Performance

Geomean `G = torch.compile_us / helion_us` across the shape suite (>1 = Helion faster). H100
80GB SXM5, cold-L2 `do_bench` median-of-11, accuracy-gated before timing. `n` = # shapes
benchmarked per kernel (each timed at both bf16 and fp32). **`=` in the "after" column means
the emitted config is byte-identical to *before* (so the kernel and its perf are unchanged)** —
the point of these tables is that almost everything is `=`.

9 core reduction kernels (test split), before (BASE) -> after (stage1):

```
  kernel          n    bf16              fp32
  rms_norm        8    1.036 -> =        1.014 -> =
  layer_norm      8    1.041 -> =        1.017 -> =
  softmax         8    1.352 -> =        1.158 -> =
  sum             7    1.050 -> 1.049    0.995 -> 0.995   (N>=18432 flips persistent->looped, perf ~unchanged)
  long_sum        7    0.925 -> =        0.926 -> =
  cross_entropy   7    1.143 -> =        0.826 -> =
  kl_div          7    1.088 -> =        1.086 -> =
  jsd             7    0.802 -> =        1.037 -> =
```

8 transfer / robustness kernels (all shapes), before (BASE) -> after (stage1):

```
  kernel                   n    bf16              fp32
  fused_add_rmsnorm        12   0.993 -> =        1.012 -> =
  fused_add_layernorm      12   1.085 -> =        0.986 -> =
  gated_rmsnorm            12   1.015 -> =        1.000 -> =
  scaled_masked_softmax    11   1.110 -> =        1.003 -> =
  cross_entropy_ls_zloss   11   1.014 -> =        1.075 -> =
  dynamic_quant            8    1.131 -> =        1.058 -> =
  grpo                     7    1.253 -> =        1.023 -> =
  fused_linear_jsd         7    0.609 -> 0.835    0.857 -> 1.141   <-- the liveness/looping change
```

**Takeaway:** 16 of 17 kernels are byte-identical before/after (perf unchanged, as intended
for a cleanliness stage). The one **perf** win is `fused_linear_jsd` from the liveness analysis
(bf16 +37%, fp32 +33%).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

stack-info: PR: #2828, branch: calebmkim/stack/6
@calebmkim calebmkim force-pushed the calebmkim/stack/6 branch from 4457e3b to 9125838 Compare July 1, 2026 20:23
@calebmkim calebmkim merged commit 81e49d5 into main Jul 1, 2026
54 of 56 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants