Fix/pallas tests shared memory#581
Open
AratiGanesh wants to merge 1 commit intorocm-jaxlib-v0.8.0from
Open
Conversation
mminutoli
requested changes
Dec 15, 2025
mminutoli
left a comment
There was a problem hiding this comment.
it's really minor comments with the intent of making it so it is easier to review from upstream. Please, fix them and let's create a PR also for upstream jax.
build/rocm-test-requirements.txt
Outdated
There was a problem hiding this comment.
Can we take this out of this PR?
I'd like to be an "atomic" change fixing only 1 issue that we can send to upstream jax.
| block_k_orig = original_blocks["block_k"] | ||
| dtype = jnp.float16 | ||
|
|
||
| adjusted_q, adjusted_k = self._adjust_mha_params_for_shared_memory( |
There was a problem hiding this comment.
I am ok with this like this, but I am wondering if they'll want this done only for GPUs.
9f43140 to
7b0b014
Compare
…ally reduce the block sizes until it fits the memory. - Fix `ops_test.py`: Add accurate shared memory estimation code and skip test if memory is exceeded
7b0b014 to
cff5d16
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Pallas GPU tests (gpu_ops_test.py and ops_test.py) were crashing with RESOURCE_EXHAUSTED errors on devices (MI250) with limited shared memory. This PR aims to estimate the shared memory and ensure that tests either adjust to fit the available memory or skip if they cannot run.
Technical Details
gpu_ops_test.py - Uses similar logic as https://github.com/ROCm/jax/pull/559/files.
Implements a shared memory estimate function for fused attention. Tests now automatically reduce block sizes (128x128 → 32x32) if the original config exceeds device limits. The backward pass logic also scales its gradients blocks proportionally to the forward pass.
ops_test.py - Estimates the share memory and if it cannot fit matrices, it skips the test.
Test Plan
Rerun
pytest tests/pallas/gpu_ops_test.py -v
pytest jax/tests/pallas/gpu_ops_test.py -v
Test Result
All tests pass.
Submission Checklist