Fix JAX extension build with NVTE_UB_WITH_MPI=1#2835
Fix JAX extension build with NVTE_UB_WITH_MPI=1#2835GaetanLepage wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR fixes a runtime undefined symbol error that occurs when building Transformer Engine with The fix extracts the MPI flag setup logic from
Confidence Score: 5/5This PR is safe to merge — it correctly fixes a real build/runtime bug with a clean, minimal refactoring. The change is a straightforward bug fix: the missing -DNVTE_UB_WITH_MPI flag in the JAX extension build is the confirmed root cause of the undefined symbol error, and the fix mirrors the existing, already-working PyTorch extension logic exactly. The refactoring into setup_mpi_flags() reduces duplication without any behavioural change. The only pre-existing caveat (empty MPI_HOME string bypassing the guard) was already flagged in a prior review thread and is out of scope here. No new issues were introduced. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Build with NVTE_UB_WITH_MPI=1?] -->|Yes| B[setup_mpi_flags called]
A -->|No| E[No MPI flags added]
B --> C{MPI_HOME set?}
C -->|No / None| D[Assert error:\nMPI_HOME must be set]
C -->|Yes| F[append MPI_HOME/include\nto include_dirs]
F --> G[append -DNVTE_UB_WITH_MPI\nto cxx_flags]
G --> H[JAX & PyTorch extensions\nboth compiled with MPI_Comm variant\nof CommOverlapP2PBase symbols]
E --> I[Extensions compiled with\nconst char* variant - only valid\nwhen core lib also built without MPI]
Greploops — Automatically fix all review issues by running Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
build_tools/jax.py
Outdated
| assert ( | ||
| os.getenv("MPI_HOME") is not None |
There was a problem hiding this comment.
Empty
MPI_HOME string bypasses the guard
os.getenv("MPI_HOME") returns None only when the variable is unset. If a user exports MPI_HOME="" (empty string), the assert passes (empty string is not None), and Path("") silently resolves to the current working directory — not a valid MPI installation — causing confusing compile errors downstream.
Consider checking for a non-empty value:
| assert ( | |
| os.getenv("MPI_HOME") is not None | |
| mpi_home = os.getenv("MPI_HOME") | |
| assert mpi_home, ( | |
| "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!" | |
| ) | |
| mpi_path = Path(mpi_home) |
This also avoids calling os.getenv("MPI_HOME") twice (once in the assert, once for Path(...)). Note: the same pattern exists in build_tools/pytorch.py line 71–74.
Signed-off-by: Gaetan Lepage <gaetan@glepage.com>
for more information, see https://pre-commit.ci
Description
When building Transformer Engine with
NVTE_UB_WITH_MPI=1andNVTE_FRAMEWORK=pytorch,jax, the JAX extension (transformer_engine_jax) fails to load at runtime with an undefined symbol error, while the PyTorch extension works fine.In
userbuffers.h, theExtCommtype is conditionally defined based onNVTE_UB_WITH_MPI:This type flows into
ExtAllgatherOpandExtBarrierOp, which are parameters of theCommOverlapP2PBaseconstructor.This means the constructor has a different mangled symbol name depending on whether
NVTE_UB_WITH_MPIis defined.The core library (
libtransformer_engine.so) is built via CMake, which correctly sets-DNVTE_UB_WITH_MPI.The PyTorch extension also adds this flag.
However, the JAX extension is missing this flag entirely.
As a result,
transformer_engine_jax.sois compiled expecting theconst char *variant of the constructor, whilelibtransformer_engine.soonly exports theMPI_Commvariant, causing an undefined symbol error at import time.Type of change
Changes
This PR adds the MPI include path and
-DNVTE_UB_WITH_MPIcompile definition to the JAX extension build, mirroring the existing handling inbuild_tools/pytorch.py.Checklist: