[cute] Enable N-axis cluster (cluster_n=2, cluster_m=1) A-multicast for tiny-M fp8#2845
Draft
yushangdi wants to merge 1 commit into
Draft
[cute] Enable N-axis cluster (cluster_n=2, cluster_m=1) A-multicast for tiny-M fp8#2845yushangdi wants to merge 1 commit into
yushangdi wants to merge 1 commit into
Conversation
yushangdi
added a commit
that referenced
this pull request
Jun 23, 2026
…or tiny-M fp8
For a tiny-M fp8 GEMM (M<=64) the tcgen05 path runs one M tile across many N
tiles (e.g. M=16, N=14336 => 112 N tiles), and each N-tile CTA re-reads the
same small A operand (M x K) from DRAM. NVIDIA's torch._scaled_mm CUTLASS
kernel handles exactly this case by selecting an N-axis 2-CTA cluster: see
aten/src/ATen/native/cuda/RowwiseScaledMM.cu, where M=16,N>=3072 dispatches
ClusterShape <1,2,1>. Two CTAs covering adjacent N tiles form a cluster and
TMA-multicast the shared A operand, cutting A traffic and improving B-stream
efficiency. (Clustering along M is useless here -- there is only one M tile;
an earlier cluster_m=2 attempt regressed for that reason.)
Helion previously allowed cluster_n>1 only inside the 4-CTA envelope
(cluster_m=2 AND cluster_n=2, CtaGroup.TWO). This enables a standalone
cluster_m=1, cluster_n=2, CtaGroup.ONE configuration where each CTA still owns
its own 64x128 output tile and the cluster exists solely to multicast A.
Changes:
- cute_mma.py: accept an n-only cluster (cluster_m=1, cluster_n=2, flat TMA
pipeline, non-persistent, single M tile m_size<=bm, N % (2*bn) == 0,
K % bk == 0) as a second cluster_n=2 envelope alongside the 4-CTA one. Build
the A TMA multicast atom + mcast_mode=2 mask and set the AB-pipeline consumer
arrive count to num_mcast_ctas_a + num_mcast_ctas_b - 1 = 2 for this case
(B stays per-CTA: cluster_m=1 means no M peers to share B). Launch
cluster=(cluster_n,1,1) along the flat grid_x (consecutive grid_x = adjacent
N tiles) while keeping the logical (1,2,1) cluster_layout_vmnk so the A mask
(mode 2 = N) and block_idx_in_cluster->N-coord mapping are correct. The
mainloop stays flat/owner-ungated (CtaGroup.ONE => every CTA is a cluster
leader); the padded-M masked store is unchanged.
- strategies.py: relax the config-spec invariant so cluster_n=2 with
cluster_m=1 is admitted under role_local_monolithic + non_persistent, keeping
codegen and config-spec consistent.
Verification (B200 sm_100, M=16 K=4096 N=14336, fp8 e4m3):
- Correctness with range-filling fp8 + identity scales vs a float32 matmul:
relerr ~1.7e-3 (fp8 floor), identical to the cluster_n=1 baseline, across 3
seeds; per-row relerr uniform over all 16 rows, zero all-zero rows (no
half-output/missing-CTA bug); left/right N halves equally correct.
cute.nvgpu.tcgen05 present. No-crash sweep static_m in {16,32,64} x block_m
in {64,128} x cluster_n in {1,2} all pass.
- Performance (warm-L2, do_bench_wrapper, cudagraph, A/B in one session):
cluster_n=1 16.27us / 3.64 TB/s -> cluster_n=2 14.87us / 3.99 TB/s (~1.1x;
larger at locked/idle clocks). torch._scaled_mm 9.68us / 6.12 TB/s.
- NCU: launch__cluster_dim_x=2 (cluster active), L2 read hit 11.0% -> 14.4%,
DRAM %peak 42.2% -> 43.7%, registers/thread unchanged at 68 (no regression,
unlike the cluster_m=2 attempt which rose to 96).
Adds test_matmul_mma_fp8_tiny_m_cluster_n2_a_multicast (3 seeds, per-row check,
asserts cluster shape (2,1,1), cross-checks vs the cluster_n=1 baseline) and
dot-requirements coverage for the newly accepted/rejected cluster configs.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
stack-info: PR: #2845, branch: yushangdi/stack/62
e3394ea to
e968b0b
Compare
39714d4 to
6af3fdc
Compare
This was referenced Jun 23, 2026
…or tiny-M fp8
For a tiny-M fp8 GEMM (M<=64) the tcgen05 path runs one M tile across many N
tiles (e.g. M=16, N=14336 => 112 N tiles), and each N-tile CTA re-reads the
same small A operand (M x K) from DRAM. NVIDIA's torch._scaled_mm CUTLASS
kernel handles exactly this case by selecting an N-axis 2-CTA cluster: see
aten/src/ATen/native/cuda/RowwiseScaledMM.cu, where M=16,N>=3072 dispatches
ClusterShape <1,2,1>. Two CTAs covering adjacent N tiles form a cluster and
TMA-multicast the shared A operand, cutting A traffic and improving B-stream
efficiency. (Clustering along M is useless here -- there is only one M tile;
an earlier cluster_m=2 attempt regressed for that reason.)
Helion previously allowed cluster_n>1 only inside the 4-CTA envelope
(cluster_m=2 AND cluster_n=2, CtaGroup.TWO). This enables a standalone
cluster_m=1, cluster_n=2, CtaGroup.ONE configuration where each CTA still owns
its own 64x128 output tile and the cluster exists solely to multicast A.
Changes:
- cute_mma.py: accept an n-only cluster (cluster_m=1, cluster_n=2, flat TMA
pipeline, non-persistent, single M tile m_size<=bm, N % (2*bn) == 0,
K % bk == 0) as a second cluster_n=2 envelope alongside the 4-CTA one. Build
the A TMA multicast atom + mcast_mode=2 mask and set the AB-pipeline consumer
arrive count to num_mcast_ctas_a + num_mcast_ctas_b - 1 = 2 for this case
(B stays per-CTA: cluster_m=1 means no M peers to share B). Launch
cluster=(cluster_n,1,1) along the flat grid_x (consecutive grid_x = adjacent
N tiles) while keeping the logical (1,2,1) cluster_layout_vmnk so the A mask
(mode 2 = N) and block_idx_in_cluster->N-coord mapping are correct. The
mainloop stays flat/owner-ungated (CtaGroup.ONE => every CTA is a cluster
leader); the padded-M masked store is unchanged.
- strategies.py: relax the config-spec invariant so cluster_n=2 with
cluster_m=1 is admitted under role_local_monolithic + non_persistent, keeping
codegen and config-spec consistent.
Verification (B200 sm_100, M=16 K=4096 N=14336, fp8 e4m3):
- Correctness with range-filling fp8 + identity scales vs a float32 matmul:
relerr ~1.7e-3 (fp8 floor), identical to the cluster_n=1 baseline, across 3
seeds; per-row relerr uniform over all 16 rows, zero all-zero rows (no
half-output/missing-CTA bug); left/right N halves equally correct.
cute.nvgpu.tcgen05 present. No-crash sweep static_m in {16,32,64} x block_m
in {64,128} x cluster_n in {1,2} all pass.
- Performance (warm-L2, do_bench_wrapper, cudagraph, A/B in one session):
cluster_n=1 16.27us / 3.64 TB/s -> cluster_n=2 14.87us / 3.99 TB/s (~1.1x;
larger at locked/idle clocks). torch._scaled_mm 9.68us / 6.12 TB/s.
- NCU: launch__cluster_dim_x=2 (cluster active), L2 read hit 11.0% -> 14.4%,
DRAM %peak 42.2% -> 43.7%, registers/thread unchanged at 68 (no regression,
unlike the cluster_m=2 attempt which rose to 96).
Adds test_matmul_mma_fp8_tiny_m_cluster_n2_a_multicast (3 seeds, per-row check,
asserts cluster shape (2,1,1), cross-checks vs the cluster_n=1 baseline) and
dot-requirements coverage for the newly accepted/rejected cluster configs.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
stack-info: PR: #2845, branch: yushangdi/stack/62
6af3fdc to
9849188
Compare
This was referenced Jun 24, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stacked PRs:
[cute] Enable N-axis cluster (cluster_n=2, cluster_m=1) A-multicast for tiny-M fp8
For a tiny-M fp8 GEMM (M<=64) the tcgen05 path runs one M tile across many N
tiles (e.g. M=16, N=14336 => 112 N tiles), and each N-tile CTA re-reads the
same small A operand (M x K) from DRAM. NVIDIA's torch._scaled_mm CUTLASS
kernel handles exactly this case by selecting an N-axis 2-CTA cluster: see
aten/src/ATen/native/cuda/RowwiseScaledMM.cu, where M=16,N>=3072 dispatches
ClusterShape <1,2,1>. Two CTAs covering adjacent N tiles form a cluster and
TMA-multicast the shared A operand, cutting A traffic and improving B-stream
efficiency. (Clustering along M is useless here -- there is only one M tile;
an earlier cluster_m=2 attempt regressed for that reason.)
Helion previously allowed cluster_n>1 only inside the 4-CTA envelope
(cluster_m=2 AND cluster_n=2, CtaGroup.TWO). This enables a standalone
cluster_m=1, cluster_n=2, CtaGroup.ONE configuration where each CTA still owns
its own 64x128 output tile and the cluster exists solely to multicast A.
Changes:
pipeline, non-persistent, single M tile m_size<=bm, N % (2*bn) == 0,
K % bk == 0) as a second cluster_n=2 envelope alongside the 4-CTA one. Build
the A TMA multicast atom + mcast_mode=2 mask and set the AB-pipeline consumer
arrive count to num_mcast_ctas_a + num_mcast_ctas_b - 1 = 2 for this case
(B stays per-CTA: cluster_m=1 means no M peers to share B). Launch
cluster=(cluster_n,1,1) along the flat grid_x (consecutive grid_x = adjacent
N tiles) while keeping the logical (1,2,1) cluster_layout_vmnk so the A mask
(mode 2 = N) and block_idx_in_cluster->N-coord mapping are correct. The
mainloop stays flat/owner-ungated (CtaGroup.ONE => every CTA is a cluster
leader); the padded-M masked store is unchanged.
cluster_m=1 is admitted under role_local_monolithic + non_persistent, keeping
codegen and config-spec consistent.
Verification (B200 sm_100, M=16 K=4096 N=14336, fp8 e4m3):
relerr ~1.7e-3 (fp8 floor), identical to the cluster_n=1 baseline, across 3
seeds; per-row relerr uniform over all 16 rows, zero all-zero rows (no
half-output/missing-CTA bug); left/right N halves equally correct.
cute.nvgpu.tcgen05 present. No-crash sweep static_m in {16,32,64} x block_m
in {64,128} x cluster_n in {1,2} all pass.
cluster_n=1 16.27us / 3.64 TB/s -> cluster_n=2 14.87us / 3.99 TB/s (~1.1x;
larger at locked/idle clocks). torch._scaled_mm 9.68us / 6.12 TB/s.
DRAM %peak 42.2% -> 43.7%, registers/thread unchanged at 68 (no regression,
unlike the cluster_m=2 attempt which rose to 96).
Adds test_matmul_mma_fp8_tiny_m_cluster_n2_a_multicast (3 seeds, per-row check,
asserts cluster shape (2,1,1), cross-checks vs the cluster_n=1 baseline) and
dot-requirements coverage for the newly accepted/rejected cluster configs.
Co-Authored-By: Claude Opus 4.8 (1M context) noreply@anthropic.com