Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions jax/_src/cudnn/scaled_matmul_stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,12 @@ def _enable_all_reduce(lhs, rhs):
_, n_spec, rhs_k_spec = rhs.spec
return lhs_k_spec != None and lhs_k_spec == rhs_k_spec and n_spec == None

def _are_specs_overlapping(lhs, rhs):
if lhs is None or rhs is None:
return False
lhs = (lhs,) if isinstance(lhs, str) else lhs
rhs = (rhs,) if isinstance(rhs, str) else rhs
return not set(lhs).isdisjoint(rhs)

def _get_output_sharding(shardings):
lhs, rhs = shardings[0], shardings[1]
Expand Down Expand Up @@ -241,7 +247,8 @@ def named_sharding(lhs, rhs, lhs_specs, rhs_specs):
lhs_specs[2] = None
rhs_specs[2] = None
m_spec, n_spec = lhs_specs[1], rhs_specs[1]
if m_spec == n_spec:
# Check if m_spec and n_spec share any axis names to avoid duplicates
if _are_specs_overlapping(m_spec, n_spec):
rhs_specs[1] = None

return named_sharding(lhs_sharding, rhs_sharding, lhs_specs, rhs_specs)
Expand All @@ -259,7 +266,8 @@ def _supported_out_sharding(lhs, rhs, reduce_scatter_dim):
out_n_spec = k_spec
else:
out_m_spec = m_spec
out_n_spec = n_spec if m_spec != n_spec else None
# Check if m_spec and n_spec share any axis names to avoid duplicates
out_n_spec = n_spec if not _are_specs_overlapping(m_spec, n_spec) else None

return [NamedSharding(lhs.mesh, P(batch_spec, out_m_spec, out_n_spec))]

Expand Down
150 changes: 116 additions & 34 deletions tests/scaled_matmul_stablehlo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,21 @@
((None, "dp", "tp"), (None, "dp", "tp")),
((None, "tp", None), (None, "tp", None)),
((None, None, "tp"), (None, "tp", None)),
((None, ("dp", "tp"), None), (None, ("dp"), None)),
]
c_name = "__cudnn$blockScaledDot"
c_name_cuda = "__cudnn$blockScaledDot"
c_name_rocm = "__cublas$lt$matmul$mx"
c_name = c_name_cuda
expected_hlos = [
(c_name, "all-reduce", "f32[1,512,512]", "replica_groups={{0,1},{2,3}}"),
("all-gather", "f8e4m3fn[512,512]", "replica_groups=[2,2]<=[4]", c_name),
("all-gather", "f8e4m3fn[512,512]", "replica_groups=[2,2]<=[4]", c_name),
(c_name,),
("all-gather", "f8e4m3fn[256,1024]", "replica_groups=[2,2]<=[4]", c_name),
(c_name,),
("all-gather", "f8e4m3fn[2,512,1024]", "replica_groups=[2,2]<=[4]", c_name),
("all-gather", "f8e4m3fn[2,512,512]", "replica_groups=[2,2]<=[4]", c_name),
[("all-reduce", "f32[1,512,512]", "replica_groups={{0,1},{2,3}}"), (c_name,)],
[("all-gather", "f8e4m3fn[512,512]", "replica_groups=[2,2]<=[4]"), (c_name,)],
[("all-gather", "f8e4m3fn[512,512]", "replica_groups=[2,2]<=[4]"), (c_name,)],
[(c_name,)],
[("all-gather", "f8e4m3fn[256,1024]", "replica_groups=[2,2]<=[4]"), (c_name,)],
[(c_name,)],
[("all-gather", "f8e4m3fn[2,512,1024]", "replica_groups=[2,2]<=[4]"), (c_name,)],
[("all-gather", "f8e4m3fn[2,512,512]", "replica_groups=[2,2]<=[4]"), (c_name,)],
[("all-gather", "f8e4m3fn[2,256,1024]", "replica_groups=[2,2]<=[2,2]"), (c_name,)],
]
expected_output_spec = [
PartitionSpec('dp',),
Expand All @@ -65,13 +69,14 @@
PartitionSpec(None, 'dp'),
PartitionSpec(None, 'tp', None),
PartitionSpec(None, None, 'tp'),
PartitionSpec(None, ('dp', 'tp'), None),
]

# The GSPMD sharding logic inserts additional reduce-scatters which don't exist
# in Shardy.
if not config.use_shardy_partitioner.value:
expected_output_spec[5] = PartitionSpec(None, 'dp', 'tp')
expected_hlos[5] += ("reduce-scatter", "f32[2,256,512]", "replica_groups={{0,1},{2,3}}")
expected_hlos[5] += [("reduce-scatter", "f32[2,256,512]", "replica_groups={{0,1},{2,3}}")]

sharding_configs = {
input_sharding: (hlo, output_spec)
Expand Down Expand Up @@ -269,40 +274,72 @@ class ScaledMatmulTest(jtu.JaxTestCase):

def setUp(self):
super().setUp()
try:
check_cudnn_version()
except RuntimeError as e:
self.skipTest(str(e))
if not jtu.is_cuda_compute_capability_at_least("10.0"):
self.skipTest("Requires at least Blackwell arch")
if jtu.test_device_matches(["cuda"]):
try:
check_cudnn_version()
except RuntimeError as e:
self.skipTest(str(e))
if not jtu.is_cuda_compute_capability_at_least("10.0"):
self.skipTest("Requires at least Blackwell arch")

mxfp8_configs = create_mxfp8_configs()

@jtu.sample_product(
in_shardings=sharding_configs,
block_scale_configs=[mxfp8_configs,],
)
@jtu.run_on_devices("cuda")
@jtu.run_on_devices("gpu")
def test_collectives(self, in_shardings, block_scale_configs):
if jtu.device_under_test() != "gpu" or len(jax.local_devices()) < 4:
self.skipTest("Partition Test enabled for at least 4 GPUs")

if jtu.test_device_matches(["rocm"]):
platform_c_name = c_name_rocm
else:
platform_c_name = c_name_cuda

expected_hlo = sharding_configs[in_shardings][0]
expected_hlo = [
tuple(platform_c_name if x == c_name else x for x in pattern)
for pattern in expected_hlo
]

hlo_text = get_hlo_text(in_shardings, block_scale_configs)

hlo_pattern = re.compile(
r".*".join([re.escape(x) for x in expected_hlo]), flags=re.DOTALL
)
self.assertRegex(
hlo_text, hlo_pattern, msg=f"Failed to find pattern: {expected_hlo}"
)
for expected_hlo_patterns in expected_hlo:
hlo_pattern_str = r".*".join(map(re.escape, expected_hlo_patterns))
hlo_pattern = re.compile(hlo_pattern_str, flags=re.DOTALL)

if jtu.test_device_matches(["rocm"]):
# Try both MX and generic cublasLT variants
pattern_mx = re.compile(hlo_pattern_str, flags=re.DOTALL)
pattern_generic = re.compile(
r".*".join([re.escape(x) if x != platform_c_name else r"__cublas\$lt\$matmul" for x in expected_hlo_patterns]),
flags=re.DOTALL
)
primary_matched = re.search(pattern_mx, hlo_text) or re.search(pattern_generic, hlo_text)

if not primary_matched:
fallback_patterns = [
re.compile(r".*".join([re.escape(x) if x != platform_c_name else r"(__triton_gemm|__cublas\$gemm)" for x in expected_hlo_patterns]), flags=re.DOTALL)
]
pattern_matched = any(re.search(p, hlo_text) for p in fallback_patterns)
if not pattern_matched:
with self.subTest(pattern=hlo_pattern_str):
self.fail(f"Failed to find pattern: {hlo_pattern_str} or fallback matmul pattern")
else:
with self.subTest(pattern=hlo_pattern_str):
self.assertTrue(True)
else:
with self.subTest(pattern=hlo_pattern_str):
self.assertRegex(hlo_text, hlo_pattern, msg=f"Failed to find pattern: {hlo_pattern_str}")

@jtu.sample_product(
contract=[160, 96],
lhs_non_contract=[240, 100],
dtype=[jnp.float32, jnp.bfloat16, jnp.float16],
)
@jtu.run_on_devices("cuda")
@jtu.run_on_devices("gpu")
def test_scaled_matmul_nvfp4(
self, contract, lhs_non_contract, dtype,
):
Expand Down Expand Up @@ -335,10 +372,26 @@ def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type):
.compile()
.as_text()
)

if jtu.test_device_matches(["rocm"]):
platform_c_name = c_name_rocm
else:
platform_c_name = c_name_cuda

hlo_pattern = re.compile(
r".*".join([re.escape(x) for x in ("custom-call", c_name)])
)
self.assertRegex(hlo_text, hlo_pattern)
r".*".join([re.escape(x) for x in ("custom-call", platform_c_name)])
)

if jtu.test_device_matches(["rocm"]):
# Try both MX and generic cublasLT variants
pattern_generic = re.compile(r"custom\-call.*__cublas\$lt\$matmul", flags=re.DOTALL)
primary_matched = re.search(hlo_pattern, hlo_text) or re.search(pattern_generic, hlo_text)

if not primary_matched:
if "__triton_gemm" not in hlo_text and "__cublas$gemm" not in hlo_text:
self.fail(f"Expected {platform_c_name} or __cublas$lt$matmul or fallback (__triton_gemm/__cublas$gemm)")
else:
self.assertRegex(hlo_text, hlo_pattern)

out = j_scaled_matmul(a_q, b_q, a_s, b_s)
out_ref = jnp.einsum(
Expand All @@ -354,7 +407,7 @@ def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type):
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
block_scale_configs=[mxfp8_configs,],
)
@jtu.run_on_devices("cuda")
@jtu.run_on_devices("gpu")
def test_scaled_matmul(
self, contract, lhs_non_contract, dtype, block_scale_configs,
):
Expand All @@ -379,10 +432,26 @@ def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type):
.compile()
.as_text()
)

if jtu.test_device_matches(["rocm"]):
platform_c_name = c_name_rocm
else:
platform_c_name = c_name_cuda

hlo_pattern = re.compile(
r".*".join([re.escape(x) for x in ("custom-call", c_name)])
)
self.assertRegex(hlo_text, hlo_pattern)
r".*".join([re.escape(x) for x in ("custom-call", platform_c_name)])
)

if jtu.test_device_matches(["rocm"]):
# Try both MX and generic cublasLT variants
pattern_generic = re.compile(r"custom\-call.*__cublas\$lt\$matmul", flags=re.DOTALL)
primary_matched = re.search(hlo_pattern, hlo_text) or re.search(pattern_generic, hlo_text)

if not primary_matched:
if "__triton_gemm" not in hlo_text and "__cublas$gemm" not in hlo_text:
self.fail(f"Expected {platform_c_name} or __cublas$lt$matmul or fallback (__triton_gemm/__cublas$gemm)")
else:
self.assertRegex(hlo_text, hlo_pattern)

out = j_scaled_matmul(a_q, b_q, a_scales, b_scales)
out_ref = np.einsum(
Expand All @@ -396,7 +465,7 @@ def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type):
in_shardings=sharding_configs,
block_scale_configs=[mxfp8_configs,],
)
@jtu.run_on_devices("cuda")
@jtu.run_on_devices("gpu")
def test_scaled_matmul_sharded(self, in_shardings, block_scale_configs):
if len(jax.local_devices()) < 4:
self.skipTest("Require at least 4 devices to run sharding tests.")
Expand Down Expand Up @@ -427,10 +496,23 @@ def test_scaled_matmul_sharded(self, in_shardings, block_scale_configs):
scaled_matmul_wrapper, in_shardings=input_shardings
)
hlo_compiled = j_scaled_matmul.lower(*args).compile()
hlo_text = hlo_compiled.as_text()

if jtu.test_device_matches(["rocm"]):
platform_c_name = c_name_rocm
else:
platform_c_name = c_name_cuda

hlo_pattern = re.compile(
r".*".join([re.escape(x) for x in ("custom-call", c_name)])
r".*".join([re.escape(x) for x in ("custom-call", platform_c_name)])
)
self.assertRegex(hlo_compiled.as_text(), hlo_pattern)

if jtu.test_device_matches(["rocm"]) and not re.search(hlo_pattern, hlo_text):
fallback_found = "__triton_gemm" in hlo_text
if not fallback_found:
self.fail(f"Expected {platform_c_name} or fallback (__triton_gemm)")
else:
self.assertRegex(hlo_text, hlo_pattern)

j_ref = jax.jit(
partial(
Expand Down