Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,32 @@ common:rbe_windows_amd64 --nobuild_python_zip

common:rbe_windows_amd64 --config=ci_windows_amd64

# RBE configs for ROCm
build:rocm_rbe --tls_client_certificate="ci-cert.crt"
build:rocm_rbe --tls_client_key="ci-cert.key"

build:rocm_rbe --bes_backend="grpcs://wardite.cluster.engflow.com"
build:rocm_rbe --bes_results_url="https://wardite.cluster.engflow.com/invocation/"
build:rocm_rbe --remote_executor="grpcs://wardite.cluster.engflow.com"
build:rocm_rbe --remote_cache="grpcs://wardite.cluster.engflow.com"
build:rocm_rbe --host_platform="//platform/linux:manylinux"
build:rocm_rbe --extra_execution_platforms="//platform/linux:manylinux"
build:rocm_rbe --platforms="//platform/linux:manylinux"
build:rocm_rbe --bes_timeout=600s
build:rocm_rbe --spawn_strategy=local
build:rocm_rbe --grpc_keepalive_time=30s
build:rocm_rbe --repo_env=REMOTE_GPU_TESTING=1

test:rocm_rbe --host_platform="//platform/linux:ubuntu_gpu"
test:rocm_rbe --extra_execution_platforms="//platform/linux:ubuntu_gpu"
test:rocm_rbe --platforms="//platform/linux:ubuntu_gpu"
test:rocm_rbe --remote_timeout=3600
test:rocm_rbe --jobs=200
test:rocm_rbe --test_sharding_strategy=disabled
test:rocm_rbe --strategy=TestRunner=remote,local
test:rocm_rbe --worker_sandboxing=true
test:rocm_rbe --repo_env=REMOTE_GPU_TESTING=1

# #############################################################################
# Cross-compile config options below. Native RBE support does not exist for
# Linux Aarch64 and Mac x86. So, we use a cross-compile toolchain to build
Expand Down
81 changes: 81 additions & 0 deletions .github/actions/download-jax-rocm-wheels/action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Composite action to download the jax, jaxlib, and the ROCM plugin wheels
name: Download JAX ROCM wheels

inputs:
python:
description: "Which python version should the artifact be downloaded for?"
type: string
required: true
rocm-version:
description: "Which rocm version should the artifact be downloaded for?"
type: string
default: "7"
skip-download-jaxlib-and-rocm-plugins-from-gh:
description: "Whether to skip downloading the jaxlib and rocm plugins from GCS (e.g for testing a jax only release)"
default: '0'
type: string
gh_download_uri:
description: "GCS location prefix from where the artifacts should be downloaded"
default: 'https://github.com/ROCm/rocm-jax/releases/download'
#default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
type: string
permissions: {}
runs:
using: "composite"

steps:
# Note that certain envs such as JAXCI_HERMETIC_PYTHON_VERSION are set by the calling workflow.
- name: Set env vars for use in artifact download URL
shell: bash
run: |
os=$(uname -s | awk '{print tolower($0)}')
arch=$(uname -m)

# Get the major and minor version of Python.
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.11, then python_major_minor=311
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t
python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.')

echo "OS=${os}" >> $GITHUB_ENV
echo "ARCH=${arch}" >> $GITHUB_ENV
# Python wheels follow a naming convention: standard wheels use the pattern
# `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use
# `*-cp<py_version>-cp<py_version>t-*`.
echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV

# Get the ROCM major version only
full_rocm_version="${{ inputs.rocm-version }}"
echo "JAXCI_ROCM_VERSION=${full_rocm_version%%.*}" >> $GITHUB_ENV
- name: Download wheels
shell: bash
id: download-wheel-artifacts
# Set continue-on-error to true to prevent actions from failing the workflow if this step
# fails. Instead, we verify the outcome in the next step so that we can print a more
# informative error message.
continue-on-error: true
run: |
mkdir -p $(pwd)/dist
if [[ "${{ inputs.download-jax-from-gcs }}" == "1" ]]; then
gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
else
echo "JAX wheel won't be downloaded, only jaxlib pre-built wheel is tested."
fi

# Do not download the jaxlib and ROCM plugin artifacts if we are testing a jax only
# release.
if [[ "${{ inputs.skip-download-jaxlib-and-rocm-plugins-from-gh }}" == "1" ]]; then
echo "JAX only release. Only downloading the jax wheel from the release bucket."
else
wget -P $(pwd)/dist/ "${{ inputs.gh_download_uri }}/jaxlib-0.8.2+rocm${{ inputs.rocm-version }}-cp3${PYTHON_MAJOR_VERSION}-cp3${PYTHON_MAJOR_VERSION}-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl"
wget -P $(pwd)/dist/ "${{ inputs.gh_download_uri }}/jax_rocm${JAXCI_ROCM_VERSION}_pjrt-0.8.2+rocm${{ inputs.rocm-version }}-py3${PYTHON_MAJOR_VERSION}-none-manylinux_2_28_x86_64.whl"
wget -P $(pwd)/dist/ "${{ inputs.gh_download_uri }}/jax_rocm${JAXCI_ROCM_VERSION}_plugin-0.8.2+rocm${{ inputs.rocm-version }}-cp3${PYTHON_MAJOR_VERSION}-cp3${PYTHON_MAJOR_VERSION}-manylinux_2_28_x86_64.whl"
fi
- name: Skip the test run if the wheel artifacts were not downloaded successfully
shell: bash
if: steps.download-wheel-artifacts.outcome == 'failure'
run: |
echo "Failed to download wheel artifacts. Please check if the wheels were"
echo "built successfully by the artifact build jobs and are available in the GCS bucket if
echo "downloading from GCS."
echo "Skipping the test run."
exit 1
120 changes: 120 additions & 0 deletions .github/workflows/bazel_rocm.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# CI - Bazel ROCM tests
#
# This workflow runs the CUDA tests with Bazel. It can only be triggered by other workflows via
# `workflow_call`. It is used by the `CI - Bazel ROCM tests (RBE)`,`CI - Wheel Tests (Continuous)`
# and `CI - Wheel Tests (Nightly/Release)` workflows to run the Bazel CUDA tests.
#
# It consists of the following job:
# run-tests:
# - Downloads the jaxlib and ROCM artifacts from GitHub if build_jaxlib is `false`.
# Otherwise, the artifacts are built from source.
# - Downloads the jax artifact from a GCS bucket if build_jax is `false`.
# Otherwise, the artifact is built from source.
# - If `run_multiaccelerator_tests` is `false`, executes the `run_bazel_test_rocm_rbe.sh` script,
# which performs the following actions:
# - `build_jaxlib=wheel`: Runs the Bazel CPU tests with py_import dependencies.
# - `build_jaxlib=false`: Runs the Bazel CPU tests with downloaded wheel dependencies.
# - `build_jaxlib=true`: Runs the Bazel CPU tests with individual Bazel target dependencies.
# - If `run_multiaccelerator_tests` is `true`, executes the `run_bazel_test_rocm_non_rbe.sh`
# script, which performs the following actions:
# - `build_jaxlib=wheel`: Runs the Bazel CPU tests with py_import dependencies.
# - `build_jaxlib=false`: Runs the Bazel CPU tests with downloaded wheel dependencies.

name: CI - Bazel CUDA tests

on:
workflow_call:
inputs:
runner:
description: "Which runner should the workflow run on?"
type: string
default: "linux-x86-n4-16"
python:
description: "Which python version to test?"
type: string
default: "3.12"
rocm-version:
description: "Which ROCM version to test?"
type: string
default: "7"
enable-x64:
description: "Should x64 mode be enabled?"
type: string
default: "0"
download-jax-from-gcs:
description: "Whether to download the jax wheel from GH"
default: '1'
type: string
skip-download-jaxlib-and-rocm-plugins-from-gh:
description: "Whether to skip downloading the jaxlib and rocm plugins from GH (e.g for testing a jax only release)"
default: '0'
type: string
gh_download_uri:
description: "GH location URI from where the artifacts should be downloaded"
default: 'https://github.com/ROCm/rocm-jax/releases/download'
type: string
build_jaxlib:
description: 'Should jaxlib be built from source?'
required: true
type: string
build_jax:
description: 'Should jax be built from source?'
required: true
type: string
write_to_bazel_remote_cache:
description: 'Whether to enable writing to the Bazel remote cache bucket'
required: false
default: '0'
type: string
run_multiaccelerator_tests:
description: 'Whether to run multi-accelerator tests'
required: false
default: 'false'
type: string
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: string
default: 'no'
permissions: {}
jobs:
run-tests:
defaults:
run:
# Explicitly set the shell to bash
shell: bash
runs-on: ${{ inputs.runner }}
container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest"

env:
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }}
JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }}
JAXCI_ROCM_VERSION: ${{ inputs.rocm-version }}
JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE: ${{ inputs.write_to_bazel_remote_cache }}
JAXCI_BUILD_JAX: ${{ inputs.build_jax }}
JAXCI_BUILD_JAXLIB: ${{ inputs.build_jaxlib }}
# Begin Presubmit Naming Check - name modification requires internal check to be updated
name: "${{ (contains(inputs.runner, 'linux-x86') && 'linux x86') ||
(contains(inputs.runner, 'linux-arm64') && 'linux arm64') ||
(contains(inputs.runner, 'windows-x86') && 'windows x86') }}, jaxlib=${{ inputs.jaxlib-version }}, CUDA=${{ inputs.rocm-version }}, Python=${{ inputs.python }}, x64=${{ inputs.enable-x64 }}, build_jax=${{ inputs.build_jax }}, build_jaxlib=${{ inputs.build_jaxlib }}"
# End Presubmit Naming Check github-cuda-presubmits
steps:
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
with:
persist-credentials: false
- name: Download JAX ROCM wheels
if: inputs.build_jaxlib == 'false'
uses: ./.github/actions/download-jax-rocm-wheels
with:
python: ${{ inputs.python }}
rocm-version: ${{ inputs.rocm-version }}
download-jax-from-gh: ${{ inputs.download-jax-from-gh }}
skip-download-jaxlib-and-rocm-plugins-from-gh: ${{ inputs.skip-download-jaxlib-and-rocm-plugins-from-gh }}
gh_download_uri: ${{ inputs.gh_download_uri }}
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: "Bazel ROCM tests with build_jax=${{ inputs.build_jax }}, build_jaxlib=${{ inputs.build_jaxlib }}"
timeout-minutes: 60
run: ${{ ((inputs.run_multiaccelerator_tests == 'false') && './ci/run_bazel_test_rocm_rbe.sh') || './ci/run_bazel_test_rocm_non_rbe.sh' }}
24 changes: 24 additions & 0 deletions .github/workflows/wheel_tests_nightly_release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,30 @@ jobs:
write_to_bazel_remote_cache: 1
run_multiaccelerator_tests: "true"

run-bazel-test-cuda:
uses: ./.github/workflows/bazel_rocm.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Runner OS and Python values need to match the matrix stategy of our internal CI jobs
# that build the wheels.
runner: ["linux-x86-g2-48-l4-4gpu"]
python: ["3.11", "3.12", "3.13", "3.13-nogil", "3.14", "3.14-nogil"]
rocm-version: [7]
enable-x64: [0]
name: "Bazel ROCM Non-RBE with ${{ format('{0}', 'build_jaxlib=false') }}"
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
rocm-version: ${{ matrix.rocm-version }}
enable-x64: ${{ matrix.enable-x64 }}
halt-for-connection: ${{inputs.halt-for-connection}}
build_jaxlib: "false"
build_jax: "false"
jaxlib-version: "head"
write_to_bazel_remote_cache: 1
run_multiaccelerator_tests: "true"

run-pytest-tpu:
uses: ./.github/workflows/pytest_tpu.yml
strategy:
Expand Down
2 changes: 2 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ python_init_repositories(
"jaxlib*",
"jax_cuda*",
"jax-cuda*",
"jax_rocm*",
"jax-rocm*",
],
local_wheel_workspaces = ["//jaxlib:jax.bzl"],
requirements = {
Expand Down
2 changes: 1 addition & 1 deletion ci/envs/default.env
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,4 @@ export JAXCI_BUILD_JAX=${JAXCI_BUILD_JAX:-true}
export JAXCI_BAZEL_OUTPUT_BASE=${JAXCI_BAZEL_OUTPUT_BASE:-}

# Controls whether to build or run CPU test targets.
export JAXCI_BAZEL_CPU_RBE_MODE=${JAXCI_BAZEL_CPU_RBE_MODE:-"test"}
export JAXCI_BAZEL_CPU_RBE_MODE=${JAXCI_BAZEL_CPU_RBE_MODE:-"test"}
76 changes: 76 additions & 0 deletions ci/run_bazel_test_rocm_rbe.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/bin/bash
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Runs Bazel GPU tests with RBE. This runs single accelerator tests with one
# GPU apiece on RBE.
#
# -e: abort script if one command fails
# -u: error if undefined variable used
# -x: log all commands
# -o history: record shell history
# -o allexport: export all functions and variables to be available to subscripts
set -exu -o history -o allexport

# Source default JAXCI environment variables.
source ci/envs/default.env

# Clone XLA at HEAD if path to local XLA is not provided
if [[ -z "$JAXCI_XLA_GIT_DIR" ]]; then
export JAXCI_CLONE_MAIN_XLA=1
fi

# Set up the build environment.
source "ci/utilities/setup_build_environment.sh"

if [[ "$JAXCI_BUILD_JAXLIB" == "false" ]]; then
WHEEL_SIZE_TESTS=""
else
WHEEL_SIZE_TESTS="//jaxlib/tools:jax_cuda_plugin_wheel_size_test \
//jaxlib/tools:jax_cuda_pjrt_wheel_size_test \
//jaxlib/tools:jaxlib_wheel_size_test"
fi

if [[ "$JAXCI_BUILD_JAX" != "false" ]]; then
WHEEL_SIZE_TESTS="$WHEEL_SIZE_TESTS //:jax_wheel_size_test"
fi

if [[ "$JAXCI_BUILD_JAXLIB" != "true" ]]; then
#cuda_libs_flag="--config=cuda_libraries_from_stubs"
cuda_libs_flag=""
else
cuda_libs_flag="--@local_config_cuda//cuda:override_include_cuda_libs=true"
fi

# Run Bazel GPU tests with RBE (single accelerator tests with one GPU apiece).
echo "Running RBE GPU tests..."

bazel test --config=rocm_rbe \
--config=rocm \
--repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \
--override_repository=xla="${JAXCI_XLA_GIT_DIR}" \
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \
--test_output=errors \
--test_env=TF_CPP_MIN_LOG_LEVEL=0 \
--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \
--test_tag_filters=-multiaccelerator \
--test_env=JAX_SKIP_SLOW_TESTS=true \
--action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \
--color=yes \
$cuda_libs_flag \
--//jax:build_jaxlib=$JAXCI_BUILD_JAXLIB \
--//jax:build_jax=$JAXCI_BUILD_JAX \
//tests:gpu_tests //tests:backend_independent_tests \
//tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \
$WHEEL_SIZE_TESTS
9 changes: 6 additions & 3 deletions jaxlib/jax.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,13 @@ def _gpu_test_deps():
"//jaxlib/rocm:gpu_only_test_deps",
"//jax_plugins:gpu_plugin_only_test_deps",
],
"//jax:config_build_jaxlib_false": [
"//jax:config_build_jaxlib_false": if_cuda_is_configured([
"//jaxlib/tools:pypi_jax_cuda_plugin_with_cuda_deps",
"//jaxlib/tools:pypi_jax_cuda_pjrt_with_cuda_deps",
],
]) + if_rocm_is_configured([
"//jaxlib/tools:rocm_plugin_kernels_wheel",
"//jaxlib/tools:rocm_plugin_pjrt_wheel",
]),
"//jax:config_build_jaxlib_wheel": [
"//jaxlib/tools:jax_cuda_plugin_py_import",
"//jaxlib/tools:jax_cuda_pjrt_py_import",
Expand Down Expand Up @@ -303,7 +306,7 @@ def jax_multiplatform_test(
shard_count = test_shards,
tags = test_tags,
main = main,
exec_properties = tf_exec_properties({"tags": test_tags}),
exec_properties = {} #tf_exec_properties({"tags": test_tags}),
)

def jax_generate_backend_suites(backends = []):
Expand Down
Loading