Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .github/workflows/build_artifacts.yml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@ jobs:
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
with:
persist-credentials: false
fetch-depth: 0
- name: Revert the problematic commit
run: |
git config --global --add safe.directory /__w/jax/jax
git config --global user.email "[email protected]"
git config --global user.name "Vlad Belitskiy"
git revert --no-edit 548eaa5b53afeba91518d4d9274f7198b55cc308
echo "Commit 548eaa5b53afeba91518d4d9274f7198b55cc308 reverted for this build."
- name: Configure Build Environment
shell: bash
run: |
Expand Down
12 changes: 6 additions & 6 deletions .github/workflows/pytest_tpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,16 @@ jobs:
echo "Unknown libtpu version type: ${{ inputs.libtpu-version-type }}"
exit 1
fi
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Pytest TPU tests
timeout-minutes: ${{ github.event_name == 'pull_request' && 30 || 210 }}
timeout-minutes: ${{ github.event_name == 'pull_request' && 210 || 210 }}
run: |
if [[ ${{ inputs.python }} == "3.13-nogil" && ${{ inputs.tpu-type }} == "v5e-8" ]]; then
echo "Uninstalling xprof as it is not compatible with python 3.13t."
$JAXCI_PYTHON -m uv pip uninstall xprof
fi
./ci/run_pytest_tpu.sh
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
with:
halt-dispatch-input: ${{ inputs.tpu-type == 'v6e-8' && 'yes' || 'no' }}
179 changes: 13 additions & 166 deletions .github/workflows/wheel_tests_continuous.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,21 @@
# runs Bazel TPU tests with py_import.

name: CI - Wheel Tests (Continuous)
permissions:
contents: read

on:
schedule:
- cron: "0 */3 * * *" # Run once every 3 hours
workflow_dispatch: # allows triggering the workflow run manually

pull_request:
branches:
- main
push:
branches:
- main
- 'release/**'
permissions: {}

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
Expand All @@ -54,7 +61,7 @@ jobs:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Runner OS and Python values need to match the matrix stategy in the CPU tests job
runner: ["linux-x86-n4-16", "linux-arm64-t2a-48", "windows-x86-n2-16"]
runner: ["linux-x86-n4-16"]
artifact: ["jaxlib"]
python: ["3.11"]
# Note: For reasons unknown, Github actions groups jobs with the same top-level name in the
Expand All @@ -69,164 +76,6 @@ jobs:
upload_artifacts_to_gcs: true
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'

build-cuda-artifacts:
uses: ./.github/workflows/build_artifacts.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Python values need to match the matrix stategy in the CUDA tests job below
runner: ["linux-x86-n4-16"]
artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"]
python: ["3.11",]
cuda-version: ["12", "13"]
name: "Build ${{ format('{0}', 'CUDA') }} artifacts"
with:
runner: ${{ matrix.runner }}
artifact: ${{ matrix.artifact }}
python: ${{ matrix.python }}
cuda-version: ${{ matrix.cuda-version }}
clone_main_xla: 1
upload_artifacts_to_gcs: true
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'

run-pytest-cpu:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
needs: [build-jax-artifact, build-jaxlib-artifact]
uses: ./.github/workflows/pytest_cpu.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Runner OS and Python values need to match the matrix stategy in the
# build_jaxlib_artifact job above
runner: ["linux-x86-n4-64", "linux-arm64-t2a-48", "windows-x86-n2-64"]
python: ["3.11",]
enable-x64: [1, 0]
name: "Pytest CPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
enable-x64: ${{ matrix.enable-x64 }}
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}

run-pytest-cuda:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
uses: ./.github/workflows/pytest_cuda.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Python values need to match the matrix stategy in the artifact build jobs above
# See exlusions for what is fully tested
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu", "linux-x86-a4-224-b200-1gpu"]
python: ["3.11",]
cuda: [
{version: "12.1", use-nvidia-pip-wheels: false},
{version: "12.9", use-nvidia-pip-wheels: true},
{version: "13", use-nvidia-pip-wheels: true},
]
enable-x64: [1, 0]
exclude:
# H100 runs only a single config, CUDA 12.9 Enable x64 1
- runner: "linux-x86-a3-8g-h100-8gpu"
cuda:
version: "12.1"
- runner: "linux-x86-a3-8g-h100-8gpu"
enable-x64: "0"
# B200 runs only a single config, CUDA 12.9 Enable x64 1
- runner: "linux-x86-a4-224-b200-1gpu"
cuda:
version: "12.1"
- runner: "linux-x86-a4-224-b200-1gpu"
enable-x64: "0"

name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }}, CUDA Pip packages = ${{ matrix.cuda.use-nvidia-pip-wheels }})"
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
cuda-version: ${{ matrix.cuda.version }}
use-nvidia-pip-wheels: ${{ matrix.cuda.use-nvidia-pip-wheels }}
enable-x64: ${{ matrix.enable-x64 }}
# GCS upload URI is the same for both artifact build jobs
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}

run-bazel-test-cpu-py-import:
uses: ./.github/workflows/bazel_cpu.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
runner: ["linux-x86-n4-16", "linux-arm64-t2a-48", "windows-x86-n2-16"]
python: ["3.11",]
enable-x64: [1, 0]
name: "Bazel CPU tests with ${{ format('{0}', 'build_jaxlib=wheel') }}"
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
enable-x64: ${{ matrix.enable-x64 }}
build_jaxlib: "wheel"
build_jax: "wheel"

run-bazel-test-cuda:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
uses: ./.github/workflows/bazel_cuda.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Python values need to match the matrix stategy in the build artifacts job above
runner: ["linux-x86-g2-48-l4-4gpu",]
python: ["3.11",]
cuda-version: ["12", "13"]
jaxlib-version: ["head", "pypi_latest"]
enable-x64: [1, 0]
name: "Bazel CUDA Non-RBE with build_jaxlib=false, (jax version = ${{ format('{0}', 'head') }})"
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
cuda-version: ${{ matrix.cuda-version }}
enable-x64: ${{ matrix.enable-x64 }}
jaxlib-version: ${{ matrix.jaxlib-version }}
# GCS upload URI is the same for both artifact build jobs
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
build_jaxlib: "false"
build_jax: "false"
write_to_bazel_remote_cache: 1
run_multiaccelerator_tests: "true"

run-bazel-test-cuda-py-import:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
uses: ./.github/workflows/bazel_cuda.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Python values need to match the matrix stategy in the build artifacts job above
runner: ["linux-x86-g2-48-l4-4gpu",]
python: ["3.11"]
cuda-version: ["12", "13"]
enable-x64: [1]
name: "Bazel CUDA Non-RBE with ${{ format('{0}', 'build_jaxlib=wheel') }}"
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
cuda-version: ${{ matrix.cuda-version }}
enable-x64: ${{ matrix.enable-x64 }}
build_jaxlib: "wheel"
build_jax: "wheel"
jaxlib-version: "head"
write_to_bazel_remote_cache: 1
run_multiaccelerator_tests: "true"

run-pytest-tpu:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
Expand All @@ -241,8 +90,7 @@ jobs:
tpu-specs: [
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"},
{type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"},
{type: "v7x-8", cores: "8", runner: "linux-x86-tpu7x-224-4tpu"}
{type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"}
]
libtpu-version-type: ["nightly"]
name: "Pytest TPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
Expand All @@ -266,8 +114,7 @@ jobs:
matrix:
python: ["3.11"]
tpu-specs: [
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"},
{type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"}
]
libtpu-version-type: ["nightly"]
name: "Bazel tests TPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
Expand All @@ -281,4 +128,4 @@ jobs:
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
build_jaxlib: "wheel"
build_jax: "wheel"
clone_main_xla: 1
clone_main_xla: 1
1 change: 1 addition & 0 deletions build/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ rich
matplotlib
auditwheel
scipy-stubs
# pytest-timeout
2 changes: 1 addition & 1 deletion ci/run_pytest_tpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then
--deselect=tests/pallas/tpu_pallas_call_print_test.py::PallasCallPrintTest \
--deselect=tests/pallas/tpu_sparsecore_pallas_test.py::DebugPrintTest \
--deselect=tests/pallas/tpu_pallas_interpret_thread_map_test.py::InterpretThreadMapTest::test_thread_map \
--maxfail=20 -m "not multiaccelerator" $IGNORE_FLAGS tests examples
--maxfail=20 --dist=loadfile -m "not multiaccelerator" $IGNORE_FLAGS tests examples

# Store the return value of the first command.
first_cmd_retval=$?
Expand Down
17 changes: 17 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"""pytest configuration"""

import os
import sys

import pytest


Expand Down Expand Up @@ -72,3 +74,18 @@ def pytest_collection() -> None:
os.environ.setdefault(
"CUDA_VISIBLE_DEVICES", str(xdist_worker_number % num_cuda_devices)
)


def pytest_runtest_logreport(report):
# Only look at the setup/call phase
if report.when == 'call':
# Get the worker ID
worker_id = getattr(report, "node", None)
if worker_id:
worker_id = worker_id.gateway.id
else:
worker_id = "master"

# Log to a file named after the worker
with open(f"test_order_{worker_id}.log", "a") as f:
f.write(f"{report.nodeid}\n")
1 change: 1 addition & 0 deletions tests/pallas/tpu_sparsecore_pallas_debug_check_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def setUp(self):

super().setUp()

@unittest.skip("Failing on all TPU versions: b/436509694")
def test_scalar_debug_check(self):
if not jtu.is_device_tpu_at_least(6):
# TODO: b/436509694 - Figure out why the test gets stuck on v5p.
Expand Down
Loading