Skip to content

Comments

Enable scaled matmul tests on ROCm#678

Draft
magaonka-amd wants to merge 6 commits intoROCm:rocm-jaxlib-v0.8.0from
magaonka-amd:enable-scaled-matmul-tests-v0.8.0
Draft

Enable scaled matmul tests on ROCm#678
magaonka-amd wants to merge 6 commits intoROCm:rocm-jaxlib-v0.8.0from
magaonka-amd:enable-scaled-matmul-tests-v0.8.0

Conversation

@magaonka-amd
Copy link

Enable scaled_matmul_stablehlo_test on ROCm platform with support for both hipBLASLt and fallback execution paths.

This PR enables 4 test methods in ScaledMatmulTest to run on ROCm:

  • test_collectives
  • test_scaled_matmul_nvfp4
  • test_scaled_matmul
  • test_scaled_matmul_sharded

Changes

Upstream Improvements (Commits 1-2)

  • Add missing 9th sharding test case with nested partition specs ((None, ("dp", "tp"), None), (None, ("dp"), None))
  • Restructure expected_hlos from flat tuples to nested lists for better test granularity
  • Fix overlapping partition specs bug in custom partitioner that caused DuplicateSpecError
  • Add _are_specs_overlapping() helper to detect when specs share axis names

ROCm Enablement (Commits 3-6)

  • Skip cuDNN and Blackwell architecture checks on ROCm devices
  • Add platform-specific custom call targets (__cudnn$blockScaledDot for CUDA, __cublas$lt$matmul$mx for ROCm)
  • Add platform detection to all 4 test methods
  • Handle XLA fallback path when hipBLASLt requirements not met (batch_size > 1, alignment constraints)
  • Change decorators from @jtu.run_on_devices("cuda") to @jtu.run_on_devices("gpu")

Implementation Details

HLO Validation

Tests validate two execution paths:

Primary path (hipBLASLt):

  • Checks for custom-call with target __cublas$lt$matmul$mx
  • Used when hipBLASLt constraints are met

Fallback path (dequantize + matmul):

  • When hipBLASLt requirements not met, XLA uses dequantize then matmul approach
  • Checks for __triton_gemm or __cublas$gemm in HLO
  • test_collectives uses regex pattern matching to preserve collective operation validation
  • Other tests use simple string checks

TestResults

with this fix below tests pass:

root@0db19fc140b0:/workspace/debug_session_MI300/rocm-jax/jax# pytest -v -k "test_collectives or test_scaled_matmul" /workspace/debug_session_MI300/rocm-jax/jax/tests/scaled_matmul_stablehlo_test.py
Test session starting on GPU ?
================================================================ test session starts ================================================================
platform linux -- Python 3.11.13, pytest-9.0.2, pluggy-1.6.0 -- /usr/local/bin/python3.11
cachedir: .pytest_cache
hypothesis profile 'default'
metadata: {'Python': '3.11.13', 'Platform': 'Linux-6.8.0-87-generic-x86_64-with-glibc2.35', 'Packages': {'pytest': '9.0.2', 'pluggy': '1.6.0'}, 'Plugins': {'json-report': '1.5.0', 'rerunfailures': '16.1', 'reportlog': '1.0.0', 'hypothesis': '6.150.0', 'html': '4.1.1', 'metadata': '3.1.1'}}
rootdir: /workspace/debug_session_MI300/rocm-jax/jax
configfile: pyproject.toml
plugins: json-report-1.5.0, rerunfailures-16.1, reportlog-1.0.0, hypothesis-6.150.0, html-4.1.1, metadata-3.1.1
collected 81 items / 43 deselected / 38 selected                                                                                                    

tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_collectives0 
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_collectives0 PASSED                                                             [  2%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_collectives1 
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_collectives1 PASSED                                                             [  5%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_collectives2 PASSED                                                             [  7%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_collectives3 
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_collectives3 PASSED                                                             [ 10%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_collectives4 
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_collectives4 PASSED                                                             [ 13%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_collectives5 
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_collectives5 PASSED                                                             [ 15%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_collectives6 
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_collectives6 PASSED                                                             [ 18%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_collectives7 
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_collectives7 PASSED                                                             [ 21%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_collectives8 
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_collectives8 PASSED                                                             [ 23%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul0 PASSED                                                           [ 26%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul1 PASSED                                                           [ 28%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul2 PASSED                                                           [ 31%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul3 PASSED                                                           [ 34%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul4 PASSED                                                           [ 36%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul5 PASSED                                                           [ 39%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul6 PASSED                                                           [ 42%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul7 PASSED                                                           [ 44%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul8 PASSED                                                           [ 47%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul9 PASSED                                                           [ 50%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul_nvfp40 PASSED                                                     [ 52%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul_nvfp41 PASSED                                                     [ 55%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul_nvfp42 PASSED                                                     [ 57%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul_nvfp43 PASSED                                                     [ 60%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul_nvfp44 PASSED                                                     [ 63%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul_nvfp45 PASSED                                                     [ 65%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul_nvfp46 PASSED                                                     [ 68%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul_nvfp47 PASSED                                                     [ 71%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul_nvfp48 PASSED                                                     [ 73%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul_nvfp49 PASSED                                                     [ 76%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul_sharded0 PASSED                                                   [ 78%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul_sharded1 PASSED                                                   [ 81%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul_sharded2 PASSED                                                   [ 84%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul_sharded3 PASSED                                                   [ 86%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul_sharded4 PASSED                                                   [ 89%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul_sharded5 PASSED                                                   [ 92%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul_sharded6 PASSED                                                   [ 94%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul_sharded7 PASSED                                                   [ 97%]
tests/scaled_matmul_stablehlo_test.py::ScaledMatmulTest::test_scaled_matmul_sharded8 PASSED                                                   [100%]

============================================== 38 passed, 43 deselected, 12 subtests passed in 40.04s ===============================================

Update test_collectives to include upstream test case for nested
sharding configuration ((None, ("dp", "tp"), None), (None, ("dp"), None)).
Also restructure expected_hlos from flat tuples to nested lists to
support multiple patterns per configuration and improve test granularity
with subTest loop.
Add _are_specs_overlapping() helper to correctly detect when partition
specs share axis names. Fixes DuplicateSpecError when using nested
sharding specs like ('dp', 'tp') that overlap with single specs like 'dp'.

Prevents creating invalid PartitionSpecs with duplicate axis names in
both input and output shardings.
Update ScaledMatmulTest setUp to only check cuDNN on CUDA devices.
ROCm uses hipBLASLt or XLA fallback path.
Define c_name_cuda and c_name_rocm for platform-specific matmul targets:
- CUDA: __cudnn$blockScaledDot
- ROCm: __cublas$lt$matmul$mx
Add platform detection to all 4 ScaledMatmulTest methods:
- test_collectives
- test_scaled_matmul_nvfp4
- test_scaled_matmul
- test_scaled_matmul_sharded

Primary path checks for hipBLASLt custom calls:
- __cublas$lt$matmul$mx (MX format)
- __cublas$lt$matmul (generic cublasLT)

Fallback path checks for dequantize + matmul operations:
- __triton_gemm
- __cublas$gemm
Enable 4 test methods in ScaledMatmulTest:
- test_collectives
- test_scaled_matmul_nvfp4
- test_scaled_matmul
- test_scaled_matmul_sharded
@magaonka-amd magaonka-amd force-pushed the enable-scaled-matmul-tests-v0.8.0 branch from 2a0e78e to 816d5f4 Compare January 30, 2026 22:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant