Add PyTorch stable ABI extension for libtorch compatibility#2813
Draft
pstjohn wants to merge 5 commits intoNVIDIA:mainfrom
Draft
Add PyTorch stable ABI extension for libtorch compatibility#2813pstjohn wants to merge 5 commits intoNVIDIA:mainfrom
pstjohn wants to merge 5 commits intoNVIDIA:mainfrom
Conversation
Member
|
Why would we need to have a separate stable directory for those files? If it works, we would want to use it by default, no? Also, then maybe the diff would be easier to understand and review? |
5977a83 to
c7d2553
Compare
Replace the pybind11-based C++ extension (transformer_engine_torch) with a stable ABI layer using torch::Library-registered operations and torch::stable::Tensor. This allows the extension to be built once and work across multiple PyTorch versions without recompilation. C++ changes (csrc/extensions/): - Replace all extension files with stable ABI implementations using torch::stable::Tensor instead of at::Tensor - Add stable_common.h with helper utilities for the stable ABI layer - Add registration.cpp for torch::Library op registration - Consolidate multi_tensor/*.cpp into single multi_tensor.cpp - Consolidate fp8/nvfp4 partial cast into partial_cast.cpp - Add grouped_gemm.cpp for grouped GEMM operations - Remove pybind11 bindings (pybind.cpp, quantizer.cpp, common.cpp, etc.) - Add quantize_bidirectional for fused rowwise+columnwise quantization Python changes: - Add _stable_torch_module.py as Python-side module replacing pybind11 - Add _stable_ops.py and _tex.py shims for backward compatibility - Add tensor extraction and stable quantization utilities - Fix distributed amax reduction for NVFP4 and FP8 current scaling - Update __init__.py to wire in the stable module Build system: - Replace setup_pytorch_extension with setup_pytorch_stable_extension that uses stable ABI headers and avoids ATen/pybind11 dependencies - Add .gitignore for build-time artifact directories Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
c7d2553 to
64a7124
Compare
Contributor
Author
I think it's going to be a large diff, but I'll definitely try to cut it back once i have all the tests passing locally |
The stable ABI quantize path does not swizzle MXFP8 scales during quantization, but the optimize_for_gemm flag on MXFP8Quantizer was causing _with_gemm_swizzled_scales=True to be set on output tensors. This made the GEMM skip the on-the-fly swizzle, producing wrong results with unswizzled scales. Fix: always set _with_gemm_swizzled_scales=False after quantization in the stable path, and in MXFP8Quantizer.make_empty(). The GEMM C++ code will swizzle on-the-fly when it sees the False flag. Fixes test_numerics[mxfp8] test_linear(column, sequence_parallel=True). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Three related MXFP8 columnwise fixes: 1. Wire quantize_bidirectional for MXFP8 quantization: when both rowwise and columnwise buffers are allocated, use the fused bidirectional kernel to fill both in one nvte_quantize_v2 call. 2. Force MXFP8 quantizers to always allocate columnwise when rowwise is enabled (override set_usage), since GEMM with NT/NN layouts requires columnwise data. 3. Create MXFP8 columnwise data on-the-fly in the GEMM wrapper when it's missing (e.g., for tensors produced by GEMM+GELU fusion that only have rowwise data). Dequantizes rowwise and re-quantizes bidirectionally to produce the correct per-column-block scales. Fixes gradient_accumulation_fusion test with MXFP8 and test_numerics[mxfp8] row-parallel tests. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Lint fixes: - Rename 'input' parameter to 'inp' in _stable_ops.py to avoid shadowing built-in - Add missing docstrings and pylint disable comments - Remove duplicate ctypes/glob imports in _stable_torch_module.py - Fix unused variables, self-assignment, unnecessary else/elif - Use enumerate() instead of range(len()) - Use dict literals instead of dict() - Fix unbalanced tuple unpacking - Fix import order in __init__.py - Add pylint disable for wildcard import in _tex.py NVFP4 bidirectional quantization: - Use quantize_bidirectional for NVFP4 (not just MXFP8) to produce correct columnwise data with independent per-block scales - Handle NVFP4 columnwise-only tensors in backward re-quantization - Fix FP4 packed shape handling in quantize_bidirectional C++ code - Pass nvfp4_2d_quantization flag through to bidirectional kernel Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Implement a stable ABI layer that replaces the pybind11-based C++ extension with torch::Library-registered operations using torch::stable::Tensor. This allows the PyTorch extension to be built once and work across multiple Python/PyTorch versions without recompilation.
Key changes: