Skip to content

Fix JAX extension build with NVTE_UB_WITH_MPI=1#2835

Open
GaetanLepage wants to merge 2 commits intoNVIDIA:mainfrom
GaetanLepage:main
Open

Fix JAX extension build with NVTE_UB_WITH_MPI=1#2835
GaetanLepage wants to merge 2 commits intoNVIDIA:mainfrom
GaetanLepage:main

Conversation

@GaetanLepage
Copy link
Copy Markdown
Contributor

Description

When building Transformer Engine with NVTE_UB_WITH_MPI=1 and NVTE_FRAMEWORK=pytorch,jax, the JAX extension (transformer_engine_jax) fails to load at runtime with an undefined symbol error, while the PyTorch extension works fine.

In userbuffers.h, the ExtComm type is conditionally defined based on NVTE_UB_WITH_MPI:

#ifdef NVTE_UB_WITH_MPI
#define ExtComm MPI_Comm
#else
#define ExtComm const char *
#endif

This type flows into ExtAllgatherOp and ExtBarrierOp, which are parameters of the CommOverlapP2PBase constructor.
This means the constructor has a different mangled symbol name depending on whether NVTE_UB_WITH_MPI is defined.
The core library (libtransformer_engine.so) is built via CMake, which correctly sets -DNVTE_UB_WITH_MPI.
The PyTorch extension also adds this flag.
However, the JAX extension is missing this flag entirely.
As a result, transformer_engine_jax.so is compiled expecting the const char * variant of the constructor, while libtransformer_engine.so only exports the MPI_Comm variant, causing an undefined symbol error at import time.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

This PR adds the MPI include path and -DNVTE_UB_WITH_MPI compile definition to the JAX extension build, mirroring the existing handling in build_tools/pytorch.py.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 4, 2026

Greptile Summary

This PR fixes a runtime undefined symbol error that occurs when building Transformer Engine with NVTE_UB_WITH_MPI=1 and the JAX framework. The root cause is that the ExtComm typedef in userbuffers.h is conditionally defined based on NVTE_UB_WITH_MPI, which changes the mangled C++ symbol name of CommOverlapP2PBase constructor. The core library and PyTorch extension were compiled with -DNVTE_UB_WITH_MPI, producing the MPI_Comm-variant symbols, while the JAX extension was compiled without the flag, expecting const char *-variant symbols — causing an import-time link failure.

The fix extracts the MPI flag setup logic from build_tools/pytorch.py into a shared setup_mpi_flags() helper in build_tools/utils.py, and then calls it from both the JAX and PyTorch extension setup functions. This ensures both extensions receive the same -DNVTE_UB_WITH_MPI define and MPI include path when the flag is enabled.

  • build_tools/utils.py: Adds setup_mpi_flags(include_dirs, cxx_flags) utility function extracted from PyTorch extension setup
  • build_tools/pytorch.py: Replaces inline MPI flag logic with a call to setup_mpi_flags()
  • build_tools/jax.py: Adds setup_mpi_flags() call before defining the Pybind11Extension — the core fix

Confidence Score: 5/5

This PR is safe to merge — it correctly fixes a real build/runtime bug with a clean, minimal refactoring.

The change is a straightforward bug fix: the missing -DNVTE_UB_WITH_MPI flag in the JAX extension build is the confirmed root cause of the undefined symbol error, and the fix mirrors the existing, already-working PyTorch extension logic exactly. The refactoring into setup_mpi_flags() reduces duplication without any behavioural change. The only pre-existing caveat (empty MPI_HOME string bypassing the guard) was already flagged in a prior review thread and is out of scope here. No new issues were introduced.

No files require special attention.

Important Files Changed

Filename Overview
build_tools/jax.py Imports and calls setup_mpi_flags() before constructing the Pybind11Extension — the primary bug fix ensuring -DNVTE_UB_WITH_MPI and the MPI include path are applied to the JAX extension when the flag is set
build_tools/utils.py Adds setup_mpi_flags() helper that centralises the MPI include-path and compile-definition logic; logic is a direct lift from the existing pytorch.py code with no functional changes
build_tools/pytorch.py Replaces inline MPI flag block with a call to the new shared setup_mpi_flags() helper — pure refactoring, no behavioural change

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Build with NVTE_UB_WITH_MPI=1?] -->|Yes| B[setup_mpi_flags called]
    A -->|No| E[No MPI flags added]
    B --> C{MPI_HOME set?}
    C -->|No / None| D[Assert error:\nMPI_HOME must be set]
    C -->|Yes| F[append MPI_HOME/include\nto include_dirs]
    F --> G[append -DNVTE_UB_WITH_MPI\nto cxx_flags]
    G --> H[JAX & PyTorch extensions\nboth compiled with MPI_Comm variant\nof CommOverlapP2PBase symbols]
    E --> I[Extensions compiled with\nconst char* variant - only valid\nwhen core lib also built without MPI]
Loading

Greploops — Automatically fix all review issues by running /greploops in Claude Code. It iterates: fix, push, re-review, repeat until 5/5 confidence.
Use the Greptile plugin for Claude Code to query reviews, search comments, and manage custom context directly from your terminal.

Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +105 to +106
assert (
os.getenv("MPI_HOME") is not None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Empty MPI_HOME string bypasses the guard

os.getenv("MPI_HOME") returns None only when the variable is unset. If a user exports MPI_HOME="" (empty string), the assert passes (empty string is not None), and Path("") silently resolves to the current working directory — not a valid MPI installation — causing confusing compile errors downstream.

Consider checking for a non-empty value:

Suggested change
assert (
os.getenv("MPI_HOME") is not None
mpi_home = os.getenv("MPI_HOME")
assert mpi_home, (
"MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!"
)
mpi_path = Path(mpi_home)

This also avoids calling os.getenv("MPI_HOME") twice (once in the assert, once for Path(...)). Note: the same pattern exists in build_tools/pytorch.py line 71–74.

Signed-off-by: Gaetan Lepage <gaetan@glepage.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant