Enable scaled matmul tests on ROCm#678
Draft
magaonka-amd wants to merge 6 commits intoROCm:rocm-jaxlib-v0.8.0from
Draft
Enable scaled matmul tests on ROCm#678magaonka-amd wants to merge 6 commits intoROCm:rocm-jaxlib-v0.8.0from
magaonka-amd wants to merge 6 commits intoROCm:rocm-jaxlib-v0.8.0from
Conversation
Update test_collectives to include upstream test case for nested
sharding configuration ((None, ("dp", "tp"), None), (None, ("dp"), None)).
Also restructure expected_hlos from flat tuples to nested lists to
support multiple patterns per configuration and improve test granularity
with subTest loop.
Add _are_specs_overlapping() helper to correctly detect when partition
specs share axis names. Fixes DuplicateSpecError when using nested
sharding specs like ('dp', 'tp') that overlap with single specs like 'dp'.
Prevents creating invalid PartitionSpecs with duplicate axis names in
both input and output shardings.
Update ScaledMatmulTest setUp to only check cuDNN on CUDA devices. ROCm uses hipBLASLt or XLA fallback path.
Define c_name_cuda and c_name_rocm for platform-specific matmul targets: - CUDA: __cudnn$blockScaledDot - ROCm: __cublas$lt$matmul$mx
Add platform detection to all 4 ScaledMatmulTest methods: - test_collectives - test_scaled_matmul_nvfp4 - test_scaled_matmul - test_scaled_matmul_sharded Primary path checks for hipBLASLt custom calls: - __cublas$lt$matmul$mx (MX format) - __cublas$lt$matmul (generic cublasLT) Fallback path checks for dequantize + matmul operations: - __triton_gemm - __cublas$gemm
Enable 4 test methods in ScaledMatmulTest: - test_collectives - test_scaled_matmul_nvfp4 - test_scaled_matmul - test_scaled_matmul_sharded
2a0e78e to
816d5f4
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Enable scaled_matmul_stablehlo_test on ROCm platform with support for both hipBLASLt and fallback execution paths.
This PR enables 4 test methods in
ScaledMatmulTestto run on ROCm:test_collectivestest_scaled_matmul_nvfp4test_scaled_matmultest_scaled_matmul_shardedChanges
Upstream Improvements (Commits 1-2)
((None, ("dp", "tp"), None), (None, ("dp"), None))expected_hlosfrom flat tuples to nested lists for better test granularityDuplicateSpecError_are_specs_overlapping()helper to detect when specs share axis namesROCm Enablement (Commits 3-6)
__cudnn$blockScaledDotfor CUDA,__cublas$lt$matmul$mxfor ROCm)@jtu.run_on_devices("cuda")to@jtu.run_on_devices("gpu")Implementation Details
HLO Validation
Tests validate two execution paths:
Primary path (hipBLASLt):
custom-callwith target__cublas$lt$matmul$mxFallback path (dequantize + matmul):
__triton_gemmor__cublas$gemmin HLOtest_collectivesuses regex pattern matching to preserve collective operation validationTestResults
with this fix below tests pass: