Use plain bazel to test jax, use hermetic rocm dependency#584
Use plain bazel to test jax, use hermetic rocm dependency#584alekstheod wants to merge 4 commits intorocm-jaxlib-v0.8.0from
Conversation
e959b62 to
0cc81a0
Compare
0cc81a0 to
24d5725
Compare
There was a problem hiding this comment.
Pull request overview
This PR introduces plain Bazel-based testing and build workflows for JAX with hermetic ROCm dependencies. The changes enable running JAX unit tests and building JAX wheels using Bazel commands instead of relying on external build systems.
Key changes:
- Added Bazel configuration and scripts for ROCm-based JAX testing and wheel deployment
- Configured hermetic Python 3.12 and ROCm 7.10.0 dependencies with specific GPU targets
- Exposed JAX wheel as a public filegroup for downstream consumption
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| build/rocm/run_jax_ut.sh | Script to execute JAX unit tests with ROCm configuration via Bazel |
| build/rocm/rocm.bazelrc | Bazel configuration defining ROCm build settings, compiler paths, and Python version |
| build/rocm/deploy_wheel.sh | Script to build and deploy JAX wheel to a specified location |
| build/rocm/BUILD | Build rules for wheel deployment script generation and binary target |
| BUILD.bazel | Exposed jax_wheel as a public filegroup for build/rocm targets |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
0ff21d1 to
5cf95d5
Compare
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
5cf95d5 to
6694656
Compare
There was a problem hiding this comment.
I'm not really sure what the purpose of this PR is. We already have code that runs unit tests in Bazel. So, I don't understand what the purpose of making a POC to implement something we already have is.
Aside from that, unit tests need to get run with rocm/rocm-jax as the starting point. Running the tests this way might exercise some of our jaxlib changes, but it won't test the kernels wheel and PJRT code that needs to be tested.
If we're trying to improve upon the Bazel unit test setup, we should use the scripts introduced in ROCm/rocm-jax#206 as a starting point.
This PR is not something to be merged in. The purpose is to show how we can do everything we do without a need of an additional wrapper, so using plain bazel commands. The goal is to get your feedback so if you think it would be more convenient to use this instead of the build.py script. |
This PR is a POC that we can use plain bazel to run the tests or build the jax wheels.
Deployment script of the wheel can be used like a following:
./build/rocm/deploy_wheel.sh /tmp/jax.whl