Skip to content

[cute] Enable N-axis cluster (cluster_n=2, cluster_m=1) A-multicast for tiny-M fp8#2845

Draft
yushangdi wants to merge 1 commit into
yushangdi/stack/61from
yushangdi/stack/62
Draft

[cute] Enable N-axis cluster (cluster_n=2, cluster_m=1) A-multicast for tiny-M fp8#2845
yushangdi wants to merge 1 commit into
yushangdi/stack/61from
yushangdi/stack/62

Conversation

@yushangdi

@yushangdi yushangdi commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

Stacked PRs:


[cute] Enable N-axis cluster (cluster_n=2, cluster_m=1) A-multicast for tiny-M fp8

For a tiny-M fp8 GEMM (M<=64) the tcgen05 path runs one M tile across many N
tiles (e.g. M=16, N=14336 => 112 N tiles), and each N-tile CTA re-reads the
same small A operand (M x K) from DRAM. NVIDIA's torch._scaled_mm CUTLASS
kernel handles exactly this case by selecting an N-axis 2-CTA cluster: see
aten/src/ATen/native/cuda/RowwiseScaledMM.cu, where M=16,N>=3072 dispatches
ClusterShape <1,2,1>. Two CTAs covering adjacent N tiles form a cluster and
TMA-multicast the shared A operand, cutting A traffic and improving B-stream
efficiency. (Clustering along M is useless here -- there is only one M tile;
an earlier cluster_m=2 attempt regressed for that reason.)

Helion previously allowed cluster_n>1 only inside the 4-CTA envelope
(cluster_m=2 AND cluster_n=2, CtaGroup.TWO). This enables a standalone
cluster_m=1, cluster_n=2, CtaGroup.ONE configuration where each CTA still owns
its own 64x128 output tile and the cluster exists solely to multicast A.

Changes:

  • cute_mma.py: accept an n-only cluster (cluster_m=1, cluster_n=2, flat TMA
    pipeline, non-persistent, single M tile m_size<=bm, N % (2*bn) == 0,
    K % bk == 0) as a second cluster_n=2 envelope alongside the 4-CTA one. Build
    the A TMA multicast atom + mcast_mode=2 mask and set the AB-pipeline consumer
    arrive count to num_mcast_ctas_a + num_mcast_ctas_b - 1 = 2 for this case
    (B stays per-CTA: cluster_m=1 means no M peers to share B). Launch
    cluster=(cluster_n,1,1) along the flat grid_x (consecutive grid_x = adjacent
    N tiles) while keeping the logical (1,2,1) cluster_layout_vmnk so the A mask
    (mode 2 = N) and block_idx_in_cluster->N-coord mapping are correct. The
    mainloop stays flat/owner-ungated (CtaGroup.ONE => every CTA is a cluster
    leader); the padded-M masked store is unchanged.
  • strategies.py: relax the config-spec invariant so cluster_n=2 with
    cluster_m=1 is admitted under role_local_monolithic + non_persistent, keeping
    codegen and config-spec consistent.

Verification (B200 sm_100, M=16 K=4096 N=14336, fp8 e4m3):

  • Correctness with range-filling fp8 + identity scales vs a float32 matmul:
    relerr ~1.7e-3 (fp8 floor), identical to the cluster_n=1 baseline, across 3
    seeds; per-row relerr uniform over all 16 rows, zero all-zero rows (no
    half-output/missing-CTA bug); left/right N halves equally correct.
    cute.nvgpu.tcgen05 present. No-crash sweep static_m in {16,32,64} x block_m
    in {64,128} x cluster_n in {1,2} all pass.
  • Performance (warm-L2, do_bench_wrapper, cudagraph, A/B in one session):
    cluster_n=1 16.27us / 3.64 TB/s -> cluster_n=2 14.87us / 3.99 TB/s (~1.1x;
    larger at locked/idle clocks). torch._scaled_mm 9.68us / 6.12 TB/s.
  • NCU: launch__cluster_dim_x=2 (cluster active), L2 read hit 11.0% -> 14.4%,
    DRAM %peak 42.2% -> 43.7%, registers/thread unchanged at 68 (no regression,
    unlike the cluster_m=2 attempt which rose to 96).

Adds test_matmul_mma_fp8_tiny_m_cluster_n2_a_multicast (3 seeds, per-row check,
asserts cluster shape (2,1,1), cross-checks vs the cluster_n=1 baseline) and
dot-requirements coverage for the newly accepted/rejected cluster configs.

Co-Authored-By: Claude Opus 4.8 (1M context) noreply@anthropic.com

yushangdi added a commit that referenced this pull request Jun 23, 2026
…or tiny-M fp8

For a tiny-M fp8 GEMM (M<=64) the tcgen05 path runs one M tile across many N
tiles (e.g. M=16, N=14336 => 112 N tiles), and each N-tile CTA re-reads the
same small A operand (M x K) from DRAM. NVIDIA's torch._scaled_mm CUTLASS
kernel handles exactly this case by selecting an N-axis 2-CTA cluster: see
aten/src/ATen/native/cuda/RowwiseScaledMM.cu, where M=16,N>=3072 dispatches
ClusterShape <1,2,1>. Two CTAs covering adjacent N tiles form a cluster and
TMA-multicast the shared A operand, cutting A traffic and improving B-stream
efficiency. (Clustering along M is useless here -- there is only one M tile;
an earlier cluster_m=2 attempt regressed for that reason.)

Helion previously allowed cluster_n>1 only inside the 4-CTA envelope
(cluster_m=2 AND cluster_n=2, CtaGroup.TWO). This enables a standalone
cluster_m=1, cluster_n=2, CtaGroup.ONE configuration where each CTA still owns
its own 64x128 output tile and the cluster exists solely to multicast A.

Changes:
- cute_mma.py: accept an n-only cluster (cluster_m=1, cluster_n=2, flat TMA
  pipeline, non-persistent, single M tile m_size<=bm, N % (2*bn) == 0,
  K % bk == 0) as a second cluster_n=2 envelope alongside the 4-CTA one. Build
  the A TMA multicast atom + mcast_mode=2 mask and set the AB-pipeline consumer
  arrive count to num_mcast_ctas_a + num_mcast_ctas_b - 1 = 2 for this case
  (B stays per-CTA: cluster_m=1 means no M peers to share B). Launch
  cluster=(cluster_n,1,1) along the flat grid_x (consecutive grid_x = adjacent
  N tiles) while keeping the logical (1,2,1) cluster_layout_vmnk so the A mask
  (mode 2 = N) and block_idx_in_cluster->N-coord mapping are correct. The
  mainloop stays flat/owner-ungated (CtaGroup.ONE => every CTA is a cluster
  leader); the padded-M masked store is unchanged.
- strategies.py: relax the config-spec invariant so cluster_n=2 with
  cluster_m=1 is admitted under role_local_monolithic + non_persistent, keeping
  codegen and config-spec consistent.

Verification (B200 sm_100, M=16 K=4096 N=14336, fp8 e4m3):
- Correctness with range-filling fp8 + identity scales vs a float32 matmul:
  relerr ~1.7e-3 (fp8 floor), identical to the cluster_n=1 baseline, across 3
  seeds; per-row relerr uniform over all 16 rows, zero all-zero rows (no
  half-output/missing-CTA bug); left/right N halves equally correct.
  cute.nvgpu.tcgen05 present. No-crash sweep static_m in {16,32,64} x block_m
  in {64,128} x cluster_n in {1,2} all pass.
- Performance (warm-L2, do_bench_wrapper, cudagraph, A/B in one session):
  cluster_n=1 16.27us / 3.64 TB/s -> cluster_n=2 14.87us / 3.99 TB/s (~1.1x;
  larger at locked/idle clocks). torch._scaled_mm 9.68us / 6.12 TB/s.
- NCU: launch__cluster_dim_x=2 (cluster active), L2 read hit 11.0% -> 14.4%,
  DRAM %peak 42.2% -> 43.7%, registers/thread unchanged at 68 (no regression,
  unlike the cluster_m=2 attempt which rose to 96).

Adds test_matmul_mma_fp8_tiny_m_cluster_n2_a_multicast (3 seeds, per-row check,
asserts cluster shape (2,1,1), cross-checks vs the cluster_n=1 baseline) and
dot-requirements coverage for the newly accepted/rejected cluster configs.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

stack-info: PR: #2845, branch: yushangdi/stack/62
@yushangdi yushangdi force-pushed the yushangdi/stack/61 branch from e3394ea to e968b0b Compare June 23, 2026 01:22
@yushangdi yushangdi force-pushed the yushangdi/stack/62 branch from 39714d4 to 6af3fdc Compare June 23, 2026 01:22
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 23, 2026
…or tiny-M fp8

For a tiny-M fp8 GEMM (M<=64) the tcgen05 path runs one M tile across many N
tiles (e.g. M=16, N=14336 => 112 N tiles), and each N-tile CTA re-reads the
same small A operand (M x K) from DRAM. NVIDIA's torch._scaled_mm CUTLASS
kernel handles exactly this case by selecting an N-axis 2-CTA cluster: see
aten/src/ATen/native/cuda/RowwiseScaledMM.cu, where M=16,N>=3072 dispatches
ClusterShape <1,2,1>. Two CTAs covering adjacent N tiles form a cluster and
TMA-multicast the shared A operand, cutting A traffic and improving B-stream
efficiency. (Clustering along M is useless here -- there is only one M tile;
an earlier cluster_m=2 attempt regressed for that reason.)

Helion previously allowed cluster_n>1 only inside the 4-CTA envelope
(cluster_m=2 AND cluster_n=2, CtaGroup.TWO). This enables a standalone
cluster_m=1, cluster_n=2, CtaGroup.ONE configuration where each CTA still owns
its own 64x128 output tile and the cluster exists solely to multicast A.

Changes:
- cute_mma.py: accept an n-only cluster (cluster_m=1, cluster_n=2, flat TMA
  pipeline, non-persistent, single M tile m_size<=bm, N % (2*bn) == 0,
  K % bk == 0) as a second cluster_n=2 envelope alongside the 4-CTA one. Build
  the A TMA multicast atom + mcast_mode=2 mask and set the AB-pipeline consumer
  arrive count to num_mcast_ctas_a + num_mcast_ctas_b - 1 = 2 for this case
  (B stays per-CTA: cluster_m=1 means no M peers to share B). Launch
  cluster=(cluster_n,1,1) along the flat grid_x (consecutive grid_x = adjacent
  N tiles) while keeping the logical (1,2,1) cluster_layout_vmnk so the A mask
  (mode 2 = N) and block_idx_in_cluster->N-coord mapping are correct. The
  mainloop stays flat/owner-ungated (CtaGroup.ONE => every CTA is a cluster
  leader); the padded-M masked store is unchanged.
- strategies.py: relax the config-spec invariant so cluster_n=2 with
  cluster_m=1 is admitted under role_local_monolithic + non_persistent, keeping
  codegen and config-spec consistent.

Verification (B200 sm_100, M=16 K=4096 N=14336, fp8 e4m3):
- Correctness with range-filling fp8 + identity scales vs a float32 matmul:
  relerr ~1.7e-3 (fp8 floor), identical to the cluster_n=1 baseline, across 3
  seeds; per-row relerr uniform over all 16 rows, zero all-zero rows (no
  half-output/missing-CTA bug); left/right N halves equally correct.
  cute.nvgpu.tcgen05 present. No-crash sweep static_m in {16,32,64} x block_m
  in {64,128} x cluster_n in {1,2} all pass.
- Performance (warm-L2, do_bench_wrapper, cudagraph, A/B in one session):
  cluster_n=1 16.27us / 3.64 TB/s -> cluster_n=2 14.87us / 3.99 TB/s (~1.1x;
  larger at locked/idle clocks). torch._scaled_mm 9.68us / 6.12 TB/s.
- NCU: launch__cluster_dim_x=2 (cluster active), L2 read hit 11.0% -> 14.4%,
  DRAM %peak 42.2% -> 43.7%, registers/thread unchanged at 68 (no regression,
  unlike the cluster_m=2 attempt which rose to 96).

Adds test_matmul_mma_fp8_tiny_m_cluster_n2_a_multicast (3 seeds, per-row check,
asserts cluster shape (2,1,1), cross-checks vs the cluster_n=1 baseline) and
dot-requirements coverage for the newly accepted/rejected cluster configs.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

stack-info: PR: #2845, branch: yushangdi/stack/62
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant