Skip to content

Add PyTorch stable ABI extension for libtorch compatibility#2813

Draft
pstjohn wants to merge 5 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/libtorch-stable-abi-clean
Draft

Add PyTorch stable ABI extension for libtorch compatibility#2813
pstjohn wants to merge 5 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/libtorch-stable-abi-clean

Conversation

@pstjohn
Copy link
Copy Markdown
Contributor

@pstjohn pstjohn commented Mar 30, 2026

Implement a stable ABI layer that replaces the pybind11-based C++ extension with torch::Library-registered operations using torch::stable::Tensor. This allows the PyTorch extension to be built once and work across multiple Python/PyTorch versions without recompilation.

Key changes:

  • Add csrc/stable/ with 20 C++ files implementing all TE ops via stable ABI
  • Add _stable_torch_module.py as the Python-side module replacing pybind11
  • Add _stable_ops.py and _tex.py shims for backward compatibility
  • Add tensor extraction and stable quantization utilities
  • Update build system to compile the stable extension separately
  • Add .gitignore for build-time artifact directories

@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented Mar 30, 2026

Why would we need to have a separate stable directory for those files? If it works, we would want to use it by default, no? Also, then maybe the diff would be easier to understand and review?

@pstjohn pstjohn force-pushed the pstjohn/libtorch-stable-abi-clean branch from 5977a83 to c7d2553 Compare March 30, 2026 18:53
Replace the pybind11-based C++ extension (transformer_engine_torch)
with a stable ABI layer using torch::Library-registered operations
and torch::stable::Tensor. This allows the extension to be built once
and work across multiple PyTorch versions without recompilation.

C++ changes (csrc/extensions/):
- Replace all extension files with stable ABI implementations using
  torch::stable::Tensor instead of at::Tensor
- Add stable_common.h with helper utilities for the stable ABI layer
- Add registration.cpp for torch::Library op registration
- Consolidate multi_tensor/*.cpp into single multi_tensor.cpp
- Consolidate fp8/nvfp4 partial cast into partial_cast.cpp
- Add grouped_gemm.cpp for grouped GEMM operations
- Remove pybind11 bindings (pybind.cpp, quantizer.cpp, common.cpp, etc.)
- Add quantize_bidirectional for fused rowwise+columnwise quantization

Python changes:
- Add _stable_torch_module.py as Python-side module replacing pybind11
- Add _stable_ops.py and _tex.py shims for backward compatibility
- Add tensor extraction and stable quantization utilities
- Fix distributed amax reduction for NVFP4 and FP8 current scaling
- Update __init__.py to wire in the stable module

Build system:
- Replace setup_pytorch_extension with setup_pytorch_stable_extension
  that uses stable ABI headers and avoids ATen/pybind11 dependencies
- Add .gitignore for build-time artifact directories

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@pstjohn pstjohn force-pushed the pstjohn/libtorch-stable-abi-clean branch from c7d2553 to 64a7124 Compare March 30, 2026 18:58
@pstjohn
Copy link
Copy Markdown
Contributor Author

pstjohn commented Mar 30, 2026

Why would we need to have a separate stable directory for those files? If it works, we would want to use it by default, no? Also, then maybe the diff would be easier to understand and review?

I think it's going to be a large diff, but I'll definitely try to cut it back once i have all the tests passing locally

pstjohn and others added 4 commits March 30, 2026 12:46
The stable ABI quantize path does not swizzle MXFP8 scales during
quantization, but the optimize_for_gemm flag on MXFP8Quantizer was
causing _with_gemm_swizzled_scales=True to be set on output tensors.
This made the GEMM skip the on-the-fly swizzle, producing wrong
results with unswizzled scales.

Fix: always set _with_gemm_swizzled_scales=False after quantization
in the stable path, and in MXFP8Quantizer.make_empty(). The GEMM
C++ code will swizzle on-the-fly when it sees the False flag.

Fixes test_numerics[mxfp8] test_linear(column, sequence_parallel=True).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Three related MXFP8 columnwise fixes:

1. Wire quantize_bidirectional for MXFP8 quantization: when both
   rowwise and columnwise buffers are allocated, use the fused
   bidirectional kernel to fill both in one nvte_quantize_v2 call.

2. Force MXFP8 quantizers to always allocate columnwise when rowwise
   is enabled (override set_usage), since GEMM with NT/NN layouts
   requires columnwise data.

3. Create MXFP8 columnwise data on-the-fly in the GEMM wrapper when
   it's missing (e.g., for tensors produced by GEMM+GELU fusion that
   only have rowwise data). Dequantizes rowwise and re-quantizes
   bidirectionally to produce the correct per-column-block scales.

Fixes gradient_accumulation_fusion test with MXFP8 and
test_numerics[mxfp8] row-parallel tests.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Lint fixes:
- Rename 'input' parameter to 'inp' in _stable_ops.py to avoid
  shadowing built-in
- Add missing docstrings and pylint disable comments
- Remove duplicate ctypes/glob imports in _stable_torch_module.py
- Fix unused variables, self-assignment, unnecessary else/elif
- Use enumerate() instead of range(len())
- Use dict literals instead of dict()
- Fix unbalanced tuple unpacking
- Fix import order in __init__.py
- Add pylint disable for wildcard import in _tex.py

NVFP4 bidirectional quantization:
- Use quantize_bidirectional for NVFP4 (not just MXFP8) to produce
  correct columnwise data with independent per-block scales
- Handle NVFP4 columnwise-only tensors in backward re-quantization
- Fix FP4 packed shape handling in quantize_bidirectional C++ code
- Pass nvfp4_2d_quantization flag through to bidirectional kernel

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants