Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion build_tools/build_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,27 @@ def run(self) -> None:
install_dir=install_dir,
)

# Build non-CMake extensions as usual
# Build non-CMake extensions as usual.
# Add cmake install/build dirs to library_dirs so the linker
# can find libtransformer_engine.so at link time.
cmake_lib_dirs = []
for ext in self.extensions:
if isinstance(ext, CMakeExtension):
package_path = Path(self.get_ext_fullpath(ext.name))
cmake_lib_dirs.append(str(package_path.resolve().parent))
build_dir = os.getenv("NVTE_CMAKE_BUILD_DIR")
if build_dir:
cmake_lib_dirs.append(str(Path(build_dir).resolve()))
else:
root_dir = Path(__file__).resolve().parent.parent
cmake_lib_dirs.append(str(root_dir / "build" / "cmake"))

all_extensions = self.extensions
self.extensions = [
ext for ext in self.extensions if not isinstance(ext, CMakeExtension)
]
for ext in self.extensions:
ext.library_dirs = cmake_lib_dirs + (ext.library_dirs or [])
super().run()
self.extensions = all_extensions

Expand Down
102 changes: 56 additions & 46 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
import os
from pathlib import Path

from typing import List

import setuptools

from .utils import all_files_in_dir, cuda_version, get_cuda_include_dirs, debug_build_enabled
from typing import List
from .utils import all_files_in_dir, get_cuda_include_dirs, debug_build_enabled


def install_requirements() -> List[str]:
"""Install dependencies for TE/PyTorch extensions."""
return ["torch>=2.1", "einops", "onnxscript", "onnx", "packaging", "pydantic", "nvdlfw-inspect"]
return ["torch>=2.6", "einops", "onnxscript", "onnx", "packaging", "pydantic", "nvdlfw-inspect"]


def test_requirements() -> List[str]:
Expand All @@ -29,74 +30,83 @@ def test_requirements() -> List[str]:
]


def setup_pytorch_extension(
def setup_pytorch_stable_extension(
csrc_source_files,
csrc_header_files,
common_header_files,
) -> setuptools.Extension:
"""Setup CUDA extension for PyTorch support"""
"""Setup stable ABI extension for PyTorch support.

# Source files
sources = all_files_in_dir(Path(csrc_source_files), name_extension="cpp")
This extension uses only the PyTorch stable ABI (torch/csrc/stable/),
producing a binary that is compatible across PyTorch versions.
It does NOT use CppExtension to avoid pulling in unstable ATen headers.
"""
import torch

# Header files
# Source files from csrc/extensions/ directory
stable_dir = Path(csrc_source_files) / "extensions"
sources = all_files_in_dir(stable_dir, name_extension="cpp")
if not sources:
return None

# Include directories
include_dirs = get_cuda_include_dirs()
include_dirs.extend(
[
common_header_files,
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
# PyTorch headers (for stable ABI only)
Path(torch.utils.cmake_prefix_path).parent.parent / "include",
]
)

# Compiler flags
cxx_flags = ["-O3", "-fvisibility=hidden"]
cxx_flags = ["-O3", "-fvisibility=hidden", "-std=c++17", "-DUSE_CUDA"]
if bool(int(os.environ.get("NVTE_ENABLE_NVSHMEM", "0"))):
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")
nvshmem_home = os.environ.get("NVSHMEM_HOME", "")
if nvshmem_home:
include_dirs.append(Path(nvshmem_home) / "include")
# Try system NVSHMEM paths (Debian/Ubuntu packages)
for nvshmem_inc in ["/usr/include/nvshmem_13", "/usr/local/include/nvshmem"]:
if os.path.isdir(nvshmem_inc):
include_dirs.append(Path(nvshmem_inc))
break
if debug_build_enabled():
cxx_flags.append("-g")
cxx_flags.append("-UNDEBUG")
else:
cxx_flags.append("-g0")

# Version-dependent CUDA options
try:
version = cuda_version()
except FileNotFoundError:
print("Could not determine CUDA version")
else:
if version < (12, 0):
raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")

if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert (
os.getenv("MPI_HOME") is not None
), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!"
mpi_path = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_path / "include")
cxx_flags.append("-DNVTE_UB_WITH_MPI")

library_dirs = []
libraries = []
if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))):
assert (
os.getenv("NVSHMEM_HOME") is not None
), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
nvshmem_home = Path(os.getenv("NVSHMEM_HOME"))
include_dirs.append(nvshmem_home / "include")
library_dirs.append(nvshmem_home / "lib")
libraries.append("nvshmem_host")
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")
# Library directories and libraries
# Find the TE common library (libtransformer_engine.so)
te_lib_dir = Path(csrc_source_files).parent.parent.parent
cuda_home = os.environ.get("CUDA_HOME", os.environ.get("CUDA_PATH", "/usr/local/cuda"))
cuda_lib_dir = os.path.join(cuda_home, "lib64")
if not os.path.isdir(cuda_lib_dir):
cuda_lib_dir = os.path.join(cuda_home, "lib")
library_dirs = [
str(Path(torch.utils.cmake_prefix_path).parent.parent / "lib"),
str(te_lib_dir),
cuda_lib_dir,
]
libraries = ["torch", "torch_cpu", "c10", "cudart", "transformer_engine"]

# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
from torch.utils.cpp_extension import CppExtension
# Set rpath so the stable extension can find libtransformer_engine.so at runtime.
# Use $ORIGIN for co-located libraries plus the absolute path for editable installs.
extra_link_args = [
"-Wl,-rpath,$ORIGIN",
f"-Wl,-rpath,{te_lib_dir.resolve()}",
]

return CppExtension(
name="transformer_engine_torch",
return setuptools.Extension(
name="transformer_engine.te_stable_abi",
sources=[str(src) for src in sources],
include_dirs=[str(inc) for inc in include_dirs],
extra_compile_args={"cxx": cxx_flags},
libraries=[str(lib) for lib in libraries],
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
extra_compile_args=cxx_flags,
libraries=libraries,
library_dirs=library_dirs,
extra_link_args=extra_link_args,
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# See LICENSE for license information.

[build-system]
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"]
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.6", "jax>=0.5.0", "flax>=0.7.1"]

# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"
21 changes: 21 additions & 0 deletions qa/L1_pytorch_thunder_integration/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

set -x

: ${THUNDER_PATH:=/opt/pytorch/lightning-thunder}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"

pip3 install pytest==8.1.1 pytest-benchmark==5.1.0
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml ${THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py

# Check return code
# Note: Return code 5 is fine. Lightning tests are skipped on systems
# without FP8 support and Pytest returns 5 if no tests are run.
RC=$?
if [ ${RC} -eq 5 ]; then
RC=0
fi
exit ${RC}
14 changes: 7 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,15 @@ def git_check_submodules() -> None:

if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
from build_tools.pytorch import setup_pytorch_extension
from build_tools.pytorch import setup_pytorch_stable_extension

ext_modules.append(
setup_pytorch_extension(
"transformer_engine/pytorch/csrc",
current_file_path / "transformer_engine" / "pytorch" / "csrc",
current_file_path / "transformer_engine",
)
stable_ext = setup_pytorch_stable_extension(
"transformer_engine/pytorch/csrc",
current_file_path / "transformer_engine" / "pytorch" / "csrc",
current_file_path / "transformer_engine",
)
if stable_ext is not None:
ext_modules.append(stable_ext)
if "jax" in frameworks:
from build_tools.jax import setup_jax_extension

Expand Down
34 changes: 7 additions & 27 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
ScaledTensor1x,
ScaledTensor2x,
GroupedScaledTensor1x,
GroupedNoScaleTensor,
ScalingMode,
QuantizerFactory,
QuantizeLayout,
Expand Down Expand Up @@ -151,13 +150,8 @@ def assert_dequantized_grouped_scaled_tensor(
a: Union[GroupedScaledTensor1x, ScaledTensor2x], b: jnp.ndarray
):
if isinstance(a, GroupedScaledTensor1x):
group_sizes = (
a.first_dims
if a.first_dims is not None
else jnp.ones(a.original_shape[0], dtype=jnp.int32)
)
assert group_sizes.sum() == b.shape[0]
b = jnp.split(b, jnp.cumulative_sum(group_sizes)[:-1], axis=0)
assert a.group_sizes.sum() == b.shape[0]
b = jnp.split(b, jnp.cumulative_sum(a.group_sizes)[:-1], axis=0)
dq_a = a.dequantize()
for dq_a_i, b_i in zip(dq_a, b):
if len(dq_a_i) == 0:
Expand Down Expand Up @@ -1793,18 +1787,13 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout):
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)

# jitting grouped_gemm
lhs_tensor = GroupedNoScaleTensor(
data=lhs, amax=None, first_dims=group_sizes, last_dims=None, original_shape=lhs.shape
)
rhs_tensor = GroupedNoScaleTensor(
data=rhs, amax=None, first_dims=None, last_dims=None, original_shape=rhs.shape
)
prim_out = jax.jit(
tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes")
)(
lhs_tensor,
rhs_tensor,
contracting_dims=contracting_dims,
lhs,
rhs,
group_sizes,
contracting_dims,
use_async_d2h_group_sizes=True,
)

Expand Down Expand Up @@ -1836,17 +1825,8 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout
)
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)

lhs_tensor = GroupedNoScaleTensor(
data=lhs, amax=None, first_dims=group_sizes, last_dims=None, original_shape=lhs.shape
)
rhs_tensor = GroupedNoScaleTensor(
data=rhs, amax=None, first_dims=None, last_dims=None, original_shape=rhs.shape
)
prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
lhs_tensor,
rhs_tensor,
contracting_dims=contracting_dims,
quantizer_set=quantizer_set,
lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
)

allclose_dtype = jnp.float8_e4m3fn
Expand Down
8 changes: 5 additions & 3 deletions tests/pytorch/test_float8_blockwise_gemm_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,9 +782,11 @@ def test_gelu_unsupported_cases_error(
is_x_1d_scaled,
is_w_1d_scaled,
) -> None:
if use_grad and not use_bias and out_dtype == torch.bfloat16:
pytest.skip("DGELU epilogue is supported for bfloat16.")
elif use_grad and not use_bias:
pytest.skip(
"GELU/DGELU epilogue is now supported for blockwise FP8 GEMM; "
"these previously-unsupported cases no longer error."
)
if use_grad and not use_bias:
expected_err = "an unsupported value or parameter was passed"
else:
expected_err = "Epilogue requested outside of the available"
Expand Down
Loading
Loading