Skip to content

Comments

update a test for checking zero ROCm GPU event#585

Open
cj401-amd wants to merge 11 commits intorocm-jaxlib-v0.8.0from
ci_cj_profiler_test_rocm-jaxlib-v0.8.0
Open

update a test for checking zero ROCm GPU event#585
cj401-amd wants to merge 11 commits intorocm-jaxlib-v0.8.0from
ci_cj_profiler_test_rocm-jaxlib-v0.8.0

Conversation

@cj401-amd
Copy link

@cj401-amd cj401-amd commented Dec 18, 2025

Motivation

for kernel_details test, it requires to build jaxlib like the followings, otherwise trace file might miss kernel_details.

python3 build/build.py build \
    --clang_path=/lib/llvm-18/bin/clang-18 \
    --wheels="jaxlib" \
    --target_cpu_features=native \
    --rocm_path=/opt/rocm \
    --rocm_version=7 \
    --rocm_amdgpu_targets=gfx942 \
    --verbose --bazel_options="--override_repository=xla=/your/xla"

Run specific ROCm profiler test

cd jax &&  python tests/profiler_test.py 
  • test evns
root@smc300x-clt-r4c6-18:/work/jax/dist# pip3 list | grep jax
jax                            0.8.0.dev20260127
jax-rocm7-pjrt                 0.8.0.dev20260127
jax-rocm7-plugin               0.8.0.dev20260127
jaxlib                         0.8.0.dev0+selfbuilt
[ RUN      ] ProfilerTest.test_rocm_kernel_details_in_trace_json
I0127 13:03:20.886666  883898 profiler_session.cc:103] Profiler session initializing.
I0127 13:03:20.886681  883898 profiler_session.cc:118] Profiler session started.
I0127 13:03:20.886698  883898 device_tracer_rocm.cc:51] GpuTracer created.
I0127 13:03:21.203999  883898 dot_search_space.cc:229] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs? Working around this by using the full hints set instead.
I0127 13:03:22.105101  883898 dot_search_space.cc:229] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs? Working around this by using the full hints set instead.
I0127 13:03:24.033801  883898 profiler_session.cc:68] Profiler session collecting data.
I0127 13:03:24.209389  883898 save_profile.cc:150] Collecting XSpace to repository: /tmp/tmpaasexs46/plugins/profile/2026_01_27_13_03_24/smc300x-clt-r4c6-18.xplane.pb
I0127 13:03:24.567035  883898 save_profile.cc:123] Creating directory: /tmp/tmpaasexs46/plugins/profile/2026_01_27_13_03_24

I0127 13:03:25.019298  883898 save_profile.cc:129] Dumped gzipped tool data for trace.json.gz to /tmp/tmpaasexs46/plugins/profile/2026_01_27_13_03_24/smc300x-clt-r4c6-18.trace.json.gz
I0127 13:03:25.077462  883898 profiler_session.cc:136] Profiler session tear down.
[       OK ] ProfilerTest.test_rocm_kernel_details_in_trace_json
[ RUN      ] ProfilerTest.test_rocm_profiling
I0127 13:03:25.167948  883898 profiler_session.cc:103] Profiler session initializing.
I0127 13:03:25.167988  883898 profiler_session.cc:118] Profiler session started.
I0127 13:03:25.168018  883898 device_tracer_rocm.cc:51] GpuTracer created.
I0127 13:03:26.219322  883898 dot_search_space.cc:229] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs? Working around this by using the full hints set instead.
I0127 13:03:26.792129  883898 dot_search_space.cc:229] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs? Working around this by using the full hints set instead.
I0127 13:03:27.985999  883898 dot_search_space.cc:229] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs? Working around this by using the full hints set instead.
I0127 13:03:28.260907  883898 profiler_session.cc:68] Profiler session collecting data.
I0127 13:03:28.451078  883898 save_profile.cc:150] Collecting XSpace to repository: /tmp/tmpcebxqk6h/plugins/profile/2026_01_27_13_03_28/smc300x-clt-r4c6-18.xplane.pb
I0127 13:03:28.805635  883898 save_profile.cc:123] Creating directory: /tmp/tmpcebxqk6h/plugins/profile/2026_01_27_13_03_28

I0127 13:03:29.257322  883898 save_profile.cc:129] Dumped gzipped tool data for trace.json.gz to /tmp/tmpcebxqk6h/plugins/profile/2026_01_27_13_03_28/smc300x-clt-r4c6-18.trace.json.gz
I0127 13:03:29.318556  883898 profiler_session.cc:136] Profiler session tear down.
[       OK ] ProfilerTest.test_rocm_profiling
----------------------------------------------------------------------
Ran 22 tests in 47.640s

Comment on lines 77 to 78
a = jnp.ones((1024, 1024), dtype=jnp.float16)
b = jnp.ones((1024, 1024), dtype=jnp.float16)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, just curious is it important to check this on large matrices? I mean, the goal of the test is to just check if GPU events are seen, so it should be irrelevant what matrices it uses. So, IIUC, just by using large matrices here, we only prolong test runtime and waste resources?

Copy link

@i-chaochen i-chaochen Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC we used to have no gpu events on small matrices?

Instead of just one big size, maybe it's better to test a number of permutation sizes, from 1024, 512, 128, 64 to 32. This gives us better guarantee on profiling robustness.

Copy link

@Arech8 Arech8 Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, so there might be some dynamic routing to a backend depending on a size of operation?

jax.jit() has a backend parameter, which could be set to "gpu" to explicitly request GPU to do the work. This should never be re-routed to a CPU, for example, so this should keep the code simple, clean, and still very lightweight. Is this correct?

Copy link

@i-chaochen i-chaochen Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ticket SWDEV-568283 and here https://github.com/ROCm/xla/blob/rocm-jaxlib-v0.7.1/xla/backends/profiler/gpu/rocm_profiler_sdk.cc#L427 a pre-fixed size and we haven't flushed properly. But this should be fixed here openxla/xla#34968 and it's backported to 0.8.0 as well. So it's best to check more matrices size on UTs just in case for future changes.

Copy link

@Arech8 Arech8 Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, thanks... so it happened because of profiler integration bugs? Oh, then absolutely agree, there should be several tests for different sizes starting from the smallest...

ADDED: remove the approval until resolved. Thanks Chao for noticing!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And since this gpu kernels, even small matrices will have similar gpu kenrels launch anway...


r = subprocess.run([sys.executable, "-c", code],
env=env, capture_output=True, text=True)
if r.returncode != 0:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not necessary for this PR, but for the future: https://docs.python.org/3.11/library/subprocess.html#subprocess.run subprocess.run has check=True argument that does the same error checking and throwing under the hood, so you don't have to.

Copy link

@Arech8 Arech8 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Chunyu for this test, it's really important to have it!

I'm approving it, but before merging, please consider my comment about matrices size: if I didn't miss anything, we could safe quite a compute on making it something trivially small like 8x8 instead of 1024x1024.

After the merge, please (this is really important!) also make a PR into the upstream jax-ml/jax. You might probably want to make the test disabled by default, if the profiling support it requires isn't merged to the upstream XLA yet. We'll enable it later once changes to XLA propagates to the upstream JAX XLA commit.

@Arech8 Arech8 added open-upstream Tag when you want a copy of this PR to be opened on upstream cherry-pick-candidate Mark a PR to be cherry-picked into the next ROCm JAX. Remove IIF the latest upstream contain the PR. labels Dec 22, 2025
@Arech8 Arech8 self-requested a review December 22, 2025 16:27
@i-chaochen i-chaochen removed their request for review January 19, 2026 14:04
@gulsumgudukbay
Copy link

@cj401-amd @Arech8 is this PR ready to merge?

@gulsumgudukbay
Copy link

@cj401-amd I realized that you have test failing in your PR description. is that the current outcome of this test?

@cj401-amd
Copy link
Author

@cj401-amd @Arech8 is this PR ready to merge?

@gulsumgudukbay

Yes. it's ready to be merged.

Previously, I posted the message showing failed test, which indicates no ROCm GPU profiling. so it can help catch the case of no GPU events.

@gulsumgudukbay
Copy link

@cj401-amd @Arech8 is this PR ready to merge?

@gulsumgudukbay

Yes. it's ready to be merged.

Previously, I posted the message showing failed test, which indicates no ROCm GPU profiling. so it can help catch the case of no GPU events.

does it require the local XLA path to be having those GPU events to be showing up? I want to merge this but I have to be sure that it is also working if we don't specify local XLA path.

@cj401-amd
Copy link
Author

@cj401-amd @Arech8 is this PR ready to merge?

@gulsumgudukbay
Yes. it's ready to be merged.
Previously, I posted the message showing failed test, which indicates no ROCm GPU profiling. so it can help catch the case of no GPU events.

does it require the local XLA path to be having those GPU events to be showing up? I want to merge this but I have to be sure that it is also working if we don't specify local XLA path.

I believe so. the upstream is here jax-ml#34135.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cherry-pick-candidate Mark a PR to be cherry-picked into the next ROCm JAX. Remove IIF the latest upstream contain the PR. open-upstream Tag when you want a copy of this PR to be opened on upstream

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants