Skip to content

Comments

Fix/pallas tests shared memory#581

Open
AratiGanesh wants to merge 1 commit intorocm-jaxlib-v0.8.0from
fix/pallas-tests-shared-memory
Open

Fix/pallas tests shared memory#581
AratiGanesh wants to merge 1 commit intorocm-jaxlib-v0.8.0from
fix/pallas-tests-shared-memory

Conversation

@AratiGanesh
Copy link

Motivation

Pallas GPU tests (gpu_ops_test.py and ops_test.py) were crashing with RESOURCE_EXHAUSTED errors on devices (MI250) with limited shared memory. This PR aims to estimate the shared memory and ensure that tests either adjust to fit the available memory or skip if they cannot run.

Technical Details

  1. gpu_ops_test.py - Uses similar logic as https://github.com/ROCm/jax/pull/559/files.
    Implements a shared memory estimate function for fused attention. Tests now automatically reduce block sizes (128x128 → 32x32) if the original config exceeds device limits. The backward pass logic also scales its gradients blocks proportionally to the forward pass.

  2. ops_test.py - Estimates the share memory and if it cannot fit matrices, it skips the test.

Test Plan

Rerun
pytest tests/pallas/gpu_ops_test.py -v
pytest jax/tests/pallas/gpu_ops_test.py -v

Test Result

All tests pass.

Submission Checklist

@AratiGanesh AratiGanesh requested a review from a team as a code owner December 15, 2025 15:39
@AratiGanesh AratiGanesh changed the base branch from main to rocm-jaxlib-v0.8.0 December 15, 2025 15:43
@AratiGanesh AratiGanesh requested review from mminutoli and removed request for a team December 15, 2025 15:46
Copy link

@mminutoli mminutoli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's really minor comments with the intent of making it so it is easier to review from upstream. Please, fix them and let's create a PR also for upstream jax.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we take this out of this PR?

I'd like to be an "atomic" change fixing only 1 issue that we can send to upstream jax.

block_k_orig = original_blocks["block_k"]
dtype = jnp.float16

adjusted_q, adjusted_k = self._adjust_mha_params_for_shared_memory(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am ok with this like this, but I am wondering if they'll want this done only for GPUs.

@AratiGanesh AratiGanesh force-pushed the fix/pallas-tests-shared-memory branch 2 times, most recently from 9f43140 to 7b0b014 Compare December 29, 2025 22:09
…ally reduce the block sizes until it fits the memory.

- Fix `ops_test.py`: Add accurate shared memory estimation code and skip test if memory is exceeded
@AratiGanesh AratiGanesh force-pushed the fix/pallas-tests-shared-memory branch from 7b0b014 to cff5d16 Compare December 29, 2025 22:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants