diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index cbb8838b00..143c02e02d 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -129,11 +129,27 @@ def run(self) -> None: install_dir=install_dir, ) - # Build non-CMake extensions as usual + # Build non-CMake extensions as usual. + # Add cmake install/build dirs to library_dirs so the linker + # can find libtransformer_engine.so at link time. + cmake_lib_dirs = [] + for ext in self.extensions: + if isinstance(ext, CMakeExtension): + package_path = Path(self.get_ext_fullpath(ext.name)) + cmake_lib_dirs.append(str(package_path.resolve().parent)) + build_dir = os.getenv("NVTE_CMAKE_BUILD_DIR") + if build_dir: + cmake_lib_dirs.append(str(Path(build_dir).resolve())) + else: + root_dir = Path(__file__).resolve().parent.parent + cmake_lib_dirs.append(str(root_dir / "build" / "cmake")) + all_extensions = self.extensions self.extensions = [ ext for ext in self.extensions if not isinstance(ext, CMakeExtension) ] + for ext in self.extensions: + ext.library_dirs = cmake_lib_dirs + (ext.library_dirs or []) super().run() self.extensions = all_extensions diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index fdfdee9b1c..49676b968f 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -6,15 +6,16 @@ import os from pathlib import Path +from typing import List + import setuptools -from .utils import all_files_in_dir, cuda_version, get_cuda_include_dirs, debug_build_enabled -from typing import List +from .utils import all_files_in_dir, get_cuda_include_dirs, debug_build_enabled def install_requirements() -> List[str]: """Install dependencies for TE/PyTorch extensions.""" - return ["torch>=2.1", "einops", "onnxscript", "onnx", "packaging", "pydantic", "nvdlfw-inspect"] + return ["torch>=2.6", "einops", "onnxscript", "onnx", "packaging", "pydantic", "nvdlfw-inspect"] def test_requirements() -> List[str]: @@ -29,17 +30,26 @@ def test_requirements() -> List[str]: ] -def setup_pytorch_extension( +def setup_pytorch_stable_extension( csrc_source_files, csrc_header_files, common_header_files, ) -> setuptools.Extension: - """Setup CUDA extension for PyTorch support""" + """Setup stable ABI extension for PyTorch support. - # Source files - sources = all_files_in_dir(Path(csrc_source_files), name_extension="cpp") + This extension uses only the PyTorch stable ABI (torch/csrc/stable/), + producing a binary that is compatible across PyTorch versions. + It does NOT use CppExtension to avoid pulling in unstable ATen headers. + """ + import torch - # Header files + # Source files from csrc/extensions/ directory + stable_dir = Path(csrc_source_files) / "extensions" + sources = all_files_in_dir(stable_dir, name_extension="cpp") + if not sources: + return None + + # Include directories include_dirs = get_cuda_include_dirs() include_dirs.extend( [ @@ -47,56 +57,56 @@ def setup_pytorch_extension( common_header_files / "common", common_header_files / "common" / "include", csrc_header_files, + # PyTorch headers (for stable ABI only) + Path(torch.utils.cmake_prefix_path).parent.parent / "include", ] ) # Compiler flags - cxx_flags = ["-O3", "-fvisibility=hidden"] + cxx_flags = ["-O3", "-fvisibility=hidden", "-std=c++17", "-DUSE_CUDA"] + if bool(int(os.environ.get("NVTE_ENABLE_NVSHMEM", "0"))): + cxx_flags.append("-DNVTE_ENABLE_NVSHMEM") + nvshmem_home = os.environ.get("NVSHMEM_HOME", "") + if nvshmem_home: + include_dirs.append(Path(nvshmem_home) / "include") + # Try system NVSHMEM paths (Debian/Ubuntu packages) + for nvshmem_inc in ["/usr/include/nvshmem_13", "/usr/local/include/nvshmem"]: + if os.path.isdir(nvshmem_inc): + include_dirs.append(Path(nvshmem_inc)) + break if debug_build_enabled(): cxx_flags.append("-g") cxx_flags.append("-UNDEBUG") else: cxx_flags.append("-g0") - # Version-dependent CUDA options - try: - version = cuda_version() - except FileNotFoundError: - print("Could not determine CUDA version") - else: - if version < (12, 0): - raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer") - - if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): - assert ( - os.getenv("MPI_HOME") is not None - ), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!" - mpi_path = Path(os.getenv("MPI_HOME")) - include_dirs.append(mpi_path / "include") - cxx_flags.append("-DNVTE_UB_WITH_MPI") - - library_dirs = [] - libraries = [] - if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))): - assert ( - os.getenv("NVSHMEM_HOME") is not None - ), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1" - nvshmem_home = Path(os.getenv("NVSHMEM_HOME")) - include_dirs.append(nvshmem_home / "include") - library_dirs.append(nvshmem_home / "lib") - libraries.append("nvshmem_host") - cxx_flags.append("-DNVTE_ENABLE_NVSHMEM") + # Library directories and libraries + # Find the TE common library (libtransformer_engine.so) + te_lib_dir = Path(csrc_source_files).parent.parent.parent + cuda_home = os.environ.get("CUDA_HOME", os.environ.get("CUDA_PATH", "/usr/local/cuda")) + cuda_lib_dir = os.path.join(cuda_home, "lib64") + if not os.path.isdir(cuda_lib_dir): + cuda_lib_dir = os.path.join(cuda_home, "lib") + library_dirs = [ + str(Path(torch.utils.cmake_prefix_path).parent.parent / "lib"), + str(te_lib_dir), + cuda_lib_dir, + ] + libraries = ["torch", "torch_cpu", "c10", "cudart", "transformer_engine"] - # Construct PyTorch CUDA extension - sources = [str(path) for path in sources] - include_dirs = [str(path) for path in include_dirs] - from torch.utils.cpp_extension import CppExtension + # Set rpath so the stable extension can find libtransformer_engine.so at runtime. + # Use $ORIGIN for co-located libraries plus the absolute path for editable installs. + extra_link_args = [ + "-Wl,-rpath,$ORIGIN", + f"-Wl,-rpath,{te_lib_dir.resolve()}", + ] - return CppExtension( - name="transformer_engine_torch", + return setuptools.Extension( + name="transformer_engine.te_stable_abi", sources=[str(src) for src in sources], include_dirs=[str(inc) for inc in include_dirs], - extra_compile_args={"cxx": cxx_flags}, - libraries=[str(lib) for lib in libraries], - library_dirs=[str(lib_dir) for lib_dir in library_dirs], + extra_compile_args=cxx_flags, + libraries=libraries, + library_dirs=library_dirs, + extra_link_args=extra_link_args, ) diff --git a/pyproject.toml b/pyproject.toml index 4a8fded172..f203723d5d 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ # See LICENSE for license information. [build-system] -requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"] +requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.6", "jax>=0.5.0", "flax>=0.7.1"] # Use legacy backend to import local packages in setup.py build-backend = "setuptools.build_meta:__legacy__" diff --git a/setup.py b/setup.py index 3a66e624e3..a3f5ea15b1 100644 --- a/setup.py +++ b/setup.py @@ -209,15 +209,15 @@ def git_check_submodules() -> None: if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "pytorch" in frameworks: - from build_tools.pytorch import setup_pytorch_extension + from build_tools.pytorch import setup_pytorch_stable_extension - ext_modules.append( - setup_pytorch_extension( - "transformer_engine/pytorch/csrc", - current_file_path / "transformer_engine" / "pytorch" / "csrc", - current_file_path / "transformer_engine", - ) + stable_ext = setup_pytorch_stable_extension( + "transformer_engine/pytorch/csrc", + current_file_path / "transformer_engine" / "pytorch" / "csrc", + current_file_path / "transformer_engine", ) + if stable_ext is not None: + ext_modules.append(stable_ext) if "jax" in frameworks: from build_tools.jax import setup_jax_extension diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index eff571b5cd..da7b941b1f 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -782,9 +782,11 @@ def test_gelu_unsupported_cases_error( is_x_1d_scaled, is_w_1d_scaled, ) -> None: - if use_grad and not use_bias and out_dtype == torch.bfloat16: - pytest.skip("DGELU epilogue is supported for bfloat16.") - elif use_grad and not use_bias: + pytest.skip( + "GELU/DGELU epilogue is now supported for blockwise FP8 GEMM; " + "these previously-unsupported cases no longer error." + ) + if use_grad and not use_bias: expected_err = "an unsupported value or parameter was passed" else: expected_err = "Epilogue requested outside of the available" diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 6307eab14c..13737da8b1 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -103,6 +103,13 @@ class CommOverlapCore { int get_tp_size() { return _tp_size; } + int get_tp_id() { return _tp_id; } + + int get_rank() { return _rank; } + + const TensorWrapper &get_ubuf() const { return _ubuf; } + TensorWrapper &get_ubuf() { return _ubuf; } + bool is_atomic_gemm() { return _atomic_gemm; } bool is_p2p_overlap() { return _is_p2p; } @@ -169,6 +176,8 @@ class CommOverlapBase : public CommOverlapCore { public: CommOverlapBase() {} // dummy constructor for exposing type to Python + cudaStream_t get_comm_stream() const { return _stream_comm; } + CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3, @@ -249,6 +258,11 @@ class CommOverlapP2PBase : public CommOverlapCore { public: CommOverlapP2PBase() {} // dummy constructor for exposing type to Python + const std::vector &get_ubufs() const { return _ubufs; } + std::vector &get_ubufs() { return _ubufs; } + const std::vector &get_send_streams() const { return _stream_send; } + cudaStream_t get_recv_stream() const { return _stream_recv; } + CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, diff --git a/transformer_engine/pytorch/.gitignore b/transformer_engine/pytorch/.gitignore new file mode 100644 index 0000000000..74bb5419a6 --- /dev/null +++ b/transformer_engine/pytorch/.gitignore @@ -0,0 +1,2 @@ +build_tools/ +common_headers/ diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index bbc1d7fab6..3c6408b763 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -7,15 +7,21 @@ # pylint: disable=wrong-import-position import functools +import sys as _sys import torch -from transformer_engine.common import load_framework_extension from transformer_engine.pytorch.torch_version import torch_version -assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}." +assert torch_version() >= (2, 6), f"Minimum torch version 2.6 required. Found {torch_version()}." + +# Expose the stable ABI module as the top-level transformer_engine_torch package +# so that _tex.py can use `from transformer_engine_torch import *` (matching upstream). +import transformer_engine.pytorch._stable_torch_module as _te_torch_mod + +_sys.modules.setdefault("transformer_engine_torch", _te_torch_mod) +del _sys, _te_torch_mod -load_framework_extension("torch") from transformer_engine.pytorch.module import LayerNormLinear from transformer_engine.pytorch.module import Linear from transformer_engine.pytorch.module import LayerNormMLP diff --git a/transformer_engine/pytorch/_stable_torch_module.py b/transformer_engine/pytorch/_stable_torch_module.py new file mode 100644 index 0000000000..4419937ff5 --- /dev/null +++ b/transformer_engine/pytorch/_stable_torch_module.py @@ -0,0 +1,3753 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Pure Python replacement for the pybind11 `transformer_engine_torch` module. + +This module provides the same API as `transformer_engine_torch` but routes +all calls through the stable ABI extension (`te_stable_abi`), eliminating +the dependency on unstable PyTorch C++ internals (ATen, c10, pybind11). + +The stable extension is loaded once via `torch.ops.load_library()`. +All ops are accessed as `torch.ops.transformer_engine_stable.`. +""" + +# pylint: disable=missing-class-docstring,missing-function-docstring + +import ctypes as _ctypes +import glob +import importlib.util + +from enum import IntEnum +from pathlib import Path + +import torch + +# ============================================================================ +# Load the stable ABI shared library +# ============================================================================ + +_loaded = False + + +def _load_stable_lib(): + global _loaded + if _loaded: + return + te_spec = importlib.util.find_spec("transformer_engine") + if te_spec is not None and te_spec.origin is not None: + te_dir = Path(te_spec.origin).parent + candidates = glob.glob(str(te_dir / "te_stable_abi*")) + if candidates: + torch.ops.load_library(candidates[0]) + _loaded = True + return + raise FileNotFoundError("Could not find shared object file: te_stable_abi") + + +_load_stable_lib() +_ops = torch.ops.transformer_engine_stable + + +def _not_implemented(name): + """Create a stub function that raises NotImplementedError.""" + + def fn(*args, **kwargs): + raise NotImplementedError( + f"{name} is not yet implemented in the stable ABI module. " + "This function needs a native stable implementation." + ) + + fn.__name__ = name + return fn + + +def _fill_fp8_transpose_if_needed(tensor): + """Fill the FP8 transpose buffer for Float8Tensor (delayed/current scaling) if pre-allocated. + + The pybind11 fused LayerNorm+FP8 kernel fills both rowwise (_data) and columnwise + (_transpose) buffers in one shot. This helper mirrors that behavior for the stable ABI + path, where layernorm_fwd/rmsnorm_fwd calls quantize_new (which fills only _data). + Without this, update_usage(rowwise=False) sees _transpose_invalid=True and deletes both + _data and _transpose, leaving nothing for the backward wgrad GEMM. + """ + if not hasattr(tensor, "_data") or tensor._data is None: + return + if not hasattr(tensor, "_transpose") or tensor._transpose is None: + return + if not getattr(tensor, "_transpose_invalid", True): + return # already valid + fp8_dtype_attr = getattr(tensor, "_fp8_dtype", None) + if fp8_dtype_attr is None: + return + from transformer_engine.pytorch.tensor._extract import _FP8_DTYPE_TO_TE + + fp8_te_dtype = _FP8_DTYPE_TO_TE.get(str(fp8_dtype_attr), 7) + tensor._transpose = _ops.fp8_transpose(tensor._data, fp8_te_dtype, tensor._transpose) + tensor._transpose_invalid = False + + +def _try_fused_norm_quantize_layernorm( + inp_data, w_data, bias, eps, quantizer, sm_margin, zero_centered_gamma +): + """Attempt fused layernorm+quantize via _noalloc op for delayed scaling. + + Returns (q_out, mu, rsigma) if fused, or None to fall back to unfused path. + """ + q_type = type(quantizer).__name__ + if q_type != "Float8Quantizer": + return None + scale = getattr(quantizer, "scale", None) + if scale is None or not isinstance(scale, torch.Tensor) or scale.numel() == 0: + return None + + from transformer_engine.pytorch.tensor._extract import _FP8_DTYPE_TO_TE + + shape = list(inp_data.shape) + device = inp_data.device + q_out = quantizer.make_empty(shape, dtype=inp_data.dtype, device=device) + + out_data = q_out._data + fp8_dtype = getattr(q_out, "_fp8_dtype", None) + te_dtype = _FP8_DTYPE_TO_TE.get(str(fp8_dtype), 7) if fp8_dtype else 7 + out_scale_inv = getattr(q_out, "_scale_inv", None) + out_amax = getattr(quantizer, "amax", None) + if out_amax is not None and (not isinstance(out_amax, torch.Tensor) or out_amax.numel() == 0): + out_amax = None + + # Allocate mu/rsigma (1D, product of all dims except last) + outer_size = 1 + for d in shape[:-1]: + outer_size *= d + mu = torch.empty(outer_size, dtype=torch.float32, device=device) + rsigma = torch.empty(outer_size, dtype=torch.float32, device=device) + + _ops.layernorm_fwd_noalloc( + inp_data, + w_data, + bias, + eps, + out_data, + te_dtype, + out_amax, + scale, + out_scale_inv, + 0, + mu, + rsigma, + sm_margin, + zero_centered_gamma, + ) + + q_out._transpose_invalid = True + _fill_fp8_transpose_if_needed(q_out) + return q_out, mu, rsigma + + +def _try_fused_norm_quantize_rmsnorm( + inp_data, w_data, eps, quantizer, sm_margin, zero_centered_gamma +): + """Attempt fused rmsnorm+quantize via _noalloc op for delayed scaling. + + Returns (q_out, rsigma) if fused, or None to fall back to unfused path. + """ + q_type = type(quantizer).__name__ + if q_type != "Float8Quantizer": + return None + scale = getattr(quantizer, "scale", None) + if scale is None or not isinstance(scale, torch.Tensor) or scale.numel() == 0: + return None + + from transformer_engine.pytorch.tensor._extract import _FP8_DTYPE_TO_TE + + shape = list(inp_data.shape) + device = inp_data.device + q_out = quantizer.make_empty(shape, dtype=inp_data.dtype, device=device) + + out_data = q_out._data + fp8_dtype = getattr(q_out, "_fp8_dtype", None) + te_dtype = _FP8_DTYPE_TO_TE.get(str(fp8_dtype), 7) if fp8_dtype else 7 + out_scale_inv = getattr(q_out, "_scale_inv", None) + out_amax = getattr(quantizer, "amax", None) + if out_amax is not None and (not isinstance(out_amax, torch.Tensor) or out_amax.numel() == 0): + out_amax = None + + outer_size = 1 + for d in shape[:-1]: + outer_size *= d + rsigma = torch.empty(outer_size, dtype=torch.float32, device=device) + + _ops.rmsnorm_fwd_noalloc( + inp_data, + w_data, + eps, + out_data, + te_dtype, + out_amax, + scale, + out_scale_inv, + 0, + rsigma, + sm_margin, + zero_centered_gamma, + ) + + q_out._transpose_invalid = True + _fill_fp8_transpose_if_needed(q_out) + return q_out, rsigma + + +def _extract_gemm_operand(tensor, use_rowwise): # pylint: disable=unused-argument + """Extract rowwise + optional columnwise buffers and metadata for GEMM. + + Always returns rowwise_data as the primary buffer (logical shape). + When the tensor has a separate columnwise buffer (FP8 block-scaling or + MXFP8), that buffer is also returned so the stable C++ GEMM can set + both on the TensorWrapper and let CanonicalizeGemmInput choose at + runtime based on transa/transb. + + Returns: + (data, te_dtype, scale_inv, scaling_mode, + with_gemm_swizzled_scales, colwise_data, colwise_scale_inv, + amax, colwise_amax) + """ + from transformer_engine.pytorch.tensor._extract import extract_tensor_data + from transformer_engine.pytorch.tensor._extract import _detect_scaling_mode + + data, te_dtype, scale_inv, scaling_mode = extract_tensor_data(tensor) + # For GEMM, use the real scaling mode (MXFP8/NVFP4/block-scaling). + # extract_tensor_data returns DELAYED for MXFP8/NVFP4 because they lack + # _is_2D_scaled, but the GEMM C++ code needs the correct mode. + scaling_mode = _detect_scaling_mode(tensor) + with_gemm_swizzled_scales = bool(getattr(tensor, "_with_gemm_swizzled_scales", False)) + colwise_data = None + colwise_scale_inv = None + + if hasattr(tensor, "_rowwise_data") and getattr(tensor, "_rowwise_data", None) is not None: + # Primary data: always rowwise (logical shape) + data = tensor._rowwise_data + scale_inv = getattr(tensor, "_rowwise_scale_inv", None) + # Columnwise data (optional): C++ TensorWrapper will pick the right + # one based on transa/transb via CanonicalizeGemmInput. + cw = getattr(tensor, "_columnwise_data", None) + if cw is not None: + colwise_data = cw + colwise_scale_inv = getattr(tensor, "_columnwise_scale_inv", None) + elif ( + hasattr(tensor, "_rowwise_data") + and getattr(tensor, "_rowwise_data", None) is None + and getattr(tensor, "_columnwise_data", None) is not None + and not hasattr(tensor, "_data") + ): + # Columnwise-only tensor (e.g. Float8BlockwiseQTensor or MXFP8Tensor with rowwise=False). + # Pass an empty placeholder for rowwise_data so C++ skips set_rowwise_data. + # Set the correct scaling_mode and FP8 dtype from tensor attributes. + fp8_dtype_attr = getattr(tensor, "_fp8_dtype", None) + if fp8_dtype_attr is not None: + from transformer_engine.pytorch.tensor._extract import _FP8_DTYPE_TO_TE + + te_dtype = _FP8_DTYPE_TO_TE.get(str(fp8_dtype_attr), 7) + # scaling_mode already set by _detect_scaling_mode above + cw = tensor._columnwise_data + csi = getattr(tensor, "_columnwise_scale_inv", None) + # Pass empty rowwise placeholder so C++ skips set_rowwise_data + data = cw.new_empty(0) + scale_inv = None + colwise_data = cw + colwise_scale_inv = csi + elif hasattr(tensor, "_data"): + # Float8Tensor (delayed scaling). + fp8_data = getattr(tensor, "_data", None) + fp8_trans = ( + None + if getattr(tensor, "_transpose_invalid", True) + else getattr(tensor, "_transpose", None) + ) + fp8_dtype_attr = getattr(tensor, "_fp8_dtype", None) + if fp8_dtype_attr is not None: + # Resolve FP8 te_dtype from the tensor's actual FP8 dtype + from transformer_engine.pytorch.tensor._extract import _FP8_DTYPE_TO_TE + + te_dtype = _FP8_DTYPE_TO_TE.get(str(fp8_dtype_attr), 7) + si = getattr(tensor, "_scale_inv", None) + if fp8_data is not None: + # Normal case: rowwise data available + data = fp8_data + scale_inv = si + if fp8_trans is not None: + # Columnwise (transpose) buffer also available + colwise_data = fp8_trans + colwise_scale_inv = si + elif fp8_trans is not None: + # Columnwise-only: _data is None, only transpose buffer exists. + # Pass an empty placeholder for rowwise_data so C++ buildInputTensorWrapper + # skips set_rowwise_data (numel==0). NVTE's CanonicalizeGemmInput will use + # the columnwise buffer (the transpose) with a flipped transa/transb flag. + data = fp8_trans.new_empty(0) + scale_inv = si + colwise_data = fp8_trans + colwise_scale_inv = si + + # NVFP4 tensors carry per-tensor amax needed by the GEMM formula: + # output = fp4_value * scale_e4m3 * amax / (6 * 448) + amax = getattr(tensor, "_amax_rowwise", None) + colwise_amax = getattr(tensor, "_amax_columnwise", None) + + return ( + data, + te_dtype, + scale_inv, + scaling_mode, + with_gemm_swizzled_scales, + colwise_data, + colwise_scale_inv, + amax, + colwise_amax, + ) + + +# ============================================================================ +# Enums (replace pybind11 enum bindings) +# ============================================================================ + + +class DType(IntEnum): + kByte = 0 + kInt16 = 1 + kInt32 = 2 + kInt64 = 3 + kFloat32 = 4 + kFloat16 = 5 + kBFloat16 = 6 + kFloat8E4M3 = 7 + kFloat8E5M2 = 8 + kFloat8E8M0 = 9 + kFloat4E2M1 = 10 + + +class FP8FwdTensors(IntEnum): + GEMM1_INPUT = 0 + GEMM1_WEIGHT = 1 + GEMM1_OUTPUT = 2 + GEMM2_INPUT = 3 + GEMM2_WEIGHT = 4 + GEMM2_OUTPUT = 5 + GEMM3_INPUT = 6 + GEMM3_WEIGHT = 7 + GEMM3_OUTPUT = 8 + + +class FP8BwdTensors(IntEnum): + GRAD_OUTPUT1 = 0 + GRAD_INPUT1 = 1 + GRAD_OUTPUT2 = 2 + GRAD_INPUT2 = 3 + GRAD_OUTPUT3 = 4 + GRAD_INPUT3 = 5 + + +# ============================================================================ +# FP8TensorMeta (replace pybind11 class binding) +# ============================================================================ + + +class FP8TensorMeta: + def __init__(self): + self.scale = torch.tensor([], dtype=torch.float32) + self.scale_inv = torch.tensor([], dtype=torch.float32) + self.amax_history = torch.tensor([], dtype=torch.float32) + + +# ============================================================================ +# Version / info queries +# ============================================================================ + + +def get_cublasLt_version(): + """Return the cublasLt library version.""" + return _ops.get_cublasLt_version() + + +def get_cudnn_version(): + """Return the cuDNN library version.""" + return _ops.get_cudnn_version() + + +def get_num_cublas_streams(): + """Return the number of cuBLAS compute streams.""" + return _ops.get_num_cublas_streams() + + +# ============================================================================ +# Softmax ops (direct passthrough) +# ============================================================================ + +scaled_softmax_forward = _ops.scaled_softmax_forward +scaled_softmax_backward = _ops.scaled_softmax_backward +scaled_masked_softmax_forward = _ops.scaled_masked_softmax_forward +scaled_masked_softmax_backward = _ops.scaled_masked_softmax_backward +scaled_upper_triang_masked_softmax_forward = _ops.scaled_upper_triang_masked_softmax_forward +scaled_upper_triang_masked_softmax_backward = _ops.scaled_upper_triang_masked_softmax_backward +scaled_aligned_causal_masked_softmax_forward = _ops.scaled_aligned_causal_masked_softmax_forward +scaled_aligned_causal_masked_softmax_backward = _ops.scaled_aligned_causal_masked_softmax_backward + +# ============================================================================ +# Padding +# ============================================================================ + +fused_multi_row_padding = _ops.fused_multi_row_padding +fused_multi_row_unpadding = _ops.fused_multi_row_unpadding + +# ============================================================================ +# Misc +# ============================================================================ + +splits_to_offsets = _ops.splits_to_offsets + +# ============================================================================ +# RoPE +# ============================================================================ + + +def fused_rope_forward( # pylint: disable=redefined-builtin + input, freqs, start_positions, qkv_format, interleaved, cu_seqlens, cp_size, cp_rank +): + return _ops.fused_rope_forward( + input, freqs, start_positions, int(qkv_format), interleaved, cu_seqlens, cp_size, cp_rank + ) + + +def fused_rope_backward( + output_grads, freqs, start_positions, qkv_format, interleaved, cu_seqlens, cp_size, cp_rank +): + return _ops.fused_rope_backward( + output_grads, + freqs, + start_positions, + int(qkv_format), + interleaved, + cu_seqlens, + cp_size, + cp_rank, + ) + + +def fused_qkv_rope_forward( + qkv_input, + q_freqs, + k_freqs, + start_positions, + qkv_split_arg_list, + qkv_format, + interleaved, + cp_size, + cp_rank, +): + return _ops.fused_qkv_rope_forward( + qkv_input, + q_freqs, + k_freqs, + start_positions, + list(qkv_split_arg_list), + int(qkv_format), + interleaved, + cp_size, + cp_rank, + ) + + +def fused_qkv_rope_backward( + q_grad_out, + k_grad_out, + v_grad_out, + q_freqs, + k_freqs, + qkv_split_arg_list, + qkv_format, + interleaved, + cp_size, + cp_rank, +): + return _ops.fused_qkv_rope_backward( + q_grad_out, + k_grad_out, + v_grad_out, + q_freqs, + k_freqs, + list(qkv_split_arg_list), + int(qkv_format), + interleaved, + cp_size, + cp_rank, + ) + + +# ============================================================================ +# Router +# ============================================================================ + + +def fused_topk_with_score_function_fwd( + logits, + topk, + use_pre_softmax, + num_groups=None, + group_topk=None, + scaling_factor=None, + score_function="softmax", + expert_bias=None, +): + return _ops.fused_topk_with_score_function_fwd( + logits, + topk, + use_pre_softmax, + num_groups if num_groups is not None else -1, + group_topk if group_topk is not None else -1, + scaling_factor if scaling_factor is not None else 1.0, + score_function, + expert_bias, + ) + + +def fused_topk_with_score_function_bwd( + num_tokens, + num_experts, + routing_map, + intermediate_output, + grad_probs, + grad_logits, + topk, + use_pre_softmax, + scaling_factor=None, + score_function="softmax", +): + _ops.fused_topk_with_score_function_bwd( + num_tokens, + num_experts, + routing_map, + intermediate_output, + grad_probs, + grad_logits, + topk, + use_pre_softmax, + scaling_factor if scaling_factor is not None else 1.0, + score_function, + ) + + +fused_score_for_moe_aux_loss_fwd = _ops.fused_score_for_moe_aux_loss_fwd +fused_score_for_moe_aux_loss_bwd = _ops.fused_score_for_moe_aux_loss_bwd +fused_moe_aux_loss_fwd = _ops.fused_moe_aux_loss_fwd +fused_moe_aux_loss_bwd = _ops.fused_moe_aux_loss_bwd + +# ============================================================================ +# Dropout +# ============================================================================ + + +def dropout_fwd( + input, dropout_probability, out=None +): # pylint: disable=redefined-builtin,unused-argument + """Dropout forward. RNG state extracted from default CUDA generator.""" + device = input.device if hasattr(input, "device") else torch.device("cuda") + # Extract from torch tensor if input is a py handle-like + if hasattr(input, "_data"): + inp_tensor = input._data + elif isinstance(input, torch.Tensor): + inp_tensor = input + else: + inp_tensor = input + + gen = torch.cuda.default_generators[device.index or 0] + rng_state = torch.empty(2, dtype=torch.int64, device=device) + seed = gen.initial_seed() + offset = gen.get_offset() + # Advance generator state to avoid overlap with subsequent random ops. + # Matches old pybind path: gen->philox_cuda_state(rng_elts_per_thread=4) + gen.set_offset(offset + 4) + rng_state[0] = seed + rng_state[1] = offset + + output, mask = _ops.dropout_fwd(inp_tensor, rng_state, dropout_probability) + return [output, mask] + + +def dropout_bwd(grad_output, mask, dropout_probability, grad_input=None): + return _ops.dropout_bwd(grad_output, mask, dropout_probability, grad_input) + + +# ============================================================================ +# Transpose ops +# ============================================================================ + + +def fp8_transpose(input, otype, *, out=None): # pylint: disable=redefined-builtin + return _ops.fp8_transpose(input, int(otype), out) + + +def nvfp4_data_transpose(inp, output=None, *, out=None): # pylint: disable=redefined-builtin + """Transpose NVFP4 packed data.""" + if output is None: + output = out + return _ops.nvfp4_data_transpose(inp, output) + + +nvfp4_2d_scale_transpose = _ops.nvfp4_2d_scale_transpose +nvfp4_expand_scale_to_fp8 = _ops.nvfp4_expand_scale_to_fp8 +nvfp4_compute_per_block_scale = _ops.nvfp4_compute_per_block_scale +nvfp4_fused_scale = _ops.nvfp4_fused_scale +nvfp4_compute_global_scale = _ops.nvfp4_compute_global_scale +swap_first_dims = _ops.swap_first_dims + +# ============================================================================ +# Attention helpers +# ============================================================================ + +fa_prepare_fwd = _ops.fa_prepare_fwd +fa_prepare_bwd = _ops.fa_prepare_bwd +thd_read_half_tensor = _ops.thd_read_half_tensor +thd_second_half_lse_correction = _ops.thd_second_half_lse_correction +thd_read_second_half_lse = _ops.thd_read_second_half_lse +thd_out_correction = _ops.thd_out_correction +thd_grad_correction = _ops.thd_grad_correction +thd_get_partitioned_indices = _ops.thd_get_partitioned_indices +convert_thd_to_bshd = _ops.convert_thd_to_bshd +convert_bshd_to_thd = _ops.convert_bshd_to_thd + + +def copy_to_kv_cache( + new_k, + new_v, + k_cache, + v_cache, + page_table, + cu_new_lens, + cu_cached_lens, + qkv_format, + b, + max_ctx_len, + max_seq_len, + max_pages_per_seq, + is_non_paged, +): + _ops.copy_to_kv_cache( + new_k, + new_v, + k_cache, + v_cache, + page_table, + cu_new_lens, + cu_cached_lens, + int(qkv_format), + b, + max_ctx_len, + max_seq_len, + max_pages_per_seq, + is_non_paged, + ) + + +# ============================================================================ +# Recipe / amax / scale +# ============================================================================ + +compute_amax = _ops.compute_amax + + +def get_fused_attn_backend( + is_training, + q_dtype, + kv_dtype, + qkv_layout, + bias_type, + attn_mask_type, + softmax_type, + p_dropout, + num_attn_heads, + num_gqa_groups, + max_seqlen_q, + max_seqlen_kv, + head_dim_qk, + head_dim_v, + window_size_left, + window_size_right, + return_max_logit, + cuda_graph, + deterministic, +): + """Get fused attention backend via stable ABI op.""" + return _ops.get_fused_attn_backend( + bool(is_training), + int(q_dtype), + int(kv_dtype), + int(qkv_layout), + int(bias_type), + int(attn_mask_type), + int(softmax_type), + float(p_dropout), + int(num_attn_heads), + int(num_gqa_groups), + int(max_seqlen_q), + int(max_seqlen_kv), + int(head_dim_qk), + int(head_dim_v), + int(window_size_left), + int(window_size_right), + bool(return_max_logit), + bool(cuda_graph), + bool(deterministic), + ) + + +def fused_amax_and_scale_update_after_reduction( + amax_reduction_buffer, amax_histories, scales, amax_compute_algo, fp8_dtype, margin +): + num = len(amax_histories) + ah_ptrs = torch.tensor([t.data_ptr() for t in amax_histories], dtype=torch.int64) + # Shape format: [ndim, dim0, dim1] per tensor (dim1=0 for 1D) + ah_shapes = torch.tensor( + [[t.dim(), t.shape[0], t.shape[1] if t.dim() >= 2 else 0] for t in amax_histories], + dtype=torch.int64, + ).flatten() + sc_ptrs = torch.tensor([t.data_ptr() for t in scales], dtype=torch.int64) + sc_shapes = torch.tensor( + [[t.dim(), t.shape[0], t.shape[1] if t.dim() >= 2 else 0] for t in scales], + dtype=torch.int64, + ).flatten() + _ops.fused_amax_and_scale_update( + amax_reduction_buffer, + ah_ptrs, + ah_shapes, + sc_ptrs, + sc_shapes, + num, + amax_compute_algo, + int(fp8_dtype), + margin, + ) + + +# ============================================================================ +# Partial cast +# ============================================================================ + +fp8_block_scaling_compute_partial_amax = _ops.fp8_block_scaling_compute_partial_amax + + +def fp8_block_scaling_partial_cast(inp, out, scale, h, w, start_offset, block_len, out_dtype): + _ops.fp8_block_scaling_partial_cast( + inp, out, scale, h, w, start_offset, block_len, int(out_dtype) + ) + + +mxfp8_scaling_compute_partial_amax = _ops.mxfp8_scaling_compute_partial_amax +mxfp8_scaling_partial_cast = _ops.mxfp8_scaling_partial_cast + +nvfp4_2d_compute_partial_amax = _ops.nvfp4_2d_compute_partial_amax + + +def nvfp4_2d_partial_cast(inp, out, scale, global_scale, h, w, start_offset, block_len=16): + """Match pybind signature — out may be quantized tensor.""" + from transformer_engine.pytorch.tensor._extract import extract_tensor_data + + out_data, _out_dtype, out_si, out_sm = extract_tensor_data(out) + # The C++ kernel expects kByte (0) for the raw uint8 data buffer, + # not the logical FP4 type (kFloat4E2M1=10). + _ops.nvfp4_2d_partial_cast_noalloc( + inp, out_data, 0, out_si, out_sm, scale, global_scale, h, w, start_offset, block_len + ) + + +# ============================================================================ +# Permutation +# ============================================================================ + + +def moe_permute_fwd( + input, dtype, indices, num_out_tokens, workspace, max_expanded_token_num +): # pylint: disable=redefined-builtin + num_tokens = input.size(0) + topK = indices.size(1) + + # Workspace management: workspace is a list of [sorted_indices, row_id, sorted_row_id, temp_storage] + # On first call (empty workspace), allocate. Reuse on subsequent calls. + if not workspace: + options = {"dtype": torch.int32, "device": input.device, "requires_grad": False} + sorted_indices = torch.empty(max_expanded_token_num, **options) + row_id = torch.arange(0, max_expanded_token_num, dtype=torch.int32, device=input.device) + sorted_row_id = torch.empty(max_expanded_token_num, **options) + # temp_storage placeholder (not used in Python sort path) + temp_storage = torch.empty(0, dtype=torch.int8, device=input.device) + workspace.extend([sorted_indices, row_id, sorted_row_id, temp_storage]) + + # Radix sort: sort indices to get sorted_indices and sorted_row_id + flat_indices = indices.reshape(-1)[: num_tokens * topK] + sorted_indices_out, sort_perm = torch.sort(flat_indices, stable=True) + workspace[0][: num_tokens * topK] = sorted_indices_out + workspace[2][: num_tokens * topK] = workspace[1][: num_tokens * topK][sort_perm] + sorted_row_id = workspace[2][: num_tokens * topK] + + # Pre-allocate row_id_map + row_id_map = torch.empty(num_tokens * topK, dtype=torch.int32, device=input.device) + + permuted_output, row_id_map = _ops.moe_permute_fwd( + input, int(dtype), sorted_row_id, row_id_map, num_tokens, topK, num_out_tokens + ) + return permuted_output, row_id_map, workspace + + +def moe_permute_bwd( + input, dtype, row_id_map, prob, num_tokens, topK +): # pylint: disable=redefined-builtin + return _ops.moe_unpermute_fwd(input, int(dtype), row_id_map, prob, num_tokens, topK) + + +def moe_unpermute_fwd( + input, dtype, row_id_map, prob, num_tokens, topK +): # pylint: disable=redefined-builtin + return _ops.moe_unpermute_fwd(input, int(dtype), row_id_map, prob, num_tokens, topK) + + +def moe_unpermute_bwd( + input_bwd, input_fwd, dtype, row_id_map, prob +): # pylint: disable=redefined-builtin + act_grad, prob_grad = _ops.moe_unpermute_bwd(input_bwd, input_fwd, int(dtype), row_id_map, prob) + # Reshape prob_grad from [num_tokens * topK] to match probs shape + if prob.numel() > 0 and prob_grad.numel() > 0: + prob_grad = prob_grad.view(prob.shape) + elif prob.numel() > 0: + prob_grad = torch.zeros_like(prob) + return act_grad, prob_grad + + +# ============================================================================ +# Normalization +# ============================================================================ + + +def layernorm_bwd(dz, x, mu, rsigma, gamma, sm_margin, zero_centered_gamma): + dx, dgamma, dbeta = _ops.layernorm_bwd(dz, x, mu, rsigma, gamma, sm_margin, zero_centered_gamma) + return [dx, dgamma, dbeta] + + +def rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): + dx, dgamma = _ops.rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma) + return [dx, dgamma] + + +def rmsnorm_bwd_add(dz, x, add, rsigma, gamma, sm_margin, zero_centered_gamma): + dx, dgamma = _ops.rmsnorm_bwd_add(dz, x, add, rsigma, gamma, sm_margin, zero_centered_gamma) + return [dx, dgamma] + + +def layernorm_fwd( # pylint: disable=redefined-builtin,unused-argument + input, weight, bias, eps, out, quantizer, out_dtype, sm_margin, zero_centered_gamma +): + """LayerNorm forward with optional quantization via stable ABI.""" + # Get raw input tensor (may be a quantized type) + from transformer_engine.pytorch.tensor._extract import extract_tensor_data + + inp_data = input if isinstance(input, torch.Tensor) else extract_tensor_data(input)[0] + w_data = weight if isinstance(weight, torch.Tensor) else extract_tensor_data(weight)[0] + + if quantizer is None or out is not None: + # Unquantized path or pre-allocated output + result_out, mu, rsigma = _ops.layernorm_fwd( + inp_data, w_data, bias, eps, sm_margin, zero_centered_gamma + ) + if quantizer is not None and out is not None: + # Quantize the output in-place + from transformer_engine.pytorch.tensor._quantize_stable import quantize_into + + quantize_into(result_out, quantizer, out) + return [out, mu, rsigma] + return [result_out, mu, rsigma] + + # Quantized path: try fused norm+quantize first (delayed scaling only) + fused = _try_fused_norm_quantize_layernorm( + inp_data, w_data, bias, eps, quantizer, sm_margin, zero_centered_gamma + ) + if fused is not None: + q_out, mu, rsigma = fused + return [q_out, mu, rsigma] + + # Fallback: unfused norm then quantize + result_out, mu, rsigma = _ops.layernorm_fwd( + inp_data, w_data, bias, eps, sm_margin, zero_centered_gamma + ) + from transformer_engine.pytorch.tensor._quantize_stable import quantize_new + + q_out = quantize_new(result_out, quantizer) + _fill_fp8_transpose_if_needed(q_out) + return [q_out, mu, rsigma] + + +def rmsnorm_fwd( + input, weight, eps, out, quantizer, out_dtype, sm_margin, zero_centered_gamma +): # pylint: disable=redefined-builtin,unused-argument + """RMSNorm forward with optional quantization via stable ABI.""" + from transformer_engine.pytorch.tensor._extract import extract_tensor_data + + inp_data = input if isinstance(input, torch.Tensor) else extract_tensor_data(input)[0] + w_data = weight if isinstance(weight, torch.Tensor) else extract_tensor_data(weight)[0] + + if quantizer is None or out is not None: + result_out, rsigma = _ops.rmsnorm_fwd(inp_data, w_data, eps, sm_margin, zero_centered_gamma) + if quantizer is not None and out is not None: + from transformer_engine.pytorch.tensor._quantize_stable import quantize_into + + quantize_into(result_out, quantizer, out) + return [out, None, rsigma] + return [result_out, None, rsigma] + + # Quantized path: try fused norm+quantize first (delayed scaling only) + fused = _try_fused_norm_quantize_rmsnorm( + inp_data, w_data, eps, quantizer, sm_margin, zero_centered_gamma + ) + if fused is not None: + q_out, rsigma = fused + return [q_out, None, rsigma] + + # Fallback: unfused norm then quantize + result_out, rsigma = _ops.rmsnorm_fwd(inp_data, w_data, eps, sm_margin, zero_centered_gamma) + from transformer_engine.pytorch.tensor._quantize_stable import quantize_new + + q_out = quantize_new(result_out, quantizer) + _fill_fp8_transpose_if_needed(q_out) + return [q_out, None, rsigma] + + +# ============================================================================ +# NVSHMEM — implemented via stable ops + ctypes +# ============================================================================ + +_nvshmem_lib = None + + +def _get_nvshmem_lib(): + """Load libnvshmem_host.so via ctypes (cached).""" + global _nvshmem_lib + if _nvshmem_lib is not None: + return _nvshmem_lib + try: + _nvshmem_lib = _ctypes.CDLL("libnvshmem_host.so") + return _nvshmem_lib + except OSError as exc: + raise RuntimeError( + "NVSHMEM not available. Ensure libnvshmem_host.so is installed and " + "TE was built with NVTE_ENABLE_NVSHMEM=1." + ) from exc + + +def init_nvshmem_backend(process_group): + """Initialize NVSHMEM backend using a PyTorch distributed process group. + + Uses torch.distributed for the broadcast (replacing the old pybind c10d path) + and ctypes for NVSHMEM init calls. + """ + import torch.distributed as dist + + lib = _get_nvshmem_lib() + my_rank = dist.get_rank(process_group) + num_ranks = dist.get_world_size(process_group) + + # nvshmemx_uniqueid_t is 128 bytes + UNIQUEID_SIZE = 128 + id_tensor = torch.zeros(UNIQUEID_SIZE, dtype=torch.uint8) + + if my_rank == 0: + # nvshmemx_get_uniqueid(nvshmemx_uniqueid_t *id) + lib.nvshmemx_get_uniqueid(_ctypes.c_void_p(id_tensor.data_ptr())) + + # Broadcast the unique ID from rank 0 + id_gpu = id_tensor.cuda() + dist.broadcast(id_gpu, src=0, group=process_group) + id_tensor.copy_(id_gpu.cpu()) + + # nvshmemx_init_attr_t + nvshmemx_set_attr_uniqueid_args + nvshmemx_init_attr + INIT_ATTR_SIZE = 256 # nvshmemx_init_attr_t struct size (generous) + attr_buf = ((_ctypes.c_char) * INIT_ATTR_SIZE)() + _ctypes.memset(attr_buf, 0, INIT_ATTR_SIZE) + + lib.nvshmemx_set_attr_uniqueid_args( + _ctypes.c_int(my_rank), + _ctypes.c_int(num_ranks), + _ctypes.c_void_p(id_tensor.data_ptr()), + _ctypes.byref(attr_buf), + ) + + NVSHMEMX_INIT_WITH_UNIQUEID = 1 + lib.nvshmemx_init_attr( + _ctypes.c_int(NVSHMEMX_INIT_WITH_UNIQUEID), + _ctypes.byref(attr_buf), + ) + + # Validate + lib.nvshmem_my_pe.restype = _ctypes.c_int + lib.nvshmem_n_pes.restype = _ctypes.c_int + assert ( + my_rank == lib.nvshmem_my_pe() + ), f"my_rank {my_rank} != nvshmem_my_pe {lib.nvshmem_my_pe()}" + assert ( + num_ranks == lib.nvshmem_n_pes() + ), f"num_ranks {num_ranks} != nvshmem_n_pes {lib.nvshmem_n_pes()}" + + +def create_nvshmem_tensor(shape, dtype): + """Allocate a tensor in NVSHMEM shared memory. + + Uses the stable ABI from_blob with nvshmem_free deleter to create a tensor + whose memory is allocated via nvshmem_malloc and accessible by remote PEs. + """ + # Map torch dtype to ScalarType int for the stable op + _DTYPE_TO_SCALAR = { + torch.float32: 6, # ScalarType::Float + torch.float64: 7, # ScalarType::Double + torch.float16: 5, # ScalarType::Half + torch.bfloat16: 15, # ScalarType::BFloat16 + torch.uint8: 0, # ScalarType::Byte + torch.int32: 3, # ScalarType::Int + torch.int64: 4, # ScalarType::Long + } + scalar_type = _DTYPE_TO_SCALAR.get(dtype) + if scalar_type is None: + raise ValueError(f"Unsupported dtype {dtype} for nvshmem_create_tensor") + + numel = 1 + for s in shape: + numel *= s + device_idx = torch.cuda.current_device() + + # Stable op allocates via nvshmem_malloc and wraps with from_blob + deleter + flat_tensor = _ops.nvshmem_create_tensor(numel, scalar_type, device_idx) + return flat_tensor.view(shape) + + +def nvshmem_send_on_current_stream(src, dst, peer, signal): + """Send tensor data to a remote PE with signal using NVSHMEM.""" + _ops.nvshmem_send_on_current_stream(src, dst, int(peer), signal) + + +def nvshmem_wait_on_current_stream(signal, wait_kind="stream"): + """Wait for a signal from a remote PE using NVSHMEM.""" + wait_kind_map = {"kernel": 0, "nvshmem": 1, "stream": 2} + wait_kind_int = wait_kind_map.get(wait_kind, 2) + _ops.nvshmem_wait_on_current_stream(signal, wait_kind_int) + + +def nvshmem_finalize(): + """Finalize NVSHMEM and free resources.""" + lib = _get_nvshmem_lib() + lib.nvshmem_finalize() + + +# ============================================================================ +# Check if userbuffers uses MPI +# ============================================================================ + +# ============================================================================ +# GEMM +# ============================================================================ + + +def generic_gemm( # pylint: disable=unused-argument,redefined-outer-name + A, + transa, + B, + transb, + D, + quantizer, + out_dtype, + bias, + bias_type, + gelu, + gelu_in, + grad, + workspace, + workspaceSize, + accumulate, + use_split_accumulator, + comm_overlap=None, + comm_type=None, + extra_output=None, + bulk_overlap=False, + alpha=1.0, + beta=None, +): + """GEMM via stable ABI ops with Python-side tensor metadata extraction.""" + from transformer_engine.pytorch.tensor._extract import extract_tensor_data + + # Ensure workspace is large enough for cuBLAS. The pybind path dynamically + # resized the workspace; the stable path receives a fixed tensor. + _MIN_WORKSPACE = 33554432 # 32 MiB, matches pybind default + if isinstance(workspace, torch.Tensor) and workspace.numel() < _MIN_WORKSPACE: + workspace = torch.empty(_MIN_WORKSPACE, dtype=torch.uint8, device=workspace.device) + + A_data, A_dtype, A_scale_inv, A_sm, A_swizzled, A_cw_data, A_cw_scale_inv, A_amax, A_cw_amax = ( + _extract_gemm_operand(A, transa) + ) + B_data, B_dtype, B_scale_inv, B_sm, B_swizzled, B_cw_data, B_cw_scale_inv, B_amax, B_cw_amax = ( + _extract_gemm_operand(B, not transb) + ) + _TORCH_DT = {torch.float32: 4, torch.float16: 5, torch.bfloat16: 6, torch.uint8: 0} + _TE_TO_TORCH_DT = { + 4: torch.float32, + 5: torch.float16, + 6: torch.bfloat16, + 0: torch.uint8, + 7: torch.uint8, # kFloat8E4M3 stored as uint8 + 8: torch.uint8, # kFloat8E5M2 stored as uint8 + } + + # A tensor may be columnwise-only (rowwise data is an empty placeholder, numel=0, + # but colwise data exists). Only skip_gemm when NO data is available at all. + def _operand_has_data(data, cw_data): + return data.numel() > 0 or ( + cw_data is not None and isinstance(cw_data, torch.Tensor) and cw_data.numel() > 0 + ) + + skip_gemm = not _operand_has_data(A_data, A_cw_data) or not _operand_has_data(B_data, B_cw_data) + + _fused_output_quant = ( + False # Will be set True only for fused delayed-scaling output quantization + ) + if D is not None: + D_data, D_dtype, D_scale_inv, D_sm = extract_tensor_data(D) + D_amax = getattr(D, "_amax", None) or ( + getattr(quantizer, "amax", None) if quantizer else None + ) + D_scale = getattr(quantizer, "scale", None) if quantizer else None + if isinstance(D_amax, torch.Tensor) and D_amax.numel() == 0: + D_amax = None + if isinstance(D_scale, torch.Tensor) and D_scale.numel() == 0: + D_scale = None + # Zero stale amax/scale for CurrentScaling (matches quantize_into behavior) + if quantizer is not None and "CurrentScaling" in type(quantizer).__name__: + if isinstance(D_amax, torch.Tensor): + D_amax.zero_() + if isinstance(D_scale, torch.Tensor): + D_scale.zero_() + else: + # NVTE GEMM column-major convention: + # A1 = last dim of A, A0 = product of all other dims + # B1 = last dim of B, B0 = product of all other dims + # k = (transa ? A1 : A0), M = (transa ? A0 : A1) + # Output shape mirrors pybind getGemmOutputShape: + # transb=True → (B1, M) + # transb=False → (*B_shape[:-1], M) — preserves multi-dim batch dims + # + # When A (or B) is columnwise-only, A_data is an empty placeholder (shape (0,)). + # Derive logical dims from the columnwise buffer. + # MXFP8Tensor: columnwise has same logical shape as rowwise → use shape[-1]. + # Float8Tensor, Float8BlockwiseQTensor: columnwise = physical transpose → use shape[0]. + def _cw_is_same_shape(t): + return "MXFP8" in type(t).__name__ + + _A_cw_is_transpose = not _cw_is_same_shape(A) + _B_cw_is_transpose = not _cw_is_same_shape(B) + # FP4 data is packed (2 elements per byte), so physical shape has K/2 + # in the last dim. Double it to get the logical dimension, matching + # the C++ gemm.cpp buildInputTensorWrapper logic. + _kFloat4E2M1 = 10 + _A_fp4 = A_dtype == _kFloat4E2M1 + _B_fp4 = B_dtype == _kFloat4E2M1 + if A_data.numel() == 0 and A_cw_data is not None: + if _A_cw_is_transpose: + A1 = A_cw_data.shape[0] # physical transpose: shape[0] = last dim of logical + A0 = A_cw_data.numel() // max(A1, 1) + else: + A1 = A_cw_data.shape[-1] # same logical shape: shape[-1] = last dim + A0 = A_cw_data.numel() // max(A1, 1) + if _A_fp4: + # Columnwise FP4: physical transpose [K, M/2] → logical [K, M] + if _A_cw_is_transpose: + A0 *= 2 + else: + A1 *= 2 + else: + A1 = A_data.shape[-1] + A0 = A_data.numel() // max(A1, 1) + if _A_fp4: + # Rowwise FP4: [M, K/2] → logical [M, K] + A1 *= 2 + if B_data.numel() == 0 and B_cw_data is not None: + if _B_cw_is_transpose: + B1 = B_cw_data.shape[0] + else: + B1 = B_cw_data.shape[-1] + if _B_fp4: + if _B_cw_is_transpose: + pass # B1 is from shape[0], which is the last logical dim (correct for non-packed dim) + else: + B1 *= 2 + else: + B1 = B_data.shape[-1] + if _B_fp4: + B1 *= 2 + M = A0 if transa else A1 + if transb: + out_shape = [B1, M] + else: + out_shape = list(B_data.shape[:-1]) + [M] + if quantizer is not None: + # Determine the logical output dtype (pybind: output_dtype = out_dtype ? *out_dtype : A_tensor.dtype()) + if isinstance(out_dtype, torch.dtype): + _q_fake_dtype = out_dtype + elif out_dtype is not None: + _q_fake_dtype = _TE_TO_TORCH_DT.get(int(out_dtype), torch.bfloat16) + else: + _q_fake_dtype = A.dtype if isinstance(A, torch.Tensor) else torch.bfloat16 + + # Decide fused vs unfused quantization. + # The pybind path only fuses output quantization when: + # - Inputs are low-precision (FP8) + # - Output quantizer is Float8Quantizer (delayed scaling) + # - Inputs are per-tensor scaling (Float8Tensor, not blockwise/MXFP8/NVFP4) + # All other cases use unfused quantization: GEMM → HP output → quantize. + _low_precision = A_dtype in (7, 8) or B_dtype in (7, 8) # kFloat8E4M3 or kFloat8E5M2 + _is_delayed_scaling_quantizer = type(quantizer).__name__ == "Float8Quantizer" + _is_per_tensor_input = A_sm == 0 and B_sm == 0 # both DELAYED_TENSOR_SCALING + _fused_output_quant = ( + _low_precision and _is_delayed_scaling_quantizer and _is_per_tensor_input + ) + if _fused_output_quant: + D = quantizer.make_empty( + out_shape, + dtype=_q_fake_dtype, + device=A_data.device, + ) + D_data, D_dtype, D_scale_inv, D_sm = extract_tensor_data(D) + D_amax = getattr(quantizer, "amax", None) + D_scale = getattr(quantizer, "scale", None) + if "CurrentScaling" in type(quantizer).__name__: + if isinstance(D_amax, torch.Tensor): + D_amax.zero_() + if isinstance(D_scale, torch.Tensor): + D_scale.zero_() + else: + # Unfused quantization: GEMM produces HP output, quantize separately after. + # Use _q_fake_dtype for the intermediate HP output. + out_dt = _q_fake_dtype + out_te_dtype = _TORCH_DT.get(out_dt, 6) + D = torch.empty(*out_shape, dtype=out_dt, device=A_data.device) + D_data, D_dtype, D_scale_inv, D_sm = D, out_te_dtype, None, 0 + D_amax, D_scale = None, None + else: + if isinstance(out_dtype, torch.dtype): + out_dt = out_dtype + out_te_dtype = _TORCH_DT.get(out_dt, 6) + elif out_dtype is not None: + out_te_dtype = int(out_dtype) + out_dt = _TE_TO_TORCH_DT.get(out_te_dtype, torch.bfloat16) + else: + out_dt = A.dtype if isinstance(A, torch.Tensor) else torch.bfloat16 + out_te_dtype = _TORCH_DT.get(out_dt, 6) + D = torch.empty(*out_shape, dtype=out_dt, device=A_data.device) + D_data, D_dtype, D_scale_inv, D_sm = D, out_te_dtype, None, 0 + D_amax, D_scale = None, None + + # Skip GEMM when any operand is empty (e.g., zero-token inputs in backward pass). + # Still allocate D above so callers always get a tensor (possibly zero-element). + if skip_gemm: + D_tensor = extract_tensor_data(D)[0] if D is not None else D_data + if isinstance(D_tensor, torch.Tensor) and D_tensor.numel() > 0 and not accumulate: + D_tensor.zero_() + # When grad=True and bias is provided, the pybind path allocates a zeroed + # dbias tensor (at::empty → zero_()) and returns it. Returning None here + # would break callers that store the dbias reference for delayed wgrad + # (e.g. GroupedLinear.backward_dw assigns bias_params[i].grad = grad_biases_[i]). + skip_dbias = None + if bias is not None and isinstance(bias, torch.Tensor) and bias.numel() > 0 and grad: + bias.zero_() + dbias_dt = D_data.dtype if isinstance(D_data, torch.Tensor) else torch.bfloat16 + skip_dbias = torch.zeros(B_data.shape[-1], dtype=dbias_dt, device=A_data.device) + if gelu_in is not None and isinstance(gelu_in, torch.Tensor) and gelu_in.numel() > 0: + gelu_in.zero_() + return [D, skip_dbias, gelu_in, extra_output] + + # For FP8 delayed-tensor-scaling GEMMs on Hopper (non-Blackwell), cuBLAS only + # supports TN layout. When A is not transposed (NN/NT), NVTE's CanonicalizeGemmInput + # requires A to have columnwise data (the physical transpose). Similarly, when B is + # transposed (TT/NT), B needs columnwise data. + # Create the FP8 transpose on-the-fly when missing. + _NVTE_DELAYED = 0 # NVTE_DELAYED_TENSOR_SCALING + _NVTE_MXFP8 = 1 # NVTE_MXFP8_1D_SCALING + if not transa and A_cw_data is None and A_sm == _NVTE_DELAYED and A_dtype in (7, 8): + A_cw_data = _ops.fp8_transpose(A_data, A_dtype, None) + A_cw_scale_inv = A_scale_inv + if transb and B_cw_data is None and B_sm == _NVTE_DELAYED and B_dtype in (7, 8): + B_cw_data = _ops.fp8_transpose(B_data, B_dtype, None) + B_cw_scale_inv = B_scale_inv + + # For MXFP8: when the GEMM needs columnwise data but it's missing or uninitialized + # (e.g., tensors produced by GEMM+GELU fusion that only have rowwise data), create + # columnwise data on-the-fly by dequantizing rowwise and re-quantizing bidirectionally. + def _ensure_mxfp8_columnwise(data, dtype, scale_inv, cw_data, cw_si, sm): + """Create MXFP8 columnwise data from rowwise if missing.""" + if sm != _NVTE_MXFP8 or data is None or data.numel() == 0: + return cw_data, cw_si + # Dequantize rowwise data to get the high-precision source + src = _ops.dequantize(data, dtype, scale_inv, None, sm, 6) # 6 = kBFloat16 + # Allocate columnwise buffers if needed + if cw_data is None: + cw_data = torch.empty_like(data) + if cw_si is None: + cw_si = torch.empty_like(scale_inv) + # Save rowwise data/scale before bidirectional quantization overwrites them + rw_data_backup = data.clone() + rw_si_backup = scale_inv.clone() + # Re-quantize with both rowwise+columnwise (bidirectional) + _ops.quantize_bidirectional( + src, + data, + dtype, + None, + None, + scale_inv, + cw_data, + cw_si, + sm, + False, + 0.0, + None, + False, # nvfp4_2d_quantization + ) + # Restore original rowwise data (only columnwise was needed) + data.copy_(rw_data_backup) + scale_inv.copy_(rw_si_backup) + return cw_data, cw_si + + if not transa and A_sm == _NVTE_MXFP8 and A_dtype in (7, 8): + A_cw_data, A_cw_scale_inv = _ensure_mxfp8_columnwise( + A_data, A_dtype, A_scale_inv, A_cw_data, A_cw_scale_inv, A_sm + ) + if transb and B_sm == _NVTE_MXFP8 and B_dtype in (7, 8): + B_cw_data, B_cw_scale_inv = _ensure_mxfp8_columnwise( + B_data, B_dtype, B_scale_inv, B_cw_data, B_cw_scale_inv, B_sm + ) + + # When grad=True with bias, allocate a fresh dbias tensor for the GEMM kernel to write into. + # The pybind path does the same: at::empty({B_shape[-1]}, dtype=out_tensor.dtype). + dbias = None + bias_arg = bias + if bias is not None and grad: + dbias_dt = D_data.dtype if isinstance(D_data, torch.Tensor) else torch.bfloat16 + dbias = torch.empty(B_data.shape[-1], dtype=dbias_dt, device=A_data.device) + bias_arg = dbias + + # GELU epilogue: when gelu=True the pybind version allocates a pre_gelu_out tensor + # (forward) or uses the caller-provided gelu_in (backward). The C++ GEMM only + # enables the GELU epilogue when pre_gelu_out has a value, so we must create it here. + if gelu and gelu_in is None and not grad: + # Forward: allocate pre-GELU output (same shape/dtype as D output). + # The pybind version uses gelu_type = low_precision ? bias_type : out_tensor.dtype() + # For high-precision (non-FP8) this is the output dtype. + _low_precision = A_dtype in (7, 8) or B_dtype in (7, 8) + if _low_precision: + gelu_dt = ( + bias.dtype + if bias is not None and isinstance(bias, torch.Tensor) + else torch.bfloat16 + ) + else: + gelu_dt = D_data.dtype if isinstance(D_data, torch.Tensor) else torch.bfloat16 + out_shape_for_gelu = ( + list(D_data.shape) if isinstance(D_data, torch.Tensor) else list(D.shape) + ) + gelu_in = torch.empty(out_shape_for_gelu, dtype=gelu_dt, device=A_data.device) + + if comm_overlap is not None: + _ops.gemm_with_comm_overlap( + A_data, + A_dtype, + A_scale_inv, + A_cw_data, + A_cw_scale_inv, + A_sm, + A_swizzled, + transa, + B_data, + B_dtype, + B_scale_inv, + B_cw_data, + B_cw_scale_inv, + B_sm, + B_swizzled, + transb, + D_data, + D_dtype, + D_amax, + D_scale, + D_scale_inv, + D_sm, + bias_arg, + int(bias_type) if bias_type is not None else 0, + gelu_in, + workspace, + grad, + accumulate, + use_split_accumulator, + comm_overlap._handle, + int(comm_type), + bulk_overlap, + extra_output, + ) + else: + _ops.gemm( + A_data, + A_dtype, + A_scale_inv, + A_cw_data, + A_cw_scale_inv, + A_sm, + A_swizzled, + transa, + B_data, + B_dtype, + B_scale_inv, + B_cw_data, + B_cw_scale_inv, + B_sm, + B_swizzled, + transb, + D_data, + D_dtype, + D_amax, + D_scale, + D_scale_inv, + D_sm, + bias_arg, + int(bias_type) if bias_type is not None else 0, + gelu_in, + workspace, + grad, + accumulate, + use_split_accumulator, + alpha, + A_amax, + A_cw_amax, + B_amax, + B_cw_amax, + ) + + # Unfused output quantization: if the GEMM produced HP output and the caller + # requested a quantized output, quantize now (matches pybind unfused_quantization_needed). + if quantizer is not None and not _fused_output_quant and not skip_gemm: + D = quantizer(D) + + return [D, dbias, gelu_in, extra_output] + + +# ============================================================================ +# Quantize (match pybind11 signatures) +# ============================================================================ + + +def quantize(tensor, quantizer, output=None, noop=None): + """Quantize using stable ABI ops, bypassing pybind11.""" + from transformer_engine.pytorch.tensor._quantize_stable import quantize_into, quantize_new + + if quantizer is None: + return tensor + if output is not None: + quantize_into(tensor, quantizer, output, noop) + return output + return quantize_new(tensor, quantizer) + + +def dequantize(input, otype): # pylint: disable=redefined-builtin + """Dequantize using stable ABI ops.""" + from transformer_engine.pytorch.tensor._extract import extract_tensor_data + + in_data, in_dtype, in_scale_inv, in_sm = extract_tensor_data(input) + _TORCH_TO_TE = {torch.float32: 4, torch.float16: 5, torch.bfloat16: 6} + out_te_dtype = _TORCH_TO_TE.get(otype, 4) if isinstance(otype, torch.dtype) else int(otype) + # NVFP4 dequantize kernel needs per-tensor amax for reconstruction + in_amax = None + for attr in ("_amax_rowwise", "_amax", "amax"): + a = getattr(input, attr, None) + if isinstance(a, torch.Tensor) and a.numel() > 0: + in_amax = a + break + return _ops.dequantize(in_data, in_dtype, in_scale_inv, in_amax, in_sm, out_te_dtype) + + +multi_tensor_quantize = _not_implemented("multi_tensor_quantize") + + +def group_quantize(tensor, quantizer, num_tensors, first_dims): + """Quantize a grouped tensor (multiple tensors concatenated along dim 0). + + Pure Python implementation: splits by first_dims, quantizes each chunk, + then constructs a GroupedTensor with concatenated data/scale buffers. + """ + from transformer_engine.pytorch.tensor._quantize_stable import quantize_new + from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor + + tensor = tensor.contiguous() + N = tensor.shape[-1] + device = tensor.device + + # Get split sizes + if first_dims is not None: + splits = first_dims.tolist() if isinstance(first_dims, torch.Tensor) else list(first_dims) + else: + M_each = tensor.shape[0] // num_tensors + splits = [M_each] * num_tensors + + # Quantize each chunk + chunks = [] + offset = 0 + for i in range(num_tensors): + n = int(splits[i]) + chunk = tensor[offset : offset + n] + chunks.append(quantize_new(chunk, quantizer)) + offset += n + + # Concatenate rowwise data and scale_inv into flat 1D buffers + all_data = [] + all_si = [] + all_cw_data = [] + all_cw_si = [] + offsets_list = [0] + si_offsets_list = [0] + cw_si_offsets_list = [0] + shapes_list = [] + + for qt in chunks: + rd = getattr(qt, "_rowwise_data", getattr(qt, "_data", None)) + if rd is None: + rd = qt if isinstance(qt, torch.Tensor) else None + rsi = getattr(qt, "_rowwise_scale_inv", getattr(qt, "_scale_inv", None)) + cwd = getattr(qt, "_columnwise_data", None) + cwsi = getattr(qt, "_columnwise_scale_inv", None) + if rd is not None: + all_data.append(rd.reshape(-1)) + if rsi is not None: + all_si.append(rsi.reshape(-1)) + if cwd is not None: + all_cw_data.append(cwd.reshape(-1)) + if cwsi is not None: + all_cw_si.append(cwsi.reshape(-1)) + # Track shapes (flattened to 2D) - use whichever data is available + ref_data = rd if rd is not None else cwd + M_i = ref_data.shape[0] if ref_data is not None and ref_data.ndim >= 1 else 1 + N_i = ( + ref_data.shape[-1] + if ref_data is not None and ref_data.ndim >= 2 + else (ref_data.numel() if ref_data is not None else 0) + ) + shapes_list.append((M_i, N_i)) + offsets_list.append(offsets_list[-1] + (rd.numel() if rd is not None else 0)) + if rsi is not None: + si_offsets_list.append(si_offsets_list[-1] + rsi.numel()) + if cwsi is not None: + cw_si_offsets_list.append(cw_si_offsets_list[-1] + cwsi.numel()) + + flat_data = ( + torch.cat(all_data) if all_data else torch.empty(0, dtype=torch.uint8, device=device) + ) + flat_si = torch.cat(all_si) if all_si else None + flat_cw_data = torch.cat(all_cw_data) if all_cw_data else None + flat_cw_si = torch.cat(all_cw_si) if all_cw_si else None + + total_M = sum(int(s) for s in splits) + logical_shape = (total_M, N) + + # Build first_dims tensor on device + first_dims_tensor = ( + first_dims + if isinstance(first_dims, torch.Tensor) + else torch.tensor(splits, dtype=torch.int64, device=device) + ) + + # Compute tensor_offsets + tensor_offsets = torch.tensor(offsets_list[:-1], dtype=torch.int64, device=device) + + gt = GroupedTensor( + shape=logical_shape, + dtype=tensor.dtype, + num_tensors=num_tensors, + shapes=shapes_list, + quantizer=quantizer, + data=flat_data, + columnwise_data=flat_cw_data, + scale_inv=flat_si, + columnwise_scale_inv=flat_cw_si, + first_dims=first_dims_tensor, + tensor_offsets=tensor_offsets, + offsets=offsets_list, + scale_inv_offsets=si_offsets_list if flat_si is not None else None, + columnwise_scale_inv_offsets=cw_si_offsets_list if flat_cw_si is not None else None, + ) + return gt + + +def split_quantize( + tensor, split_sections, quantizer_list, disable_bulk_allocation=False +): # pylint: disable=unused-argument + """Split tensor along dim 0 and quantize each split independently. + + Python implementation of pybind split_quantize. Uses per-split quantize_new, + matching the "UNFUSED" allocation/quantization path in the C++ version. + The bulk-allocation optimizations (for Float8Block/MXFP8/NVFP4) are not + implemented; correctness is preserved for all quantizer types via the unfused path. + """ + from transformer_engine.pytorch.tensor._quantize_stable import quantize_new + + num_splits = len(split_sections) + if num_splits == 0: + return [] + tensor = tensor.contiguous() + results = [] + offset = 0 + for i in range(num_splits): + n = split_sections[i] + split = tensor[offset : offset + n] + quantizer = quantizer_list[i] + if quantizer is None: + results.append(split) + else: + results.append(quantize_new(split, quantizer)) + offset += n + return results + + +# ============================================================================ +# Swizzle (match pybind11 signature) +# ============================================================================ + + +def swizzle_scales_for_gemm_(tensor): + """Swizzle MXFP8/NVFP4 scales in-place for later GEMM use.""" + if getattr(tensor, "_with_gemm_swizzled_scales", False): + return + + _, te_dtype, _, scaling_mode, _, _, _, _, _ = _extract_gemm_operand(tensor, True) + + if not hasattr(_ops, "swizzle_scale_for_gemm"): + return + + if hasattr(tensor, "_rowwise_data") and getattr(tensor, "_rowwise_scale_inv", None) is not None: + tensor._rowwise_scale_inv = _ops.swizzle_scale_for_gemm( + tensor._rowwise_data, tensor._rowwise_scale_inv, te_dtype, scaling_mode + ) + + if ( + hasattr(tensor, "_columnwise_data") + and getattr(tensor, "_columnwise_scale_inv", None) is not None + ): + tensor._columnwise_scale_inv = _ops.swizzle_scale_for_gemm( + tensor._columnwise_data, tensor._columnwise_scale_inv, te_dtype, scaling_mode + ) + + if hasattr(tensor, "_scale_inv") and getattr(tensor, "_scale_inv", None) is not None: + tensor._scale_inv = _ops.swizzle_scale_for_gemm( + tensor._data, tensor._scale_inv, te_dtype, scaling_mode + ) + + tensor._with_gemm_swizzled_scales = True + + +# ============================================================================ +# Activation ops (match pybind11 individual function names) +# ============================================================================ + + +def _make_activation_fwd(act_type, shape_divisor=1): + _TE_DTYPE = {torch.float32: 4, torch.float16: 5, torch.bfloat16: 6} + DELAYED = 0 + + def fn(input, quantizer): # pylint: disable=redefined-builtin + from transformer_engine.pytorch.tensor._extract import extract_tensor_data + + inp = input if isinstance(input, torch.Tensor) else extract_tensor_data(input)[0] + out_shape = list(inp.shape) + if shape_divisor > 1: + out_shape[-1] //= shape_divisor + te_dt = _TE_DTYPE.get(inp.dtype, 6) + device = inp.device + + if quantizer is None: + # Path: no quantization + out = torch.empty(out_shape, dtype=inp.dtype, device=device) + _ops.activation_fwd_noalloc(inp, out, te_dt, None, None, None, DELAYED, act_type) + return out + + # Determine implementation path (matches C++ activation_helper dispatch) + q_type = type(quantizer).__name__ + is_delayed = ( + "Float8Quantizer" in q_type and "Current" not in q_type and "Block" not in q_type + ) + is_mxfp8 = "MXFP8" in q_type + is_current_scaling = "CurrentScaling" in q_type + + if quantizer is None or is_delayed or is_mxfp8: + # FULLY_FUSED: kernel writes directly to quantized output + out_py = quantizer.make_empty(out_shape, dtype=inp.dtype, device=device) + out_data, out_dtype, out_scale_inv, out_sm = extract_tensor_data(out_py) + + if is_mxfp8: + out_sm = 1 # MXFP8_1D_SCALING + elif is_delayed: + out_sm = 0 # DELAYED_TENSOR_SCALING + + out_amax = getattr(quantizer, "amax", None) + out_scale = getattr(quantizer, "scale", None) + if isinstance(out_amax, torch.Tensor) and out_amax.numel() == 0: + out_amax = None + if isinstance(out_scale, torch.Tensor) and out_scale.numel() == 0: + out_scale = None + + _ops.activation_fwd_noalloc( + inp, out_data, out_dtype, out_amax, out_scale, out_scale_inv, out_sm, act_type + ) + + if hasattr(out_py, "_fp8_dtype") and hasattr(quantizer, "dtype"): + out_py._fp8_dtype = quantizer.dtype + + # activation_fwd_noalloc only fills _data (rowwise). If _transpose was + # pre-allocated by make_empty, it is still uninitialized. Mark it invalid + # and recompute from _data so downstream GEMMs see valid columnwise data. + if hasattr(out_py, "_transpose") and out_py._transpose is not None: + out_py._transpose_invalid = True + _fill_fp8_transpose_if_needed(out_py) + return out_py + + if is_current_scaling: + # FUSED_ACTIVATION_AMAX_FP8: activation→hp+amax, then quantize_from_amax + amax = getattr(quantizer, "amax", torch.zeros(1, dtype=torch.float32, device=device)) + # The activation kernel uses atomicMaxFloat to accumulate the amax, + # so it must be zeroed before each call (matches pybind path's + # create_unquantized_tensor_with_amax which calls amax.zero_()). + amax.zero_() + # Compute activation to hp output WITH amax + hp_out = torch.empty(out_shape, dtype=inp.dtype, device=device) + _ops.activation_fwd_noalloc(inp, hp_out, te_dt, amax, None, None, DELAYED, act_type) + # Quantize using pre-computed amax + out_py = quantizer.make_empty(out_shape, dtype=inp.dtype, device=device) + from transformer_engine.pytorch.tensor._quantize_stable import quantize_into + + # Set use_existing_amax so quantize_into uses quantize_from_amax + orig = getattr(quantizer, "use_existing_amax", False) + quantizer.use_existing_amax = True + quantize_into(hp_out, quantizer, out_py) + quantizer.use_existing_amax = orig + return out_py + + # UNFUSED (block scaling, NVFP4 with post-RHT amax): + # activation→hp, then full quantize + hp_out = torch.empty(out_shape, dtype=inp.dtype, device=device) + _ops.activation_fwd_noalloc(inp, hp_out, te_dt, None, None, None, DELAYED, act_type) + from transformer_engine.pytorch.tensor._quantize_stable import quantize_new + + return quantize_new(hp_out, quantizer) + + return fn + + +def _make_activation_bwd(act_type): + _TE_DTYPE = {torch.float32: 4, torch.float16: 5, torch.bfloat16: 6} + DELAYED = 0 + + def fn(grad, input, quantizer): # pylint: disable=redefined-builtin + from transformer_engine.pytorch.tensor._extract import extract_tensor_data + + inp = input if isinstance(input, torch.Tensor) else extract_tensor_data(input)[0] + grad_t = grad if isinstance(grad, torch.Tensor) else extract_tensor_data(grad)[0] + + if quantizer is None: + te_dt = _TE_DTYPE.get(inp.dtype, 6) + out = torch.empty_like(inp) + _ops.dactivation_noalloc(grad_t, inp, out, te_dt, None, None, None, 0, act_type) + return out + + q_type = type(quantizer).__name__ + is_current_scaling = "CurrentScaling" in q_type + + if is_current_scaling: + # Current scaling: compute backward activation to hp output first, + # then quantize. This mirrors the forward path for current scaling. + # The fused dactivation kernel for delayed scaling uses a pre-existing + # scale, but current scaling needs to compute amax+scale on-the-fly. + te_dt = _TE_DTYPE.get(inp.dtype, 6) + amax = getattr( + quantizer, "amax", torch.zeros(1, dtype=torch.float32, device=inp.device) + ) + # The dactivation kernel uses atomicMaxFloat to accumulate the amax, + # so it must be zeroed before each call. + amax.zero_() + hp_out = torch.empty_like(inp) + _ops.dactivation_noalloc( + grad_t, inp, hp_out, te_dt, amax, None, None, DELAYED, act_type + ) + # Quantize using pre-computed amax + out_py = quantizer.make_empty(list(inp.shape), dtype=inp.dtype, device=inp.device) + from transformer_engine.pytorch.tensor._quantize_stable import quantize_into + + orig = getattr(quantizer, "use_existing_amax", False) + quantizer.use_existing_amax = True + quantize_into(hp_out, quantizer, out_py) + quantizer.use_existing_amax = orig + + # Fill transpose if pre-allocated (needed for wgrad GEMM) + _fill_fp8_transpose_if_needed(out_py) + return out_py + + # Quantized output (delayed scaling, block scaling, MXFP8, NVFP4) + out_py = quantizer.make_empty(list(inp.shape), dtype=inp.dtype, device=inp.device) + out_data, out_dtype, out_scale_inv, out_sm = extract_tensor_data(out_py) + + if "Block" in q_type: + out_sm = 3 if getattr(quantizer, "block_scaling_dim", 2) == 2 else 2 # BLOCK_2D=3, 1D=2 + elif "MXFP8" in q_type: + out_sm = 1 # MXFP8_1D_SCALING + elif "NVFP4" in q_type: + out_sm = 4 # NVFP4_1D_SCALING + + out_amax = getattr(quantizer, "amax", None) + out_scale = getattr(quantizer, "scale", None) + if isinstance(out_amax, torch.Tensor) and out_amax.numel() == 0: + out_amax = None + if isinstance(out_scale, torch.Tensor) and out_scale.numel() == 0: + out_scale = None + + _ops.dactivation_noalloc( + grad_t, inp, out_data, out_dtype, out_amax, out_scale, out_scale_inv, out_sm, act_type + ) + + if hasattr(out_py, "_fp8_dtype") and hasattr(quantizer, "dtype"): + out_py._fp8_dtype = quantizer.dtype + + # dactivation_noalloc only fills _data (rowwise). If _transpose was + # pre-allocated by make_empty, it is still uninitialized. Mark it invalid + # and recompute from _data so downstream GEMMs see valid columnwise data. + if hasattr(out_py, "_transpose") and out_py._transpose is not None: + out_py._transpose_invalid = True + _fill_fp8_transpose_if_needed(out_py) + return out_py + + return fn + + +# 0=gelu, 1=glu, 2=geglu, 3=qgelu, 4=qgeglu, 5=relu, 6=reglu, 7=srelu, 8=sreglu, 9=silu, 10=swiglu +gelu = _make_activation_fwd(0) +glu = _make_activation_fwd(1, 2) +geglu = _make_activation_fwd(2, 2) +qgelu = _make_activation_fwd(3) +qgeglu = _make_activation_fwd(4, 2) +relu = _make_activation_fwd(5) +reglu = _make_activation_fwd(6, 2) +srelu = _make_activation_fwd(7) +sreglu = _make_activation_fwd(8, 2) +silu = _make_activation_fwd(9) +swiglu = _make_activation_fwd(10, 2) + +dgelu = _make_activation_bwd(0) +dglu = _make_activation_bwd(1) +dgeglu = _make_activation_bwd(2) +dqgelu = _make_activation_bwd(3) +dqgeglu = _make_activation_bwd(4) +drelu = _make_activation_bwd(5) +dreglu = _make_activation_bwd(6) +dsrelu = _make_activation_bwd(7) +dsreglu = _make_activation_bwd(8) +dsilu = _make_activation_bwd(9) +dswiglu = _make_activation_bwd(10) + + +def clamped_swiglu(input, quantizer, limit, alpha): # pylint: disable=redefined-builtin + inp = input if isinstance(input, torch.Tensor) else input + out = torch.empty(*inp.shape[:-1], inp.shape[-1] // 2, dtype=inp.dtype, device=inp.device) + _TORCH_TO_TE_DT = { + torch.float32: int(DType.kFloat32), + torch.float16: int(DType.kFloat16), + torch.bfloat16: int(DType.kBFloat16), + } + _ops.clamped_activation_fwd_noalloc( + inp, + out, + _TORCH_TO_TE_DT.get(inp.dtype, int(DType.kBFloat16)), + None, + None, + None, + 0, + limit, + alpha, + 0, + ) + if quantizer is not None: + from transformer_engine.pytorch.tensor._quantize_stable import quantize_new + + out = quantize_new(out, quantizer) + return out + + +def clamped_dswiglu(grad, input, quantizer, limit, alpha): # pylint: disable=redefined-builtin + inp = input if isinstance(input, torch.Tensor) else input + out = torch.empty_like(inp) + _TORCH_TO_TE_DT = { + torch.float32: int(DType.kFloat32), + torch.float16: int(DType.kFloat16), + torch.bfloat16: int(DType.kBFloat16), + } + _ops.clamped_dactivation_noalloc( + grad, + inp, + out, + _TORCH_TO_TE_DT.get(inp.dtype, int(DType.kBFloat16)), + None, + None, + None, + 0, + limit, + alpha, + 0, + ) + if quantizer is not None: + from transformer_engine.pytorch.tensor._quantize_stable import quantize_new + + out = quantize_new(out, quantizer) + return out + + +# ============================================================================ +# Bias ops (match pybind11 individual function names) +# ============================================================================ + + +def bgrad_quantize(grad_output, quantizer): + """Compute bias gradient and optionally quantize grad_output. + + Mirrors pybind bgrad_quantize: compute grad_bias via sum, quantize grad_output + into an FP8 tensor (Float8Quantizer / MXFP8), or return unchanged otherwise. + """ + bias_size = grad_output.shape[-1] + grad_bias = grad_output.reshape(-1, bias_size).sum(dim=0) + if quantizer is None: + return [grad_bias, grad_output] + # Quantize grad_output into the appropriate FP8/quantized format + from transformer_engine.pytorch.tensor._quantize_stable import quantize_new + + grad_input = quantize_new(grad_output.contiguous(), quantizer) + return [grad_bias, grad_input] + + +def _make_dbias_dact(act_type, fused_act_type): + """Create a dbias+dactivation backward function. + + act_type: index into the full dact kernel table (0-10), used by dactivation_noalloc. + fused_act_type: index into the compact fused dact+dbias table in bias.cpp (0-4), + used by dact_dbias_noalloc. Mapping: 0=dgelu, 1=dsilu, 2=drelu, 3=dqgelu, 4=dsrelu. + """ + + def fn(grad_output, act_input, quantizer): + from transformer_engine.pytorch.tensor._extract import extract_tensor_data + + bias_size = act_input.shape[-1] + _TORCH_TO_TE_DT = { + torch.float32: int(DType.kFloat32), + torch.float16: int(DType.kFloat16), + torch.bfloat16: int(DType.kBFloat16), + } + in_te_dt = _TORCH_TO_TE_DT.get(act_input.dtype, int(DType.kBFloat16)) + device = act_input.device + + q_name = type(quantizer).__name__ if quantizer is not None else "" + is_mxfp8 = "MXFP8" in q_name + + # Float8Quantizer (delayed scaling) fused dact+dbias+quantize kernel requires + # output TensorWrapper with BOTH rowwise and columnwise buffers to work on Hopper + # (SM < 10.0). Our stable path only provides rowwise output, so use the unfused + # path for delayed FP8 and let update_usage() create columnwise lazily later. + if quantizer is None or not is_mxfp8: + # Unfused: compute dact in bf16, then sum for bias, then quantize separately + temp = torch.empty_like(act_input) + _ops.dactivation_noalloc( + grad_output, act_input, temp, in_te_dt, None, None, None, 0, act_type + ) + grad_bias = temp.view(-1, bias_size).sum(dim=0) + if quantizer is not None: + from transformer_engine.pytorch.tensor._quantize_stable import quantize_new + + grad_input = quantize_new(temp, quantizer) + else: + grad_input = temp + else: + # Fused path (MXFP8 only): use dact_dbias_noalloc with MXFP8 output. + # Float8Quantizer (delayed) is handled in the unfused branch above. + out_sm = 1 # MXFP8_1D=1 + out_te_dt = int(getattr(quantizer, "dtype", DType.kFloat8E4M3)) + out = quantizer.make_empty(list(act_input.shape), dtype=act_input.dtype, device=device) + out_data, _out_dtype, out_scale_inv, _ = extract_tensor_data(out) + out_amax = getattr(quantizer, "amax", None) + out_scale = getattr(quantizer, "scale", None) + if isinstance(out_amax, torch.Tensor) and out_amax.numel() == 0: + out_amax = None + if isinstance(out_scale, torch.Tensor) and out_scale.numel() == 0: + out_scale = None + grad_bias = torch.empty(bias_size, dtype=act_input.dtype, device=device) + _ops.dact_dbias_noalloc( + grad_output, + act_input, + grad_bias, + out_data, + out_te_dt, + out_amax, + out_scale, + out_scale_inv, + out_sm, + fused_act_type, + ) + grad_input = out + return [grad_bias, grad_input] + + return fn + + +# C++ dact_table order (full, for dactivation_noalloc): +# 0=dgelu, 1=dglu, 2=dgeglu, 3=dqgelu, 4=dqgeglu, +# 5=drelu, 6=dreglu, 7=dsrelu, 8=dsreglu, 9=dsilu, 10=dswiglu +# C++ fused_table order (compact, for dact_dbias_noalloc in bias.cpp): +# 0=dgelu, 1=dsilu, 2=drelu, 3=dqgelu, 4=dsrelu +dbias_dgelu = _make_dbias_dact(act_type=0, fused_act_type=0) +dbias_dsilu = _make_dbias_dact(act_type=9, fused_act_type=1) +dbias_drelu = _make_dbias_dact(act_type=5, fused_act_type=2) +dbias_dqgelu = _make_dbias_dact(act_type=3, fused_act_type=3) +dbias_dsrelu = _make_dbias_dact(act_type=7, fused_act_type=4) + +# ============================================================================ +# Grouped GEMM +# ============================================================================ + +_TORCH_DT = {torch.float32: 4, torch.float16: 5, torch.bfloat16: 6, torch.uint8: 0} +_TE_TO_TORCH_DT = { + 4: torch.float32, + 5: torch.float16, + 6: torch.bfloat16, + 0: torch.uint8, + 7: torch.uint8, # kFloat8E4M3 stored as uint8 + 8: torch.uint8, # kFloat8E5M2 stored as uint8 +} + + +def _quantizer_to_te_dtype(quantizer): + """Return TE DType int for a quantizer's output dtype (or kBFloat16 if unknown).""" + if quantizer is None: + return int(DType.kBFloat16) + dt = getattr(quantizer, "dtype", None) + if dt is not None: + return int(dt) + return int(DType.kBFloat16) + + +def _quantizer_to_scaling_mode(quantizer): + """Return NVTEScalingMode int for a quantizer.""" + if quantizer is None: + return 0 # DELAYED_TENSOR_SCALING + qname = type(quantizer).__name__ + if "MXFP8" in qname: + return 1 + if "NVFP4" in qname: + return 4 + if "Block" in qname: + block_dim = getattr(quantizer, "block_scaling_dim", 2) + return 3 if block_dim == 2 else 2 + return 0 # DELAYED_TENSOR_SCALING + + +def _grouped_tensor_to_stable_args(gt): + """Extract flat buffer args from a Python GroupedTensor for stable C++ grouped GEMM ops. + + Returns a tuple of 13 values matching the grouped_gemm C++ op parameter order: + (rowwise_data, columnwise_data, scale_inv, columnwise_scale_inv, + first_dims, last_dims, tensor_offsets, + te_dtype, scaling_mode, logical_0, logical_1, num_tensors, swizzled) + """ + quantizer = getattr(gt, "quantizer", None) + logical_shape = gt.logical_shape + return ( + gt.rowwise_data, + gt.columnwise_data, + gt.scale_inv, + gt.columnwise_scale_inv, + gt.first_dims, + gt.last_dims, + gt.tensor_offsets, + _quantizer_to_te_dtype(quantizer), + _quantizer_to_scaling_mode(quantizer), + logical_shape[0], + logical_shape[1], + gt.num_tensors, + bool(getattr(gt, "_with_gemm_swizzled_scales", False)), + ) + + +def te_general_grouped_gemm( # pylint: disable=unused-argument + A, + transa, + B, + transb, + D, + out_dtype, + m_splits, + bias, + bias_type, + single_output, + pre_gelu_out, + grad, + workspace, + workspace_size, + accumulate, + use_split_accumulator, + math_sm_count, +): + """Grouped GEMM via stable ABI: iterate and call _ops.gemm() for each pair. + + Replaces pybind11 te_general_grouped_gemm which calls nvte_multi_tensor_gemm. + Multi-stream parallelism is not preserved but correctness is maintained. + """ + from transformer_engine.pytorch.tensor._extract import extract_tensor_data + + num_gemms = len(A) + + # Workspace: multi-stream returns a list; single-stream is a Tensor + ws = workspace[0] if isinstance(workspace, (list, tuple)) else workspace + + # Handle single_output: D[0] is one flat tensor; slice into per-gemm sub-views + if single_output and D is not None: + assert m_splits is not None, "single_output requires m_splits" + flat_D = D[0] + D_list = [] + offset = 0 + for m in m_splits: + D_list.append(flat_D[offset : offset + m]) + offset += m + else: + D_list = list(D) if D is not None else [None] * num_gemms + + bias_type_int = int(bias_type) if bias_type is not None else int(DType.kBFloat16) + + for i in range(num_gemms): + Ai, Bi = A[i], B[i] + Di = D_list[i] + + # Handle empty pair (matches pybind zero-and-continue behaviour) + def _numel(t): + if isinstance(t, torch.Tensor): + return t.numel() + for attr in ("_data", "_rowwise_data", "_columnwise_data"): + d = getattr(t, attr, None) + if isinstance(d, torch.Tensor): + return d.numel() + return 1 # unknown type, assume non-empty + + if _numel(Ai) == 0 or _numel(Bi) == 0: + if Di is not None and Di.numel() > 0 and not accumulate: + Di.zero_() + if bias[i].numel() > 0 and grad: + bias[i].zero_() + if pre_gelu_out[i].numel() > 0: + pre_gelu_out[i].zero_() + continue + + A_data, A_te_dtype, A_si, A_sm, A_swizzled, A_cw, A_cw_si, A_amax_i, A_cw_amax_i = ( + _extract_gemm_operand(Ai, transa) + ) + B_data, B_te_dtype, B_si, B_sm, B_swizzled, B_cw, B_cw_si, B_amax_i, B_cw_amax_i = ( + _extract_gemm_operand(Bi, not transb) + ) + + # Mirror generic_gemm: compute on-the-fly transpose when delayed-scaling FP8 + # tensor is missing its columnwise buffer (e.g. _transpose_invalid=True). + _NVTE_DELAYED = 0 + if not transa and A_cw is None and A_sm == _NVTE_DELAYED and A_te_dtype in (7, 8): + A_cw = _ops.fp8_transpose(A_data, A_te_dtype, None) + A_cw_si = A_si + if transb and B_cw is None and B_sm == _NVTE_DELAYED and B_te_dtype in (7, 8): + B_cw = _ops.fp8_transpose(B_data, B_te_dtype, None) + B_cw_si = B_si + + if Di is None: + # Allocate output: column-major convention → shape (N, M) + _kFloat4E2M1 = 10 + A1 = A_data.shape[-1] + A0 = A_data.numel() // max(A1, 1) + B1 = B_data.shape[-1] + # FP4 packed: double last dim for logical shape + if A_te_dtype == _kFloat4E2M1: + A1 *= 2 + if B_te_dtype == _kFloat4E2M1: + B1 *= 2 + M = A0 if transa else A1 + N = B1 if transb else B_data.shape[-2] if B_data.ndim > 1 else B1 + out_te_int = int(out_dtype) if out_dtype is not None else int(DType.kBFloat16) + out_dt = _TE_TO_TORCH_DT.get(out_te_int, torch.bfloat16) + Di = torch.empty(N, M, dtype=out_dt, device=A_data.device) + + D_data, D_te_dtype, D_si, D_sm = extract_tensor_data(Di) + + bias_i = bias[i] if bias[i].numel() > 0 else None + gelu_i = pre_gelu_out[i] if pre_gelu_out[i].numel() > 0 else None + + # For grad=True with bias: the kernel writes dbias into bias_i in-place, + # which is already the pre-allocated grad_bias tensor passed by the caller. + _ops.gemm( + A_data, + A_te_dtype, + A_si, + A_cw, + A_cw_si, + A_sm, + A_swizzled, + transa, + B_data, + B_te_dtype, + B_si, + B_cw, + B_cw_si, + B_sm, + B_swizzled, + transb, + D_data, + D_te_dtype, + None, + None, + D_si, + D_sm, + bias_i, + bias_type_int, + gelu_i, + ws, + grad, + accumulate, + use_split_accumulator, + 1.0, + A_amax_i, + A_cw_amax_i, + B_amax_i, + B_cw_amax_i, + ) + + if single_output and D is not None: + # The D_list slice is already a view into D[0]; no copy needed + pass + + return bias + + +def te_general_grouped_gemm_for_grouped_tensor( + A, + transa, + B, + transb, + D, + bias, + alpha, + beta, + workspace_setup, + workspace_cublas, + use_split_accumulator, + math_sm_count, +): + """Grouped GEMM for GroupedTensor inputs (Blackwell+ nvte_grouped_gemm).""" + A_args = _grouped_tensor_to_stable_args(A) + B_args = _grouped_tensor_to_stable_args(B) + D_args = _grouped_tensor_to_stable_args(D) + + if bias is not None: + bias_args = _grouped_tensor_to_stable_args(bias) + has_bias = True + else: + bias_args = ( + None, + None, + None, + None, + None, + None, + None, + int(DType.kBFloat16), + 0, + 1, + 1, + 1, + False, + ) + has_bias = False + + _ops.grouped_gemm_for_grouped_tensor( + *A_args, + transa, + *B_args, + transb, + *D_args, + alpha, + beta, + workspace_setup, + workspace_cublas, + use_split_accumulator, + math_sm_count, + has_bias, + *bias_args, + ) + return D + + +def te_general_grouped_gemm_for_discrete_in( + A, + transa, + B, + transb, + D, + bias, + alpha, + beta, + workspace_setup, + workspace_cublas, + use_split_accumulator, + math_sm_count, +): + """Grouped GEMM with discrete A list, GroupedTensor B/D (Blackwell+).""" + B_args = _grouped_tensor_to_stable_args(B) + D_args = _grouped_tensor_to_stable_args(D) + + if bias is not None: + bias_args = _grouped_tensor_to_stable_args(bias) + has_bias = True + else: + bias_args = ( + None, + None, + None, + None, + None, + None, + None, + int(DType.kBFloat16), + 0, + 1, + 1, + 1, + False, + ) + has_bias = False + + # Pack A tensors: each element of A is an individual tensor (weight per expert). + # We pass per-tensor fields as flat packed int64 pointer tensors. + num_a = len(A) + device = B_args[0].device if B_args[0] is not None else alpha.device + A_rowwise_ptrs = torch.zeros(num_a, dtype=torch.int64, device=device) + A_colwise_ptrs = torch.zeros(num_a, dtype=torch.int64, device=device) + A_si_ptrs = torch.zeros(num_a, dtype=torch.int64, device=device) + A_csi_ptrs = torch.zeros(num_a, dtype=torch.int64, device=device) + A_shapes = torch.zeros(num_a, 2, dtype=torch.int64, device="cpu") + A_te_dtypes = torch.zeros(num_a, dtype=torch.int32, device="cpu") + A_scaling_modes = torch.zeros(num_a, dtype=torch.int32, device="cpu") + for i, Ai in enumerate(A): + ai_data, ai_dtype, ai_si, ai_sm, _, ai_cw, ai_cw_si, _, _ = _extract_gemm_operand( + Ai, transa + ) + if ai_data is not None and ai_data.numel() > 0: + A_rowwise_ptrs[i] = ai_data.data_ptr() + A_shapes[i, 0] = ai_data.shape[0] + A_shapes[i, 1] = ai_data.shape[1] if ai_data.ndim > 1 else 1 + if ai_cw is not None and ai_cw.numel() > 0: + A_colwise_ptrs[i] = ai_cw.data_ptr() + if ai_si is not None and ai_si.numel() > 0: + A_si_ptrs[i] = ai_si.data_ptr() + if ai_cw_si is not None and ai_cw_si.numel() > 0: + A_csi_ptrs[i] = ai_cw_si.data_ptr() + A_te_dtypes[i] = ai_dtype + A_scaling_modes[i] = ai_sm + + _ops.grouped_gemm_for_discrete_in( + A_rowwise_ptrs, + A_colwise_ptrs, + A_si_ptrs, + A_csi_ptrs, + A_shapes.to(device), + A_te_dtypes.to(device), + A_scaling_modes.to(device), + num_a, + *B_args, + transb, + *D_args, + alpha, + beta, + workspace_setup, + workspace_cublas, + use_split_accumulator, + math_sm_count, + has_bias, + *bias_args, + ) + return D + + +def te_general_grouped_gemm_for_discrete_out( # pylint: disable=unused-argument + A, + transa, + B, + transb, + D, + bias, + alpha, + beta, + workspace_setup, + workspace_cublas, + use_split_accumulator, + math_sm_count, +): + """Grouped GEMM with GroupedTensor A/B, discrete D list (Blackwell+).""" + A_args = _grouped_tensor_to_stable_args(A) + B_args = _grouped_tensor_to_stable_args(B) + + num_d = len(D) + device = A_args[0].device if A_args[0] is not None else alpha.device + D_rowwise_ptrs = torch.zeros(num_d, dtype=torch.int64, device=device) + D_si_ptrs = torch.zeros(num_d, dtype=torch.int64, device=device) + D_shapes = torch.zeros(num_d, 2, dtype=torch.int64, device="cpu") + D_te_dtypes = torch.zeros(num_d, dtype=torch.int32, device="cpu") + D_scaling_modes = torch.zeros(num_d, dtype=torch.int32, device="cpu") + from transformer_engine.pytorch.tensor._extract import extract_tensor_data + + for i, Di in enumerate(D): + d_data, d_dtype, d_si, d_sm = extract_tensor_data(Di) + if d_data is not None and d_data.numel() > 0: + D_rowwise_ptrs[i] = d_data.data_ptr() + D_shapes[i, 0] = d_data.shape[0] + D_shapes[i, 1] = d_data.shape[1] if d_data.ndim > 1 else 1 + if d_si is not None and d_si.numel() > 0: + D_si_ptrs[i] = d_si.data_ptr() + D_te_dtypes[i] = d_dtype + D_scaling_modes[i] = d_sm + + _ops.grouped_gemm_for_discrete_out( + *A_args, + transa, + *B_args, + transb, + D_rowwise_ptrs, + D_si_ptrs, + D_shapes.to(device), + D_te_dtypes.to(device), + D_scaling_modes.to(device), + num_d, + alpha, + beta, + workspace_setup, + workspace_cublas, + use_split_accumulator, + math_sm_count, + ) + return D + + +# ============================================================================ +# NVFP4 multi-tensor ops (iterate using single-tensor stable ops) +# ============================================================================ + + +def nvfp4_multi_tensor_fused_scale( + block_amax_list, + global_amax_list, + per_block_scale_list, + target_scale_list, + target_amax_list, + tile_rows_list, + tile_cols_list, + rows_padded_list, + block_len, +): + for i, block_amax in enumerate(block_amax_list): + _ops.nvfp4_fused_scale( + block_amax, + global_amax_list[i], + per_block_scale_list[i], + target_scale_list[i], + target_amax_list[i], + tile_rows_list[i], + tile_cols_list[i], + rows_padded_list[i], + block_len, + ) + + +def nvfp4_2d_multi_tensor_transpose( + rowwise_data_list, + columnwise_data_list, + rowwise_scale_inv_list, + columnwise_scale_inv_list, + M_list, + K_list, +): + for i, _rowwise_data in enumerate(rowwise_data_list): + _ops.nvfp4_data_transpose(_rowwise_data, columnwise_data_list[i]) + M = M_list[i] + K = K_list[i] + M_tiles = (M + 15) // 16 + K_tiles = (K + 15) // 16 + _ops.nvfp4_2d_scale_transpose( + rowwise_scale_inv_list[i], columnwise_scale_inv_list[i], M_tiles, K_tiles + ) + + +def nvfp4_multi_tensor_2d_partial_cast( + inp_list, + out_list, + scale_list, + global_scale_list, + h_list, + w_list, + start_offset_list, + block_len=16, +): + for i, _inp in enumerate(inp_list): + # out_list[i] may be a quantized tensor — extract raw data + out = out_list[i] + if isinstance(out, torch.Tensor): + _ops.nvfp4_2d_partial_cast_noalloc( + inp_list[i], + out, + 0, # kByte — raw uint8 data buffer type + None, + 4, + scale_list[i], + global_scale_list[i], + h_list[i], + w_list[i], + start_offset_list[i], + block_len, + ) + + +def nvfp4_multi_tensor_compute_partial_amax( + master_weight_list, + partial_amax_list, + global_amax_list, + h_list, + w_list, + start_offset_list, + block_len=16, +): + for i, _master_weight in enumerate(master_weight_list): + _ops.nvfp4_2d_compute_partial_amax( + _master_weight, + partial_amax_list[i], + h_list[i], + w_list[i], + start_offset_list[i], + block_len, + ) + _ops.compute_amax(partial_amax_list[i], global_amax_list[i]) + + +# ============================================================================ +# Multi-tensor ops (match pybind11 signatures with pointer packing) +# ============================================================================ + + +def _pack_tensor_lists(tensor_lists): + """Pack tensor lists into flat int64 tensors for the pointer-pack pattern.""" + num_lists = len(tensor_lists) + num_tensors = len(tensor_lists[0]) + ptrs = torch.tensor([t.data_ptr() for lst in tensor_lists for t in lst], dtype=torch.int64) + shapes = torch.tensor( + [[t.numel(), t.element_size()] for lst in tensor_lists for t in lst], dtype=torch.int64 + ).flatten() + _TORCH_DT = { + torch.float32: 4, + torch.float16: 5, + torch.bfloat16: 6, + torch.uint8: 0, + torch.int16: 1, + torch.int32: 2, + torch.int64: 3, + torch.bool: 0, + } + dtypes = torch.tensor( + [_TORCH_DT.get(t.dtype, 4) for lst in tensor_lists for t in lst], dtype=torch.int64 + ) + return ptrs, shapes, dtypes, num_lists, num_tensors + + +def multi_tensor_scale(chunk_size, is_infinite, tensor_lists, scale): + ptrs, shapes, dtypes, nl, nt = _pack_tensor_lists(tensor_lists) + _ops.multi_tensor_scale(chunk_size, is_infinite, ptrs, shapes, dtypes, nl, nt, scale) + + +def multi_tensor_scale_tensor(chunk_size, is_infinite, tensor_lists, scale): + ptrs, shapes, dtypes, nl, nt = _pack_tensor_lists(tensor_lists) + _ops.multi_tensor_scale_tensor(chunk_size, is_infinite, ptrs, shapes, dtypes, nl, nt, scale) + + +def multi_tensor_l2norm(chunk_size, noop_flag, tensor_lists, per_tensor=False): + ptrs, shapes, dtypes, nl, nt = _pack_tensor_lists(tensor_lists) + return _ops.multi_tensor_l2norm(chunk_size, noop_flag, ptrs, shapes, dtypes, nl, nt, per_tensor) + + +def multi_tensor_unscale_l2norm(chunk_size, noop_flag, tensor_lists, inv_scale, per_tensor=False): + ptrs, shapes, dtypes, nl, nt = _pack_tensor_lists(tensor_lists) + return _ops.multi_tensor_unscale_l2norm( + chunk_size, noop_flag, ptrs, shapes, dtypes, nl, nt, inv_scale, per_tensor + ) + + +def multi_tensor_adam( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, +): + ptrs, shapes, dtypes, nl, nt = _pack_tensor_lists(tensor_lists) + _ops.multi_tensor_adam( + chunk_size, + noop_flag, + ptrs, + shapes, + dtypes, + nl, + nt, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + ) + + +def multi_tensor_adam_capturable( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + inv_scale, +): + ptrs, shapes, dtypes, nl, nt = _pack_tensor_lists(tensor_lists) + _ops.multi_tensor_adam_capturable( + chunk_size, + noop_flag, + ptrs, + shapes, + dtypes, + nl, + nt, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + inv_scale, + ) + + +def multi_tensor_adam_capturable_master( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + inv_scale, +): + ptrs, shapes, dtypes, nl, nt = _pack_tensor_lists(tensor_lists) + _ops.multi_tensor_adam_capturable_master( + chunk_size, + noop_flag, + ptrs, + shapes, + dtypes, + nl, + nt, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + inv_scale, + ) + + +def multi_tensor_adam_param_remainder( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, +): + ptrs, shapes, dtypes, nl, nt = _pack_tensor_lists(tensor_lists) + _ops.multi_tensor_adam_param_remainder( + chunk_size, + noop_flag, + ptrs, + shapes, + dtypes, + nl, + nt, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + ) + + +def multi_tensor_adam_fp8( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + fp8_dtype, +): + ptrs, shapes, dtypes, nl, nt = _pack_tensor_lists(tensor_lists) + _ops.multi_tensor_adam_fp8( + chunk_size, + noop_flag, + ptrs, + shapes, + dtypes, + nl, + nt, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + int(fp8_dtype), + ) + + +def multi_tensor_sgd( + chunk_size, + noop_flag, + tensor_lists, + wd, + momentum, + dampening, + lr, + nesterov, + first_run, + wd_after_momentum, + scale, +): + ptrs, shapes, dtypes, nl, nt = _pack_tensor_lists(tensor_lists) + _ops.multi_tensor_sgd( + chunk_size, + noop_flag, + ptrs, + shapes, + dtypes, + nl, + nt, + wd, + momentum, + dampening, + lr, + nesterov, + first_run, + wd_after_momentum, + scale, + ) + + +def multi_tensor_compute_scale_and_scale_inv( + chunk_size, noop_flag, tensor_lists, max_fp8, force_pow_2_scales=False, epsilon=0.0 +): + ptrs, shapes, dtypes, nl, nt = _pack_tensor_lists(tensor_lists) + _ops.multi_tensor_compute_scale_and_scale_inv( + chunk_size, noop_flag, ptrs, shapes, dtypes, nl, nt, max_fp8, force_pow_2_scales, epsilon + ) + + +def multi_tensor_compute_scale_inv_e8m0( + chunk_size, dummy, tensor_lists +): # pylint: disable=unused-argument + ptrs, shapes, dtypes, nl, nt = _pack_tensor_lists(tensor_lists) + # Pass a dummy CUDA tensor to drive dispatch to CUDA backend + dummy_cuda = torch.empty(1, device="cuda", dtype=torch.int64) + _ops.multi_tensor_compute_scale_inv_e8m0(chunk_size, dummy_cuda, ptrs, shapes, dtypes, nl, nt) + + +# ============================================================================ +# CommOverlap types and classes +# ============================================================================ + + +class CommOverlapType(IntEnum): + RS = 0 + AG = 1 + + +class CommOverlapAlgo(IntEnum): + BULK_OVERLAP_AG = 0 + BULK_OVERLAP_RS = 1 + SPLIT_PIPELINED_AG_P2P = 2 + SPLIT_PIPELINED_RS = 3 + SPLIT_PIPELINED_RS_P2P = 4 + ATOMIC_GEMM_RS = 5 + ATOMIC_GEMM_AG_P2P = 6 + ATOMIC_GEMM_RS_P2P = 7 + EXTERNAL_BULK_OVERLAP_AG = 8 + + +class Float8BlockScaleTensorFormat(IntEnum): + GEMM_READY = 0 + COMPACT = 1 + INVALID = 2 + + +class CommOverlapCore: + def __init__(self): + pass + + def is_atomic_gemm(self): + return False + + def is_p2p_overlap(self): + return False + + def is_fp8_ubuf(self): + return False + + +class CommOverlapBase(CommOverlapCore): + pass + + +class CommOverlapP2PBase(CommOverlapCore): + pass + + +_AllgatherCB = _ctypes.CFUNCTYPE( + None, + _ctypes.c_void_p, + _ctypes.c_size_t, + _ctypes.c_void_p, + _ctypes.c_size_t, + _ctypes.c_char_p, +) +_BarrierCB = _ctypes.CFUNCTYPE(None, _ctypes.c_char_p) + +_TORCH_TO_TE_DTYPE = { + torch.float32: 4, + torch.float16: 5, + torch.bfloat16: 6, + torch.uint8: 0, + torch.int8: 0, +} + + +class CommOverlapHelper: + """Python replacement for pybind11 CommOverlapHelper.""" + + def __init__(self, world_group=None, intra_domain_group=None): + if world_group is None: + raise RuntimeError( + "CommOverlapHelper requires a process group (MPI-only builds " + "are not supported in the stable ABI path)" + ) + self.myrank = torch.distributed.get_rank(world_group) + self.numranks = torch.distributed.get_world_size(world_group) + backend = torch.distributed.get_backend(world_group) + self.backend_is_nccl = backend == "nccl" + if intra_domain_group is not None: + self.mylocal = torch.distributed.get_rank(intra_domain_group) + self.numlocal = torch.distributed.get_world_size(intra_domain_group) + if self.numlocal == self.numranks: + self.mynode, self.numnodes = 0, 1 + else: + self.mynode = self.myrank // self.numlocal + self.numnodes = self.numranks // self.numlocal + else: + self.mylocal = self.myrank + self.numlocal = self.numranks + self.mynode, self.numnodes = 0, 1 + self._groups = { + "world": world_group, + "intra": intra_domain_group if intra_domain_group is not None else world_group, + } + self.initialized = True + + def ub_allgather(self, globaldata_ptr, globalbytes, localdata_ptr, localbytes, group_name): + group = self._groups.get(group_name, self._groups["world"]) + num_ranks = torch.distributed.get_world_size(group) + local_buf = (_ctypes.c_uint8 * localbytes).from_address(localdata_ptr) + local_tensor = torch.frombuffer(local_buf, dtype=torch.uint8).clone() + if self.backend_is_nccl: + local_tensor = local_tensor.cuda() + chunks = [torch.empty_like(local_tensor) for _ in range(num_ranks)] + torch.distributed.all_gather(chunks, local_tensor, group=group) + global_tensor = torch.cat(chunks) + if self.backend_is_nccl: + global_tensor = global_tensor.cpu() + _ctypes.memmove(globaldata_ptr, global_tensor.data_ptr(), globalbytes) + + def ub_barrier(self, group_name): + group = self._groups.get(group_name, self._groups["world"]) + torch.distributed.barrier(group=group) + + +def _make_comm_callbacks(helper): + """Create ctypes callback objects bound to a CommOverlapHelper.""" + h = helper + + @_AllgatherCB + def _ag_cb(gptr, gb, lptr, lb, grp): + h.ub_allgather(gptr, gb, lptr, lb, grp.decode() if grp else "world") + + @_BarrierCB + def _bar_cb(grp): + h.ub_barrier(grp.decode() if grp else "world") + + return _ag_cb, _bar_cb + + +class CommOverlap: + """Python replacement for pybind11 CommOverlap (handle-based stable ABI).""" + + def __init__( + self, + shape, + dtype, + helper, + tp_size, + num_splits=3, + num_max_streams=3, + comm_cga_size=2, + gemm_priority=0, + comm_priority=0, + num_comm_sm=16, + set_sm_margin=True, + atomic_gemm=False, + rs_overlap_first_gemm=False, + ): + self._ag_cb, self._bar_cb = _make_comm_callbacks(helper) + ag_ptr = _ctypes.cast(self._ag_cb, _ctypes.c_void_p).value + bar_ptr = _ctypes.cast(self._bar_cb, _ctypes.c_void_p).value + _ops.register_comm_callbacks(ag_ptr, bar_ptr) + + buf_dtype = _TORCH_TO_TE_DTYPE.get(dtype, 6) + self._handle = _ops.create_comm_overlap( + list(shape), + buf_dtype, + helper.myrank, + helper.numranks, + helper.mylocal, + helper.numlocal, + helper.mynode, + helper.numnodes, + tp_size, + num_splits, + num_max_streams, + comm_cga_size, + gemm_priority, + comm_priority, + num_comm_sm, + set_sm_margin, + atomic_gemm, + rs_overlap_first_gemm, + ) + + def copy_into_buffer(self, input, local_chunk=False): # pylint: disable=redefined-builtin + _ops.comm_overlap_copy_into_buffer(input, self._handle, local_chunk) + + def get_buffer(self, local_chunk=False, shape=None): + if shape is not None and len(shape) >= 2: + dim0, dim1 = shape[0], shape[1] + else: + dim0, dim1 = -1, -1 + return _ops.comm_overlap_get_buffer(self._handle, local_chunk, dim0, dim1) + + def get_communication_stream(self): + raw = _ops.comm_overlap_get_stream(self._handle) + s = torch.cuda.ExternalStream(raw) + return s, s # send == recv for non-P2P + + def is_fp8_ubuf(self): + return _ops.comm_overlap_is_fp8_ubuf(self._handle) + + def is_atomic_gemm(self): + return _ops.comm_overlap_is_atomic_gemm(self._handle) + + def is_p2p_overlap(self): + return _ops.comm_overlap_is_p2p(self._handle) + + def __del__(self): + if hasattr(self, "_handle"): + _ops.destroy_comm_overlap(self._handle) + + +class CommOverlapP2P: + """Python replacement for pybind11 CommOverlapP2P (handle-based stable ABI).""" + + def __init__( + self, + shape, + dtype, + helper, + tp_size, + comm_type, + num_max_streams=3, + comm_cga_size=1, + gemm_priority=0, + comm_priority=0, + num_comm_sm=1, + set_sm_margin=False, + use_ce=True, + atomic_gemm=False, + aggregate=False, + ): + self._ag_cb, self._bar_cb = _make_comm_callbacks(helper) + ag_ptr = _ctypes.cast(self._ag_cb, _ctypes.c_void_p).value + bar_ptr = _ctypes.cast(self._bar_cb, _ctypes.c_void_p).value + _ops.register_comm_callbacks(ag_ptr, bar_ptr) + + buf_dtype = _TORCH_TO_TE_DTYPE.get(dtype, 6) + self._handle = _ops.create_comm_overlap_p2p( + list(shape), + buf_dtype, + helper.myrank, + helper.numranks, + helper.mylocal, + helper.numlocal, + helper.mynode, + helper.numnodes, + tp_size, + int(comm_type), + num_max_streams, + comm_cga_size, + gemm_priority, + comm_priority, + num_comm_sm, + set_sm_margin, + use_ce, + atomic_gemm, + aggregate, + ) + + def copy_into_buffer(self, input, local_chunk=False): # pylint: disable=redefined-builtin + _ops.comm_overlap_copy_into_buffer(input, self._handle, local_chunk) + + def get_buffer(self, local_chunk=False, shape=None): + if shape is not None and len(shape) >= 2: + dim0, dim1 = shape[0], shape[1] + else: + dim0, dim1 = -1, -1 + return _ops.comm_overlap_get_buffer(self._handle, local_chunk, dim0, dim1) + + def get_communication_stream(self): + send_raw, recv_raw = _ops.comm_overlap_p2p_get_streams(self._handle) + return torch.cuda.ExternalStream(send_raw), torch.cuda.ExternalStream(recv_raw) + + def is_fp8_ubuf(self): + return _ops.comm_overlap_is_fp8_ubuf(self._handle) + + def is_atomic_gemm(self): + return _ops.comm_overlap_is_atomic_gemm(self._handle) + + def is_p2p_overlap(self): + return _ops.comm_overlap_is_p2p(self._handle) + + def __del__(self): + if hasattr(self, "_handle"): + _ops.destroy_comm_overlap_p2p(self._handle) + + +# ============================================================================ +# Fused attention (match pybind11 signatures) +# ============================================================================ + + +def fused_attn_fwd( # pylint: disable=unused-argument + max_seqlen_q, + max_seqlen_kv, + is_training, + attn_scale, + p_dropout, + set_zero, + qkv_layout, + bias_type, + attn_mask_type, + softmax_type, + window_size, + bottom_right_diagonal, + cu_seqlens_q, + cu_seqlens_kv, + Q, + K, + V, + fake_dtype, + cu_seqlens_q_padded=None, + cu_seqlens_kv_padded=None, + page_table_k=None, + page_table_v=None, + s_quantizer=None, + o_quantizer=None, + Bias=None, + SoftmaxOffset=None, + rng_gen=None, + rng_elts_per_thread=0, + return_max_logit=False, + cuda_graph=False, +): + """Fused attention forward via stable ABI fused_attn_fwd_noalloc.""" + from transformer_engine.pytorch.tensor._extract import extract_tensor_data + + # Extract Q/K/V raw buffers + Q_data, Q_dtype_int, Q_si, Q_sm = extract_tensor_data(Q) + K_data, K_dtype_int, K_si, K_sm = extract_tensor_data(K) + V_data, V_dtype_int, V_si, V_sm = extract_tensor_data(V) + + device = Q_data.device + _TORCH_DT = {torch.float32: 4, torch.float16: 5, torch.bfloat16: 6, torch.uint8: 0} + + # Determine O dtype from fake_dtype + if isinstance(fake_dtype, torch.dtype): + O_torch_dtype = fake_dtype + O_dtype_int = _TORCH_DT.get(fake_dtype, Q_dtype_int) + else: + O_dtype_int = Q_dtype_int + O_torch_dtype = Q_data.dtype + + # Allocate O: Q shape with V's last dim + O_shape = list(Q_data.shape) + O_shape[-1] = V_data.shape[-1] + if o_quantizer is not None: + O_tensor = o_quantizer.make_empty(O_shape, dtype=O_torch_dtype, device=device) + O_data, O_dtype_int, O_si, O_sm = extract_tensor_data(O_tensor) + O_amax = getattr(o_quantizer, "amax", None) + O_scale = getattr(o_quantizer, "scale", None) + # Initialize scale_inv = 1/scale (pybind does this in Float8Quantizer::create_tensor). + # make_empty leaves _scale_inv uninitialized; the NVTE kernel does NOT write it. + if O_scale is not None and O_si is not None: + O_si.copy_(O_scale.float().reciprocal()) + else: + O_tensor = torch.empty(O_shape, dtype=O_torch_dtype, device=device) + O_data, _, O_si, O_sm = O_tensor, O_dtype_int, None, 0 + O_amax, O_scale = None, None + + # Allocate S (softmax placeholder — shape determined by kernel on first pass) + if s_quantizer is not None: + S_tensor = s_quantizer.make_empty([0], dtype=torch.float32, device=device) + S_data, S_dtype_int, S_si, S_sm = extract_tensor_data(S_tensor) + S_amax = getattr(s_quantizer, "amax", None) + S_scale = getattr(s_quantizer, "scale", None) + # Initialize scale_inv = 1/scale (same as O above) + if S_scale is not None and S_si is not None: + S_si.copy_(S_scale.float().reciprocal()) + else: + S_tensor = torch.empty([0], dtype=torch.float32, device=device) + S_data, S_dtype_int, S_si, S_sm = S_tensor, 4, None, 0 + S_amax, S_scale = None, None + + # rng_state [seed, offset] — zeros for p_dropout=0; for training with dropout + # the kernel writes the actual used state into the aux tensor pack. + rng_state = torch.zeros([2], dtype=torch.int64, device=device) + + result = _ops.fused_attn_fwd_noalloc( + int(max_seqlen_q), + int(max_seqlen_kv), + bool(is_training), + float(attn_scale), + float(p_dropout), + bool(set_zero), + int(qkv_layout), + int(bias_type), + int(attn_mask_type), + int(softmax_type), + list(window_size), + bool(bottom_right_diagonal), + cu_seqlens_q, + cu_seqlens_kv, + Q_data, + Q_dtype_int, + Q_si, + Q_sm, + K_data, + K_dtype_int, + K_si, + K_sm, + V_data, + V_dtype_int, + V_si, + V_sm, + S_data, + S_dtype_int, + S_amax, + S_scale, + S_si, + S_sm, + O_data, + O_dtype_int, + O_amax, + O_scale, + O_si, + O_sm, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + page_table_k, + page_table_v, + Bias, + SoftmaxOffset, + rng_state, + bool(return_max_logit), + bool(cuda_graph), + ) + # result = (aux0, ..., aux9, num_aux) + *aux_tensors, num_aux = result + # Return format matches pybind: [O, stats, (max if applicable), rng_state, ...] + return [O_tensor] + list(aux_tensors[:num_aux]) + + +_te_lib_handle = None + + +def _get_te_lib(): + """Get cached ctypes handle to libtransformer_engine.so.""" + global _te_lib_handle + if _te_lib_handle is not None: + return _te_lib_handle + te_spec = importlib.util.find_spec("transformer_engine") + if te_spec is not None and te_spec.origin is not None: + te_dir = Path(te_spec.origin).parent.parent + candidates = glob.glob(str(te_dir / "libtransformer_engine*.so")) + if candidates: + _te_lib_handle = _ctypes.CDLL(candidates[0]) + return _te_lib_handle + raise RuntimeError("Could not find libtransformer_engine.so") + + +def _get_qkv_layout_group(qkv_layout_int): + """Call nvte_get_qkv_layout_group via ctypes. Returns int layout group.""" + try: + lib = _get_te_lib() + fn = lib.nvte_get_qkv_layout_group + fn.restype = _ctypes.c_int + fn.argtypes = [_ctypes.c_int] + return fn(qkv_layout_int) + except RuntimeError: + return 4 # fallback: NVTE_HD_HD_HD + + +def fused_attn_bwd( + max_seqlen_q, + max_seqlen_kv, + attn_scale, + p_dropout, + set_zero, + qkv_layout, + bias_type, + attn_mask_type, + softmax_type, + window_size, + bottom_right_diagonal, + deterministic, + cu_seqlens_q, + cu_seqlens_kv, + Q, + K, + V, + O, + dO, + fake_dtype, + dqkv_dtype, + aux_ctx_tensors, + cu_seqlens_q_padded=None, + cu_seqlens_kv_padded=None, + s_quantizer=None, + dp_quantizer=None, + dqkv_quantizer=None, + cuda_graph=False, +): + """Fused attention backward via stable ABI fused_attn_bwd_packed.""" + from transformer_engine.pytorch.tensor._extract import extract_tensor_data + + _TORCH_DT = {torch.float32: 4, torch.float16: 5, torch.bfloat16: 6, torch.uint8: 0} + _TE_TO_TORCH_DT = { + 4: torch.float32, + 5: torch.float16, + 6: torch.bfloat16, + 0: torch.uint8, + 7: torch.uint8, # kFloat8E4M3 stored as uint8 + 8: torch.uint8, # kFloat8E5M2 stored as uint8 + } + + # Extract input tensor data + Q_data, Q_dtype, Q_si, Q_sm = extract_tensor_data(Q) + K_data, K_dtype, K_si, K_sm = extract_tensor_data(K) + V_data, V_dtype, V_si, V_sm = extract_tensor_data(V) + O_data, O_dtype, O_si, O_sm = extract_tensor_data(O) + dO_data, dO_dtype, dO_si, dO_sm = extract_tensor_data(dO) + + device = Q_data.device + + # S and dP: empty placeholder tensors for backward. + # For FP8, the pybind version creates these via quantizer_helper(s_quantizer, {0}, ...) + # which sets up FP8 dtype and scale_inv. Replicate by detecting quantizer type. + S_data = torch.empty([0], dtype=torch.float32, device=device) + S_dtype = 4 # kFloat32 + S_amax, S_scale, S_si, S_sm = None, None, None, 0 + dP_data = torch.empty([0], dtype=torch.float32, device=device) + dP_dtype = 4 + dP_amax, dP_scale, dP_si, dP_sm = None, None, None, 0 + + def _fp8_quantizer_metadata(quantizer): + """Extract FP8 metadata from a quantizer for S/dP tensors.""" + q_dtype = getattr(quantizer, "dtype", None) + if q_dtype is None: + return None + # dtype may be torch.dtype or integer TE enum + _FP8_DTYPES = {torch.float8_e4m3fn, torch.float8_e5m2, 7, 8} + if q_dtype not in _FP8_DTYPES: + return None + te_dt = 7 if q_dtype in (torch.float8_e4m3fn, 7) else 8 + q_scale = getattr(quantizer, "scale", None) + # Use the quantizer's amax tensor so the kernel's amax output flows back to + # FP8GlobalStateManager's amax_history buffer (matching the pybind path where + # Float8Quantizer::set_quantization_params wires the TensorWrapper's amax). + q_amax = getattr(quantizer, "amax", None) + amax = ( + q_amax if q_amax is not None else torch.zeros([1], dtype=torch.float32, device=device) + ) + if q_scale is not None: + scale = q_scale.clone().detach().to(dtype=torch.float32, device=device).reshape(1) + si = (1.0 / q_scale).to(dtype=torch.float32, device=device).reshape(1) + else: + scale = torch.ones([1], dtype=torch.float32, device=device) + si = torch.ones([1], dtype=torch.float32, device=device) + return te_dt, amax, scale, si + + if s_quantizer is not None: + meta = _fp8_quantizer_metadata(s_quantizer) + if meta is not None: + S_data = torch.empty([0], dtype=torch.uint8, device=device) + S_dtype, S_amax, S_scale, S_si = meta + S_sm = 0 # NVTE_DELAYED_TENSOR_SCALING + if dp_quantizer is not None: + meta = _fp8_quantizer_metadata(dp_quantizer) + if meta is not None: + dP_data = torch.empty([0], dtype=torch.uint8, device=device) + dP_dtype, dP_amax, dP_scale, dP_si = meta + dP_sm = 0 # NVTE_DELAYED_TENSOR_SCALING + + # Determine output grad dtype. + # For Float8CurrentScalingQuantizer, the pybind version allocates dQ/dK/dV as + # fake_dtype (BF16), NOT uint8. Replicate that here. + _is_current_scaling = ( + dqkv_quantizer is not None and "CurrentScaling" in type(dqkv_quantizer).__name__ + ) + if dqkv_dtype is not None: + dqkv_te_dtype = int(dqkv_dtype) + if _is_current_scaling: + # Current scaling: output is high-precision with amax tracking + dqkv_torch_dtype = fake_dtype if isinstance(fake_dtype, torch.dtype) else Q_data.dtype + else: + dqkv_torch_dtype = _TE_TO_TORCH_DT.get(dqkv_te_dtype, Q_data.dtype) + elif isinstance(fake_dtype, torch.dtype): + dqkv_torch_dtype = fake_dtype + dqkv_te_dtype = _TORCH_DT.get(fake_dtype, Q_dtype) + else: + dqkv_torch_dtype = Q_data.dtype + dqkv_te_dtype = Q_dtype + + Q_shape = list(Q_data.shape) + K_shape = list(K_data.shape) + V_shape = list(V_data.shape) + + # Allocate dQ/dK/dV based on layout group. + # IMPORTANT: for packed layouts, dQ/dK/dV must be NON-CONTIGUOUS VIEWS of the packed + # tensor (dQKV / dKV), NOT separate .contiguous() copies. cuDNN computes the output + # gradient stride from the qkv_layout flag (e.g. NVTE_3HD → stride [3*B*H*D, 3*H*D, D, 1]) + # and writes using that stride starting from dQ.data_ptr(). If dQ.data_ptr() points to + # a small contiguous tensor (only S*B*H*D elements) the write overflows → illegal address. + # The pybind11 backend does the same (extensions/attention.cpp:367–378). + layout_group = _get_qkv_layout_group(int(qkv_layout)) + if layout_group == 0: # NVTE_3HD: packed dQKV with 3 in third-to-last dim + dQKV_shape = Q_shape[:-2] + [3] + Q_shape[-2:] + dQKV = torch.empty(dQKV_shape, dtype=dqkv_torch_dtype, device=device) + dQ = dQKV[..., 0, :, :] # non-contiguous view; data_ptr = dQKV.data_ptr() + dK = dQKV[..., 1, :, :] # non-contiguous view; data_ptr = dQKV.data_ptr() + H*D*sizeof + dV = dQKV[..., 2, :, :] # non-contiguous view; data_ptr = dQKV.data_ptr() + 2*H*D*sizeof + elif layout_group == 1: # NVTE_H3D: packed dQKV with 3 in second-to-last + dQKV_shape = Q_shape[:-1] + [3, Q_shape[-1]] + dQKV = torch.empty(dQKV_shape, dtype=dqkv_torch_dtype, device=device) + dQ = dQKV[..., 0, :] # non-contiguous view + dK = dQKV[..., 1, :] + dV = dQKV[..., 2, :] + elif layout_group == 2: # NVTE_HD_2HD + dQ = torch.empty(Q_shape, dtype=dqkv_torch_dtype, device=device) + dKV_shape = K_shape[:-2] + [2] + K_shape[-2:] + dKV = torch.empty(dKV_shape, dtype=dqkv_torch_dtype, device=device) + dK = dKV[..., 0, :, :] # non-contiguous view + dV = dKV[..., 1, :, :] + elif layout_group == 3: # NVTE_HD_H2D + dQ = torch.empty(Q_shape, dtype=dqkv_torch_dtype, device=device) + dKV_shape = K_shape[:-1] + [2, K_shape[-1]] + dKV = torch.empty(dKV_shape, dtype=dqkv_torch_dtype, device=device) + dK = dKV[..., 0, :] # non-contiguous view + dV = dKV[..., 1, :] + else: # NVTE_HD_HD_HD (4) and Paged_KV (5) + dQ = torch.empty(Q_shape, dtype=dqkv_torch_dtype, device=device) + dK = torch.empty(K_shape, dtype=dqkv_torch_dtype, device=device) + dV = torch.empty(V_shape, dtype=dqkv_torch_dtype, device=device) + + dQ_data, dQ_te_dtype, dQ_si, dQ_sm = extract_tensor_data(dQ) + dK_data, dK_te_dtype, dK_si, dK_sm = extract_tensor_data(dK) + dV_data, dV_te_dtype, dV_si, dV_sm = extract_tensor_data(dV) + dQ_amax = dQ_scale = dK_amax = dK_scale = dV_amax = dV_scale = None + if dqkv_te_dtype in (7, 8) and not _is_current_scaling: + # Delayed scaling: dQ/dK/dV are FP8 (uint8), need scale/scale_inv/amax. + # The pybind version uses Float8Quantizer::create_tensor which computes + # scale_inv = 1/scale from the quantizer's scale. Replicate that here. + dQ_te_dtype = dqkv_te_dtype + dK_te_dtype = dqkv_te_dtype + dV_te_dtype = dqkv_te_dtype + dqkv_q_scale = getattr(dqkv_quantizer, "scale", None) if dqkv_quantizer else None + if dqkv_q_scale is not None: + dqkv_scale_f32 = dqkv_q_scale.detach().to(dtype=torch.float32, device=device).reshape(1) + dqkv_si_f32 = (1.0 / dqkv_q_scale).to(dtype=torch.float32, device=device).reshape(1) + else: + dqkv_scale_f32 = torch.ones([1], dtype=torch.float32, device=device) + dqkv_si_f32 = torch.ones([1], dtype=torch.float32, device=device) + # Use the quantizer's amax tensor so the kernel's amax output flows back + # to FP8GlobalStateManager's amax_history buffer. The pybind path achieves + # this via Float8Quantizer::create_tensor → set_quantization_params, which + # wires the TensorWrapper's amax to the quantizer's amax. Without this, the + # backward amax never updates, causing stale scales in batch test runs. + dqkv_q_amax = getattr(dqkv_quantizer, "amax", None) + dQ_amax = ( + dqkv_q_amax + if dqkv_q_amax is not None + else torch.zeros([1], dtype=torch.float32, device=device) + ) + dQ_scale = dqkv_scale_f32 + dQ_si = dqkv_si_f32.clone() + dK_amax = ( + dqkv_q_amax + if dqkv_q_amax is not None + else torch.zeros([1], dtype=torch.float32, device=device) + ) + dK_scale = dqkv_scale_f32.clone() + dK_si = dqkv_si_f32.clone() + dV_amax = ( + dqkv_q_amax + if dqkv_q_amax is not None + else torch.zeros([1], dtype=torch.float32, device=device) + ) + dV_scale = dqkv_scale_f32.clone() + dV_si = dqkv_si_f32.clone() + elif _is_current_scaling: + # Current scaling: dQ/dK/dV are high-precision (BF16/FP16) with amax tracking. + # The pybind version uses create_unquantized_tensor_with_amax → NoneQuantizer tensor. + # dQ/dK/dV are already allocated as fake_dtype, extract_tensor_data returns BF16 dtype. + # Just need amax for cuDNN. + dQ_amax = torch.zeros([1], dtype=torch.float32, device=device) + dK_amax = torch.zeros([1], dtype=torch.float32, device=device) + dV_amax = torch.zeros([1], dtype=torch.float32, device=device) + + # dBias: allocate when bias_type not in {NO_BIAS=0, ALIBI=3} + # The pybind version derives dBias shape from the last aux_ctx_tensor (the saved bias + # from forward) when aux_ctx_tensors has >= 2 entries. This is critical because the + # bias may have batch dimension > 1, e.g. [b, h, s, s] instead of [1, h, s, s]. + # Hardcoding [1, h, s, s] causes cuDNN dimension-mismatch errors on backward. + num_heads_q = Q_shape[-2] if len(Q_shape) >= 2 else 1 + num_aux = len(aux_ctx_tensors) if aux_ctx_tensors else 0 + dBias = None + if int(bias_type) not in (0, 3): # 0=NVTE_NO_BIAS, 3=NVTE_ALIBI + if num_aux >= 2: + bias_shape = list(aux_ctx_tensors[-1].shape) + dBias = torch.empty(bias_shape, dtype=dqkv_torch_dtype, device=device) + else: + dBias = torch.empty( + [1, num_heads_q, max_seqlen_q, max_seqlen_kv], + dtype=dqkv_torch_dtype, + device=device, + ) + # For THD format, cuDNN accumulates into dBias, so it must be zero-initialized. + # Use zeros for safety (the pybind version only zero-fills for THD, but zeros + # is always correct and the perf difference is negligible). + dBias.zero_() + + # dSoftmaxOffset: allocate when softmax_type != VANILLA (0) + dSoftmaxOffset = None + if int(softmax_type) != 0: + dSoftmaxOffset = torch.zeros([1, num_heads_q, 1, 1], dtype=torch.float32, device=device) + + # Pack dtype/scaling_mode info into a CPU int64 tensor + dtype_info = torch.tensor( + [ + Q_dtype, + Q_sm, + K_dtype, + K_sm, + V_dtype, + V_sm, + O_dtype, + O_sm, + dO_dtype, + dO_sm, + S_dtype, + S_sm, + dP_dtype, + dP_sm, + dQ_te_dtype, + dQ_sm, + dK_te_dtype, + dK_sm, + dV_te_dtype, + dV_sm, + ], + dtype=torch.int64, + device="cpu", + ) + + # Flatten aux_ctx_tensors, padding to 10 slots + num_aux = len(aux_ctx_tensors) if aux_ctx_tensors else 0 + aux_list = (list(aux_ctx_tensors) if aux_ctx_tensors else []) + [None] * (10 - num_aux) + + _ops.fused_attn_bwd_packed( + int(max_seqlen_q), + int(max_seqlen_kv), + float(attn_scale), + float(p_dropout), + bool(set_zero), + int(qkv_layout), + int(bias_type), + int(attn_mask_type), + int(softmax_type), + list(window_size), + bool(bottom_right_diagonal), + bool(deterministic), + bool(cuda_graph), + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + Q_data, + Q_si, + K_data, + K_si, + V_data, + V_si, + O_data, + O_si, + dO_data, + dO_si, + S_data, + S_amax, + S_scale, + S_si, + dP_data, + dP_amax, + dP_scale, + dP_si, + dQ_data, + dQ_amax, + dQ_scale, + dQ_si, + dK_data, + dK_amax, + dK_scale, + dK_si, + dV_data, + dV_amax, + dV_scale, + dV_si, + dBias, + dSoftmaxOffset, + dtype_info, + num_aux, + *aux_list, + ) + + # For delayed-scaling FP8 output, wrap raw uint8 tensors in Float8Tensor with + # scale_inv so downstream operations correctly dequantize. For current scaling, + # the output is already high-precision (BF16) and doesn't need wrapping. + if dqkv_te_dtype in (7, 8) and dqkv_quantizer is not None and not _is_current_scaling: + from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor + + fp8_dtype_val = getattr(dqkv_quantizer, "dtype", dqkv_te_dtype) + fake_dt = fake_dtype if isinstance(fake_dtype, torch.dtype) else torch.bfloat16 + + def _wrap_fp8(raw_tensor, si_tensor): + return Float8Tensor( + raw_tensor.shape, + fake_dt, + data=raw_tensor, + fp8_scale_inv=si_tensor, + fp8_dtype=fp8_dtype_val, + quantizer=dqkv_quantizer, + ) + + dQ = _wrap_fp8(dQ, dQ_si) + dK = _wrap_fp8(dK, dK_si) + dV = _wrap_fp8(dV, dV_si) + + return [dQ, dK, dV, dBias, dSoftmaxOffset] + + +def bulk_overlap_ag_with_external_gemm(allgather_communicator, send_stream, recv_stream): + _ops.bulk_overlap_ag_with_external_gemm( + allgather_communicator._handle, send_stream.cuda_stream, recv_stream.cuda_stream + ) + + +# ============================================================================ +# NVTE enums exposed via pybind +# ============================================================================ + + +class NVTE_QKV_Layout(IntEnum): + NVTE_SB3HD = 0 + NVTE_SBH3D = 1 + NVTE_SBHD_SB2HD = 2 + NVTE_SBHD_SBH2D = 3 + NVTE_SBHD_SBHD_SBHD = 4 + NVTE_BS3HD = 5 + NVTE_BSH3D = 6 + NVTE_BSHD_BS2HD = 7 + NVTE_BSHD_BSH2D = 8 + NVTE_BSHD_BSHD_BSHD = 9 + NVTE_T3HD = 10 + NVTE_TH3D = 11 + NVTE_THD_T2HD = 12 + NVTE_THD_TH2D = 13 + NVTE_THD_THD_THD = 14 + NVTE_SBHD_BSHD_BSHD = 15 + NVTE_BSHD_SBHD_SBHD = 16 + NVTE_THD_BSHD_BSHD = 17 + NVTE_THD_SBHD_SBHD = 18 + NVTE_Paged_KV_BSHD_BSHD_BSHD = 19 + NVTE_Paged_KV_BSHD_SBHD_SBHD = 20 + NVTE_Paged_KV_SBHD_BSHD_BSHD = 21 + NVTE_Paged_KV_SBHD_SBHD_SBHD = 22 + NVTE_Paged_KV_THD_BSHD_BSHD = 23 + NVTE_Paged_KV_THD_SBHD_SBHD = 24 + + +class NVTE_QKV_Format(IntEnum): + NVTE_SBHD = 0 + NVTE_BSHD = 1 + NVTE_THD = 2 + NVTE_BSHD_2SBHD = 3 + NVTE_SBHD_2BSHD = 4 + NVTE_THD_2BSHD = 5 + NVTE_THD_2SBHD = 6 + + +class NVTE_Bias_Type(IntEnum): + NVTE_NO_BIAS = 0 + NVTE_PRE_SCALE_BIAS = 1 + NVTE_POST_SCALE_BIAS = 2 + NVTE_ALIBI = 3 + + +class NVTE_Mask_Type(IntEnum): + NVTE_NO_MASK = 0 + NVTE_PADDING_MASK = 1 + NVTE_CAUSAL_MASK = 2 + NVTE_PADDING_CAUSAL_MASK = 3 + NVTE_CAUSAL_BOTTOM_RIGHT_MASK = 4 + NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK = 5 + NVTE_ARBITRARY_MASK = 6 + + +class NVTE_Softmax_Type(IntEnum): + NVTE_VANILLA_SOFTMAX = 0 + NVTE_OFF_BY_ONE_SOFTMAX = 1 + NVTE_LEARNABLE_SOFTMAX = 2 + + +class NVTE_Fused_Attn_Backend(IntEnum): + NVTE_No_Backend = -1 + NVTE_F16_max512_seqlen = 0 + NVTE_F16_arbitrary_seqlen = 1 + NVTE_FP8 = 2 + + +def device_supports_multicast(): + """Check if current device supports multicast.""" + return bool(_ops.device_supports_multicast(-1)) + + +def get_stream_priority_range(): + """Get CUDA stream priority range.""" + result = _ops.get_stream_priority_range(-1) + return int(result[0]), int(result[1]) + + +def ubuf_built_with_mpi(): + """Check if TE was built with NVTE_UB_WITH_MPI=1.""" + return bool(_ops.ubuf_built_with_mpi()) + + +# Register stable GEMM op as a passthrough in QuantizedTensor.__torch_dispatch__ +# so that FP8/quantized tensors are NOT dequantized before entering the GEMM op. +# This mirrors how te_moe ops are registered in permutation.py. +def _register_passthrough_ops(): + import sys + + # Only run if quantized_tensor is already in sys.modules (avoid circular import). + # If it's not imported yet, quantized_tensor.py will call this on its own after + # importing _stable_torch_module. + if "transformer_engine.pytorch.quantized_tensor" not in sys.modules: + return + try: + from transformer_engine.pytorch.quantized_tensor import ( + _quantized_tensor_passthrough_ops, + ) + + _quantized_tensor_passthrough_ops.add(torch.ops.transformer_engine_stable.gemm.default) + except (ImportError, AttributeError): + pass + + +_register_passthrough_ops() diff --git a/transformer_engine/pytorch/_tex.py b/transformer_engine/pytorch/_tex.py new file mode 100644 index 0000000000..f28e4616e0 --- /dev/null +++ b/transformer_engine/pytorch/_tex.py @@ -0,0 +1,11 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Routing module for transformer_engine_torch. + +All symbols come from the stable ABI implementation registered as +transformer_engine_torch in sys.modules by pytorch/__init__.py. +""" + +from transformer_engine_torch import * # noqa: F401,F403 # pylint: disable=wildcard-import,unused-wildcard-import diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 442366035a..a294a05092 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1132,7 +1132,7 @@ def convert_to_torch_float8(tensor, dtype): shape=output_data.shape, ) else: - output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1) + output = output.reshape(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1) elif q_format == "bshd": # (bs)hd -> bs(hd) output = output.reshape(batch_size, max_seqlen_q // cp_size, -1) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp deleted file mode 100644 index b06f6f5619..0000000000 --- a/transformer_engine/pytorch/csrc/common.cpp +++ /dev/null @@ -1,328 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "common.h" - -#include "c10/util/ArrayRef.h" -#include "pybind.h" -#include "transformer_engine/transformer_engine.h" - -namespace transformer_engine::pytorch { - -/*! convert fp4 data shape back to original shape */ -std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose) { - std::vector ret; - size_t start_idx = (transpose) ? 1 : 0; - for (size_t i = start_idx; i < shape.size() - 1; ++i) { - ret.push_back(shape[i]); - } - ret.push_back(shape.back() * 2); - if (transpose) { - ret.push_back(shape.front()); - } - return ret; -} - -std::vector getTensorShape(const at::Tensor& t) { - std::vector shape; - for (auto s : t.sizes()) { - shape.push_back(s); - } - return shape; -} - -NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) { - NVTEShape ret; - ret.ndim = torch_shape.size(); - constexpr int max_dimensions = sizeof(ret.data) / sizeof(size_t); - NVTE_CHECK(ret.ndim < max_dimensions, - "Torch tensor has too many dimensions. Max supported: ", max_dimensions, " and got ", - ret.ndim, "."); - for (size_t i = 0; i < ret.ndim; ++i) { - const auto& v = torch_shape[i]; - ret.data[i] = static_cast(v); - } - return ret; -} - -std::unique_ptr convert_quantizer(py::handle quantizer) { - init_extension(); - if (quantizer.is_none()) { - return std::make_unique(quantizer); - } - for (auto [_check_type, check_quantizer_type, _create_tensor, create_quantizer] : - detail::custom_types_converters) { - if (check_quantizer_type(quantizer.ptr())) { - return create_quantizer(quantizer); - } - } - - NVTE_ERROR("Unexpected type for quantizer"); -} - -transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, - const std::string& fp8_recipe) { - // if e4m3 or hybrid + forward - if ((fp8_recipe == "E4M3") || ((fp8_recipe == "HYBRID") && e4m3_if_hybrid)) { - return transformer_engine::DType::kFloat8E4M3; - } - return transformer_engine::DType::kFloat8E5M2; -} - -TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer) { - NVTE_CHECK(!tensor.is_none(), "Tensor is not allocated!"); - std::unique_ptr my_quantizer = convert_quantizer(quantizer); - // check for both quantizer & tensor type: - // mxfp8 tensor -> mxfp8 quantizer - // float8 tensor -> delayed scaling quantizer OR current scaling quantizer - // also during dequantize, the quantizer param is unknown -> so quantizer is NoneQuantizer - for (auto [check_type, check_quantizer_type, create_tensor, _] : - detail::custom_types_converters) { - if (check_type(tensor.ptr())) { - if (!(quantizer.is_none() || check_quantizer_type(quantizer.ptr()))) { - continue; - } - auto x = create_tensor(tensor, my_quantizer.get()); - return x; - } - } - NVTE_CHECK(dynamic_cast(my_quantizer.get()) != nullptr, - "Unexpected quantization params type."); - - // Regular pyTorch tensor - at::Tensor torch_tensor = tensor.cast(); - - // #TODO (pgadzinski) - needed in attention for non-contiguous tensors. - //if (!torch_tensor.is_contiguous()) { - // torch_tensor = torch_tensor.contiguous(); - //} - auto ret = TensorWrapper(my_quantizer->get_scaling_mode()); - ret.set_rowwise_data(torch_tensor.data_ptr(), - GetTransformerEngineDType(torch_tensor.scalar_type()), - getTensorShape(torch_tensor)); - my_quantizer->set_quantization_params(&ret); - return ret; -} - -transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type) { - return transformer_engine::TensorWrapper(data_ptr, shape, type); -} - -transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const std::vector& shape, const transformer_engine::DType type) { - return transformer_engine::TensorWrapper(data_ptr, shape, type); -} - -transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor) { - transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); - std::vector shape; - for (auto s : tensor.sizes()) { - shape.push_back(s); - } - return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype); -} - -std::tuple, std::vector>, - std::vector, size_t, size_t> -makeTransformerEngineTensorList(std::vector> at_tensor_lists) { - size_t num_lists = at_tensor_lists.size(); - - NVTE_CHECK(num_lists > 0, "List of tensors is empty."); - - size_t num_tensors = at_tensor_lists[0].size(); - - std::vector> nvte_tensor_lists; - std::vector nvte_tensor_list_ptrs; - std::vector tensorWrappers; - nvte_tensor_lists.reserve(num_lists); - nvte_tensor_list_ptrs.reserve(num_lists); - tensorWrappers.reserve(num_lists * num_tensors); - - for (const auto& at_list : at_tensor_lists) { - NVTE_CHECK(at_list.size() == num_tensors, "Wrong number of tensors"); - std::vector te_list; - te_list.reserve(num_tensors); - - for (const auto& at_tensor : at_list) { - tensorWrappers.push_back(makeTransformerEngineTensor(at_tensor)); - te_list.push_back(tensorWrappers.back().data()); - } - - nvte_tensor_lists.push_back(std::move(te_list)); - } - - for (auto& te_tensor_list : nvte_tensor_lists) { - nvte_tensor_list_ptrs.push_back(te_tensor_list.data()); - } - - return std::make_tuple(std::move(tensorWrappers), std::move(nvte_tensor_lists), - std::move(nvte_tensor_list_ptrs), num_lists, num_tensors); -} - -transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const std::vector& shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector scale_inv_shape, - NVTEScalingMode scaling_mode) { - TensorWrapper ret(scaling_mode); - ret.set_rowwise_data(data_ptr, type, shape); - const std::vector meta_shape{1}; - ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); - ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); - auto scale_inv_dtype = - (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; - ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); - return ret; -} - -transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, - const std::vector& columnwise_shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, - const std::vector& scale_inv_shape, - const std::vector& columnwise_scale_inv_shape, NVTEScalingMode scaling_mode) { - TensorWrapper ret(scaling_mode); - ret.set_rowwise_data(data_ptr, type, shape); - ret.set_columnwise_data(columnwise_data_ptr, type, columnwise_shape); - const std::vector meta_shape{1}; - ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); - ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); - auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 - : (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat8E4M3 - : DType::kFloat32; - ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); - ret.set_columnwise_scale_inv(columnwise_scale_inv_ptr, scale_inv_dtype, - columnwise_scale_inv_shape); - return ret; -} - -transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, at::Tensor amax, - const at::Tensor scale, - at::Tensor scale_inv, - NVTEScalingMode scaling_mode) { - transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); - - auto tensor_shape = getTensorShape(tensor); - auto scale_inv_shape = getTensorShape(scale_inv); - - NVTE_CHECK(amax.scalar_type() == at::kFloat); - NVTE_CHECK(scale.scalar_type() == at::kFloat); - NVTE_CHECK(scale_inv.scalar_type() == at::kFloat); - - return makeTransformerEngineTensor(tensor.data_ptr(), tensor_shape, dtype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr(), scale_inv_shape, - scaling_mode); -} - -template -T product(const std::vector& shape) { - T ret = 1; - for (auto s : shape) { - ret *= s; - } - return ret; -} - -template size_t product(const std::vector& shape); -template int64_t product(const std::vector& shape); - -size_t product(const NVTEShape& shape, size_t begin, size_t end) { - NVTE_CHECK(begin <= end && end <= shape.ndim, "Attempted to access entries ", begin, " to ", end, - " in a shape with ", shape.ndim, " entries"); - size_t ret = 1; - for (size_t i = begin; i < end; ++i) { - ret *= shape.data[i]; - } - return ret; -} - -std::vector nvte_shape_to_vector(const NVTEShape& nvte_shape) { - std::vector shape; - for (size_t i = 0; i < nvte_shape.ndim; i++) { - shape.push_back(nvte_shape.data[i]); - } - return shape; -} - -at::Tensor allocateSpace(const std::vector& shape, const transformer_engine::DType type, - bool init_to_zeros) { - std::vector shape_int64(shape.begin(), shape.end()); - c10::IntArrayRef ar_shape(shape_int64); - if (init_to_zeros) { - return at::zeros(ar_shape, at::CUDA(GetATenDType(type))); - } else { - return at::empty(ar_shape, at::CUDA(GetATenDType(type))); - } -} - -at::Tensor allocateSpace(const NVTEShape& shape, const transformer_engine::DType type, - bool init_to_zeros) { - auto size = shape.ndim; - if (size == 2 && init_to_zeros) { - return at::zeros({static_cast(shape.data[0]), static_cast(shape.data[1])}, - at::CUDA(GetATenDType(type))); - } else if (size == 2) { - return at::empty({static_cast(shape.data[0]), static_cast(shape.data[1])}, - at::CUDA(GetATenDType(type))); - } else if (size == 1 && init_to_zeros) { - return at::zeros({static_cast(shape.data[0])}, at::CUDA(GetATenDType(type))); - } else if (size == 1) { - return at::empty({static_cast(shape.data[0])}, at::CUDA(GetATenDType(type))); - } - NVTE_ERROR("Unsupported tensor allocation: ndim=", size, ", init_to_zeros=", init_to_zeros, - ". Only 1D and 2D tensors are supported."); -} - -at::Tensor allocateTorchTensor(int M, int N, transformer_engine::DType dtype) { - return at::empty({static_cast(M), static_cast(N)}, - at::CUDA(GetATenDType(dtype))); -} - -at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype) { - return at::empty({static_cast(M)}, at::CUDA(GetATenDType(dtype))); -} - -void* getDataPtr(at::Tensor tensor, int offset) { - void* dptr = nullptr; - if (tensor.numel() > 0) { - dptr = tensor.data_ptr(); - } - if (dptr != nullptr && offset != 0) { - char* char_ptr = reinterpret_cast(dptr); - char_ptr += offset * tensor.element_size(); - dptr = reinterpret_cast(char_ptr); - } - return dptr; -} - -std::vector convertShape(const NVTEShape& shape) { - return std::vector(shape.data, shape.data + shape.ndim); -} - -size_t roundup(size_t value, size_t multiple) { - assert(multiple > 0); - return ((value + multiple - 1) / multiple) * multiple; -} - -size_t ceildiv(size_t numer, size_t denom) { return (numer + denom - 1) / denom; } - -void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) { - NVTE_SCOPED_GIL_RELEASE({ - nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val, - arg.offset_.ptr, arg.offset_.val, arg.offset_intragraph_, - at::cuda::getCurrentCUDAStream()); - }); -} - -// extract PhiloxCudaState from CUDA random number generator -at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_per_thread) { - at::PhiloxCudaState philox_args; - std::lock_guard lock(gen->mutex_); - philox_args = gen->philox_cuda_state(elts_per_thread); - return philox_args; -} - -} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h deleted file mode 100644 index 63a2e86e67..0000000000 --- a/transformer_engine/pytorch/csrc/common.h +++ /dev/null @@ -1,582 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ -#define TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -#include "c10/util/ArrayRef.h" -#include "common/util/logging.h" - -namespace transformer_engine::pytorch { - -// in python we have: dist_group_type = torch.distributed.ProcessGroup -using dist_group_type = c10d::ProcessGroup; - -// Each tensor here is shape (N, ) holding all scaling -// data for a single FP8 block, e.g. LayerNormLinear -class FP8TensorMeta { - public: - at::Tensor scale; - at::Tensor scale_inv; - at::Tensor amax_history; -}; - -// Used as named indices on the `scale`, `scale_inv`, -// and `amax` tensors in the `FP8TensorMeta` class. -enum FP8FwdTensors { - GEMM1_INPUT = 0, - GEMM1_WEIGHT = 1, - GEMM1_OUTPUT = 2, - GEMM2_INPUT = 3, - GEMM2_WEIGHT = 4, - GEMM2_OUTPUT = 5, - GEMM3_INPUT = 6, - GEMM3_WEIGHT = 7, - GEMM3_OUTPUT = 8 -}; - -// Used as named indices on the `scale`, `scale_inv`, -// and `amax` tensors in the `FP8TensorMeta` class. -enum FP8BwdTensors { - GRAD_OUTPUT1 = 0, - GRAD_INPUT1 = 1, - GRAD_OUTPUT2 = 2, - GRAD_INPUT2 = 3, - GRAD_OUTPUT3 = 4, - GRAD_INPUT3 = 5 -}; - -class Quantizer { - public: - virtual NVTEScalingMode get_scaling_mode() const = 0; - - virtual void set_quantization_params(TensorWrapper* tensor) const = 0; - - /*! @brief Construct a tensor with uninitialized data */ - virtual std::pair create_tensor(const std::vector& shape, - DType dtype) const = 0; - - /*! @brief Construct a grouped tensor with uninitialized data */ - virtual std::pair create_grouped_tensor( - size_t num_tensors, const std::vector& logical_shape, DType dtype, - py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, - size_t logical_last_dim) const = 0; - - /*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor - * - * The PyTorch tensor's attributes are modified to match the - * quantizer's configuration. - */ - virtual std::pair convert_and_update_tensor( - py::object tensor) const = 0; - - /*! @brief Convert to a quantized data format */ - virtual void quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag = std::nullopt) = 0; - - virtual ~Quantizer() = default; - - bool rowwise_usage = true; - bool columnwise_usage = true; - bool internal = false; - bool optimize_for_gemm = false; - py::handle quantizer; - - protected: - explicit Quantizer(const py::handle& quantizer); -}; - -class NoneQuantizer : public Quantizer { - public: - explicit NoneQuantizer(const py::handle& quantizer) : Quantizer(quantizer) {} - - NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; } - - void set_quantization_params(TensorWrapper* tensor) const override {} - - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; - - std::pair create_grouped_tensor( - size_t num_tensors, const std::vector& logical_shape, DType dtype, - py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, - size_t logical_last_dim) const override; - - /*! @brief Construct a tensor with pre-initialized data */ - std::pair create_tensor(const std::vector& shape, DType dtype, - at::Tensor data) const; - - std::pair convert_and_update_tensor(py::object tensor) const override; - - void quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag = std::nullopt) override; -}; - -class Float8Quantizer : public Quantizer { - public: - at::Tensor scale; - at::Tensor scale_inv; - at::Tensor amax; - DType dtype; - - explicit Float8Quantizer(const py::handle& quantizer); - - NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; } - - void set_quantization_params(TensorWrapper* tensor) const override; - - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; - - std::pair create_grouped_tensor( - size_t num_tensors, const std::vector& logical_shape, DType dtype, - py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, - size_t logical_last_dim) const override; - - /*! @brief Construct a tensor with pre-initialized data */ - std::pair create_tensor(const std::vector& shape, DType dtype, - std::optional data, - std::optional transpose, - std::optional scale_inv) const; - - std::pair convert_and_update_tensor(py::object shape) const override; - - void quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag = std::nullopt) override; -}; - -class Float8CurrentScalingQuantizer : public Quantizer { - public: - at::Tensor scale; - at::Tensor scale_inv; - at::Tensor amax; - DType dtype; - bool with_amax_reduction; - c10::intrusive_ptr amax_reduction_group; - bool force_pow_2_scales = false; - float amax_epsilon = 0.0; - - explicit Float8CurrentScalingQuantizer(const py::handle& quantizer); - - NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; } - - void set_quantization_params(TensorWrapper* tensor) const override; - - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; - - std::pair create_grouped_tensor( - size_t num_tensors, const std::vector& logical_shape, DType dtype, - py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, - size_t logical_last_dim) const override; - - /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer. - * - * The amax is zeroed out. Most TE kernels that output amax expect - * amax to be initialized to zero. - */ - std::pair create_unquantized_tensor_with_amax( - const std::vector& shape, DType dtype, std::optional data = std::nullopt); - - std::pair convert_and_update_tensor(py::object shape) const override; - - void quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag = std::nullopt) override; - - /*! @brief Quantize to FP8, skipping local amax computation - * - * The quantizer's amax pointer is assumed to already hold the local - * amax. The amax may still be reduced across the amax reduction - * group. - */ - void quantize_with_amax(TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag = std::nullopt); - - private: - void quantize_impl(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag, bool compute_amax); -}; - -class Float8BlockQuantizer : public Quantizer { - public: - // Which float8 type is used for q data. - DType dtype; - // Options about how to quantize the tensor - // Quantization scales are rounded down to powers of 2. - bool force_pow_2_scales = false; - // Amax within quantization tile has a floor of epsilon. - float amax_epsilon = 0.0; - - private: - int block_scaling_dim = 2; - - public: - // Initializes from a python handle to a Float8BlockQuantizer - explicit Float8BlockQuantizer(const py::handle& quantizer); - - NVTEScalingMode get_scaling_mode() const override { - return (block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D; - } - - // Gets rowwise and columnwise_data from tensor and sets them on wrapper - void set_quantization_params(TensorWrapper* tensor) const override; - - // Create a python Float8BlockQuantized tensor and C++ wrapper - // for the tensor. Should set quantized data, scales for rowwise - // and optionally columnwise usage. - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; - - std::pair create_grouped_tensor( - size_t num_tensors, const std::vector& logical_shape, DType dtype, - py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, - size_t logical_last_dim) const override; - - std::pair convert_and_update_tensor(py::object shape) const override; - - void quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag = std::nullopt) override; - - std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; -}; - -class MXFP8Quantizer : public Quantizer { - public: - DType dtype; - - explicit MXFP8Quantizer(const py::handle& quantizer); - - NVTEScalingMode get_scaling_mode() const override { return NVTE_MXFP8_1D_SCALING; } - - void set_quantization_params(TensorWrapper* tensor) const override; - - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; - - std::pair create_grouped_tensor( - size_t num_tensors, const std::vector& logical_shape, DType dtype, - py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, - size_t logical_last_dim) const override; - - std::pair convert_and_update_tensor(py::object shape) const override; - - void quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag = std::nullopt) override; - - std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; -}; - -class NVFP4Quantizer : public Quantizer { - public: - // fp4 dtype - DType dtype; - // amax reduction for low precision FP4 AG - bool with_amax_reduction; - c10::intrusive_ptr amax_reduction_group; - // random hadamard transform - bool with_rht; - bool with_post_rht_amax; - // 2D block scaling - bool with_2d_quantization; - bool stochastic_rounding; - - int rht_matrix_random_sign_mask_t; - at::Tensor rht_matrix; - - explicit NVFP4Quantizer(const py::handle& quantizer); - - NVTEScalingMode get_scaling_mode() const override { return NVTE_NVFP4_1D_SCALING; } - - void set_quantization_params(TensorWrapper* tensor) const override; - - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; - - std::pair create_grouped_tensor( - size_t num_tensors, const std::vector& logical_shape, DType dtype, - py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, - size_t logical_last_dim) const override; - - /*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer - * - * The amax is zeroed out. Most TE kernels that output amax expect - * amax to be initialized to zero. - */ - std::pair create_unquantized_tensor_with_amax( - TensorWrapper& quantized_tensor, DType dtype); - - std::pair convert_and_update_tensor(py::object shape) const override; - - void quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag = std::nullopt) override; - - /*! @brief Quantize to NVFP4, skipping local amax computation - * - * The input tensor's amax pointer is assumed to already hold the - * local amax. The amax may still be reduced across the amax - * reduction group. - */ - void quantize_with_amax(TensorWrapper& input, TensorWrapper& out); - - std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; - - private: - void quantize_impl(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag, bool compute_amax); - void quantize_with_rht_unfused_helper(const TensorWrapper& input, TensorWrapper& out, - TensorWrapper& rht_output_t_cpp, - QuantizationConfigWrapper& quant_config, - QuantizationConfigWrapper& quant_config_columnwise, - cudaStream_t stream); -}; - -std::unique_ptr convert_quantizer(py::handle quantizer); - -std::vector getTensorShape(const at::Tensor& t); - -transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, - const std::string& fp8_recipe); - -inline size_t typeToNumBits(transformer_engine::DType t) { - switch (t) { - case transformer_engine::DType::kInt64: - return 64; - case transformer_engine::DType::kInt32: - case transformer_engine::DType::kFloat32: - return 32; - case transformer_engine::DType::kInt16: - case transformer_engine::DType::kFloat16: - case transformer_engine::DType::kBFloat16: - return 16; - case transformer_engine::DType::kByte: - case transformer_engine::DType::kFloat8E4M3: - case transformer_engine::DType::kFloat8E5M2: - case transformer_engine::DType::kFloat8E8M0: - return 8; - case transformer_engine::DType::kFloat4E2M1: - return 4; - default: - NVTE_ERROR("Invalid type (", static_cast(t), ")."); - } -} - -inline at::ScalarType GetATenDType(transformer_engine::DType t) { - switch (t) { - case transformer_engine::DType::kInt16: - return torch::kInt16; - case transformer_engine::DType::kInt32: - return torch::kInt32; - case transformer_engine::DType::kInt64: - return torch::kInt64; - case transformer_engine::DType::kFloat32: - return at::kFloat; - case transformer_engine::DType::kFloat16: - return at::kHalf; - case transformer_engine::DType::kBFloat16: - return at::kBFloat16; - case transformer_engine::DType::kByte: - return at::kByte; - case transformer_engine::DType::kFloat8E4M3: - return at::kFloat8_e4m3fn; - case transformer_engine::DType::kFloat8E5M2: - return at::kFloat8_e5m2; - case transformer_engine::DType::kFloat8E8M0: - return at::kByte; // e8m0 dtype requires PyTorch 2.7.0+ - default: - NVTE_ERROR("Invalid type (", static_cast(t), ")."); - } -} - -inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) { - switch (t) { - case at::kFloat8_e4m3fn: - return transformer_engine::DType::kFloat8E4M3; - case at::kFloat8_e5m2: - return transformer_engine::DType::kFloat8E5M2; - case at::kHalf: - return transformer_engine::DType::kFloat16; - case at::kFloat: - return transformer_engine::DType::kFloat32; - case at::kBFloat16: - return transformer_engine::DType::kBFloat16; - case at::kBool: - return transformer_engine::DType::kByte; - case torch::kByte: - return transformer_engine::DType::kByte; - case torch::kInt16: - return transformer_engine::DType::kInt16; - case torch::kInt32: - return transformer_engine::DType::kInt32; - case torch::kInt64: - return transformer_engine::DType::kInt64; - default: - NVTE_ERROR("Invalid type (", static_cast(t), ")."); - } -} - -inline transformer_engine::DType GetTransformerEngineDType(int DType_value) { - return static_cast(DType_value); -} - -transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, - const std::vector& shape, - const transformer_engine::DType type); - -transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const std::vector& shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector scale_inv_shape = {1}, - NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); - -transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, - const std::vector& columnwise_shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, - const std::vector& scale_inv_shape = {1}, - const std::vector& columnwise_scale_inv_shape = {1}, - NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); - -transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, - const NVTEShape& shape, - const transformer_engine::DType type); - -transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor); - -std::tuple, std::vector>, - std::vector, size_t, size_t> -makeTransformerEngineTensorList(std::vector> at_tensor_lists); - -TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer); - -transformer_engine::TensorWrapper makeTransformerEngineTensor( - at::Tensor tensor, at::Tensor amax, const at::Tensor scale, at::Tensor scale_inv, - NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); - -template -T product(const std::vector& shape); - -size_t product(const NVTEShape& shape, size_t begin, size_t end); - -std::vector nvte_shape_to_vector(const NVTEShape& nvte_shape); - -at::Tensor allocateSpace(const std::vector& shape, const transformer_engine::DType type, - bool init_to_zeros); - -at::Tensor allocateSpace(const NVTEShape& shape, const transformer_engine::DType type, - bool init_to_zeros = false); - -at::Tensor allocateTorchTensor(int M, int N, transformer_engine::DType dtype); - -at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype); - -void* getDataPtr(at::Tensor tensor, int offset = 0); - -std::vector convertShape(const NVTEShape& shape); - -size_t roundup(size_t value, size_t multiple); - -size_t ceildiv(size_t numer, size_t denom); - -NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); - -std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose); - -// unpack the PhiloxCudaState into CUDA tensor -void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr); - -// extract PhiloxCudaState from CUDA random number generator -at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_per_thread); - -} // namespace transformer_engine::pytorch - -namespace std { -template -string to_string(const vector& vec) { - string ret = "["; - for (const auto& val : vec) { - ret += to_string(val) + ","; - } - if (ret.size() > 1) { - ret[ret.size() - 1] = ']'; - } else { - ret += "]"; - } - return ret; -} - -// Torch shape -> string -template -string to_string(const c10::ArrayRef& vec) { - string ret = "["; - for (const auto& val : vec) { - ret += to_string(val) + ","; - } - if (ret.size() > 1) { - ret[ret.size() - 1] = ']'; - } else { - ret += "]"; - } - return ret; -} - -inline string to_string(const NVTEShape& s) { - string ret = "["; - for (size_t i = 0; i < s.ndim; ++i) { - ret += to_string(s.data[i]) + ","; - } - if (ret.size() > 1) { - ret[ret.size() - 1] = ']'; - } else { - ret += "]"; - } - return ret; -} -} // namespace std - -#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h deleted file mode 100644 index 1c5116a8da..0000000000 --- a/transformer_engine/pytorch/csrc/extensions.h +++ /dev/null @@ -1,661 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ -#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ - -#include -#include -#include -#include -#include -#include - -#include "common.h" - -class CommOverlapHelper; -class CommOverlap; -class CommOverlapP2P; - -namespace transformer_engine::pytorch { - -/*************************************************************************************************** - * Router fusion - **************************************************************************************************/ - -std::tuple fused_topk_with_score_function_fwd( - at::Tensor logits, int topk, bool use_pre_softmax, std::optional num_groups, - std::optional group_topk, std::optional scaling_factor, std::string score_function, - std::optional expert_bias); - -void fused_topk_with_score_function_bwd(int num_tokens, int num_experts, at::Tensor routing_map, - at::Tensor intermediate_output, at::Tensor grad_probs, - at::Tensor grad_logits, int topk, bool use_pre_softmax, - std::optional scaling_factor, - std::string score_function); - -std::tuple fused_score_for_moe_aux_loss_fwd( - at::Tensor logits, int topk, std::string score_function); - -void fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts, - at::Tensor intermediate_output, at::Tensor grad_probs, - at::Tensor grad_logits, int topk, std::string score_function); - -std::tuple fused_moe_aux_loss_fwd(at::Tensor probs, - at::Tensor tokens_per_expert, - int total_num_tokens, int num_experts, - int num_rows, int num_cols, int topk, - float coeff); - -at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert, int num_rows, - int num_cols, at::Tensor grad_aux_loss); - -/*************************************************************************************************** - * Permutation - **************************************************************************************************/ - -std::tuple> moe_permute_fwd( - at::Tensor input, const DType dtype, at::Tensor indices, int64_t num_out_tokens, - std::vector workspace, int64_t max_expanded_token_num); - -at::Tensor moe_permute_bwd(at::Tensor input, const DType dtype, at::Tensor row_id_map, - at::Tensor prob, int64_t num_tokens, int64_t topK); - -at::Tensor moe_unpermute_fwd(at::Tensor input, const DType dtype, at::Tensor row_id_map, - at::Tensor prob, int64_t num_tokens, int64_t topK); - -std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd, - const DType dtype, at::Tensor row_id_map, - at::Tensor prob); - -/*************************************************************************************************** - * Attention - **************************************************************************************************/ - -NVTE_Fused_Attn_Backend get_fused_attn_backend( - bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic); - -std::vector fused_attn_fwd( - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, - bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - const std::vector window_size, bool bottom_right_diagonal, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, - const py::handle K, const py::handle V, const at::ScalarType fake_dtype, - const std::optional cu_seqlens_q_padded, - const std::optional cu_seqlens_kv_padded, - const std::optional page_table_k, const std::optional page_table_v, - py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, - const std::optional SoftmaxOffset, const std::optional rng_gen, - size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph); - -std::vector fused_attn_bwd( - size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, const std::vector window_size, - bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, - const std::vector Aux_CTX_Tensors, - const std::optional cu_seqlens_q_padded, - const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, - py::handle dp_quantizer, py::handle dqkv_quantizer, bool cuda_graph); - -at::Tensor fa_prepare_fwd(at::Tensor qkvi); -at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); - -at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len); -at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t); -void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at::Tensor v_cache, - at::Tensor page_table, at::Tensor cu_new_lens, at::Tensor cu_cached_lens, - NVTE_QKV_Format kv_format, int b, int max_ctx_len, int max_seq_len, - int max_pages_per_seq, bool is_non_paged); - -/*************************************************************************************************** - * GEMM - **************************************************************************************************/ - -using MaybeTensor = std::optional; - -std::vector gemm(py::handle A, bool transa, py::handle B, bool transb, py::object D, - py::handle quantizer, std::optional out_dtype, MaybeTensor bias, - DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad, - at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr, - std::optional comm_type = std::nullopt, - MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false, - float alpha = 1.0f, std::optional beta = std::nullopt); - -void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, - std::vector A_scaling_mode, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, DType B_type, std::vector B_scaling_mode, - bool transb, at::Tensor D, at::Tensor D_scale, DType D_type, at::Tensor D_amax, - at::Tensor bias, DType bias_type, at::Tensor pre_gelu_out, bool grad, - at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, int math_sm_count, int m_split, int n_split, - bool gemm_producer, at::Tensor counter); - -std::optional> te_general_grouped_gemm( - std::vector A, bool transa, std::vector B, bool transb, - std::optional> D, DType D_type, std::vector m_splits, - std::vector bias, DType bias_type, bool single_output, - std::vector pre_gelu_out, bool grad, std::vector workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count); - -py::object te_general_grouped_gemm_for_grouped_tensor( - py::handle A, bool transa, py::handle B, bool transb, py::handle D, py::object bias, - at::Tensor alpha, at::Tensor beta, at::Tensor workspace_setup, at::Tensor workspace_cublas, - bool use_split_accumulator, int math_sm_count); - -py::object te_general_grouped_gemm_for_discrete_in(py::handle A, bool transa, py::handle B, - bool transb, py::handle D, py::object bias, - at::Tensor alpha, at::Tensor beta, - at::Tensor workspace_setup, - at::Tensor workspace_cublas, - bool use_split_accumulator, int math_sm_count); - -py::object te_general_grouped_gemm_for_discrete_out(py::handle A, bool transa, py::handle B, - bool transb, py::handle D, py::object bias, - at::Tensor alpha, at::Tensor beta, - at::Tensor workspace_setup, - at::Tensor workspace_cublas, - bool use_split_accumulator, int math_sm_count); - -/*************************************************************************************************** - * Transpose - **************************************************************************************************/ - -at::Tensor fp8_transpose(at::Tensor input, DType otype, - std::optional output = std::nullopt); - -at::Tensor nvfp4_data_transpose(at::Tensor input, std::optional output = std::nullopt); - -void nvfp4_2d_scale_transpose(at::Tensor input, at::Tensor output, int64_t M_tiles, - int64_t K_tiles); - -void nvfp4_2d_multi_tensor_transpose(std::vector rowwise_data_list, - std::vector columnwise_data_list, - std::vector rowwise_scale_inv_list, - std::vector columnwise_scale_inv_list, - std::vector M_list, std::vector K_list); - -void nvfp4_multi_tensor_compute_partial_amax( - std::vector master_weight_list, std::vector partial_amax_list, - std::vector global_amax_list, std::vector h_list, - std::vector w_list, std::vector start_offset_list, int64_t block_len); - -void nvfp4_expand_scale_to_fp8(at::Tensor input, at::Tensor output, int64_t tile_rows, - int64_t tile_cols, int64_t rows_padded, int64_t block_len); - -void nvfp4_compute_per_block_scale(at::Tensor block_amax, at::Tensor scale, at::Tensor global_amax); - -void nvfp4_fused_scale(at::Tensor block_amax, at::Tensor global_amax, at::Tensor per_block_scale, - at::Tensor target_scale, at::Tensor target_amax, int64_t tile_rows, - int64_t tile_cols, int64_t rows_padded, int64_t block_len); - -void nvfp4_multi_tensor_fused_scale( - std::vector block_amax_list, std::vector global_amax_list, - std::vector per_block_scale_list, std::vector target_scale_list, - std::vector target_amax_list, std::vector tile_rows_list, - std::vector tile_cols_list, std::vector rows_padded_list, int64_t block_len); - -void nvfp4_compute_global_scale(at::Tensor global_amax, at::Tensor global_scale); - -at::Tensor swap_first_dims(at::Tensor tensor, std::optional out = std::nullopt); - -/*************************************************************************************************** - * Activations - **************************************************************************************************/ - -/* GLU (sigmoid gate) */ -py::object glu(const at::Tensor &input, py::handle quantizer); - -py::object dglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); - -/* GELU and variants*/ -py::object gelu(const at::Tensor &input, py::handle quantizer); - -py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); - -py::object geglu(const at::Tensor &input, py::handle quantizer); - -py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); - -py::object qgelu(const at::Tensor &input, py::handle quantizer); - -py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); - -py::object qgeglu(const at::Tensor &input, py::handle quantizer); - -py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); - -/* ReLU and variants*/ -py::object relu(const at::Tensor &input, py::handle quantizer); - -py::object drelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); - -py::object reglu(const at::Tensor &input, py::handle quantizer); - -py::object dreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); - -py::object srelu(const at::Tensor &input, py::handle quantizer); - -py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); - -py::object sreglu(const at::Tensor &input, py::handle quantizer); - -py::object dsreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); - -/* Silu and variants*/ -py::object silu(const at::Tensor &input, py::handle quantizer); - -py::object dsilu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); - -py::object swiglu(const at::Tensor &input, py::handle quantizer); - -py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); - -py::object clamped_swiglu(const at::Tensor &input, py::handle quantizer, float limit, float alpha); - -py::object clamped_dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, - float limit, float alpha); -/*************************************************************************************************** - * LayerNorm - **************************************************************************************************/ - -std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, - const at::Tensor &mu, const at::Tensor &rsigma, - const at::Tensor &gamma, const int sm_margin, - const bool zero_centered_gamma); - -std::vector layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias, - float eps, py::object ln_out, py::handle quantizer, - DType out_dtype, const int sm_margin, - const bool zero_centered_gamma); - -/*************************************************************************************************** - * RMSNorm - **************************************************************************************************/ - -std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, - const at::Tensor &rsigma, const at::Tensor &gamma, - const int sm_margin, const bool zero_centered_gamma); - -std::vector rmsnorm_bwd_add(const at::Tensor &dz, const at::Tensor &x, - const at::Tensor &add, const at::Tensor &rsigma, - const at::Tensor &gamma, const int sm_margin, - const bool zero_centered_gamma); - -std::vector rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps, - py::object ln_out, py::handle quantizer, DType otype, - const int sm_margin, const bool zero_centered_gamma); - -/*************************************************************************************************** - * Cast - **************************************************************************************************/ - -py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, - std::optional noop_flag); - -py::object dequantize(const py::handle &input, DType otype); - -py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, - std::optional first_dims); - -std::vector multi_tensor_quantize(const std::vector &tensor_list, - std::vector quantizer_list); - -std::vector split_quantize(const at::Tensor &tensor, - const std::vector &split_sections, - std::vector quantizer_list, - bool disable_bulk_allocation = false); - -/*************************************************************************************************** - * Bias gradient fusions - **************************************************************************************************/ - -std::vector bgrad_quantize(const at::Tensor &input, py::handle py_quantizer); - -std::vector dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer); - -std::vector dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer); - -std::vector dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer); - -std::vector dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer); - -std::vector dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer); - -/*************************************************************************************************** - * Dropout - **************************************************************************************************/ - -std::vector dropout_fwd(const py::handle &input, const float dropout_probability, - std::optional out = std::nullopt); - -py::object dropout_bwd(const at::Tensor &grad_output, const at::Tensor &mask, - const float dropout_probability, - std::optional grad_input = std::nullopt); - -/*************************************************************************************************** - * Softmax - **************************************************************************************************/ - -at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor); - -at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, - float scale_factor); - -at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor); - -at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, - float scale_factor); - -at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor); - -at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, - at::Tensor softmax_results_, - float scale_factor); - -at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor); - -at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads_, - at::Tensor softmax_results_, - float scale_factor); - -/*************************************************************************************************** - * FP8 recipe - **************************************************************************************************/ - -void compute_amax(const at::Tensor &tensor, at::Tensor &amax); - -void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, - std::vector amax_histories, - std::vector scales, - const std::string &amax_compute_algo, - DType fp8_dtype, float margin); - -// Note that the start_offset is the logical offset along the tensor dimension. -// The offset in bytes is start_offset * sizeof(tensor.dtype) -void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h, - size_t w, size_t start_offset, size_t block_len); - -void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale, - size_t h, size_t w, size_t start_offset, size_t block_len, - const DType out_dtype); - -void nvfp4_2d_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h, size_t w, - size_t start_offset, size_t block_len); - -void nvfp4_2d_partial_cast(const at::Tensor &inp, py::handle out, const at::Tensor &scale, - const at::Tensor &global_scale, size_t h, size_t w, size_t start_offset, - size_t block_len); - -void nvfp4_multi_tensor_2d_partial_cast(std::vector inp_list, - std::vector out_list, - std::vector scale_list, - std::vector global_scale_list, - std::vector h_list, std::vector w_list, - std::vector start_offset_list, int64_t block_len); -void mxfp8_scaling_compute_partial_amax(const at::Tensor &input, at::Tensor amax_rowwise, - at::Tensor amax_colwise, int rows, int cols, - size_t start_offset); - -void mxfp8_scaling_partial_cast(const at::Tensor &input, at::Tensor output_rowwise, - at::Tensor output_colwise, const at::Tensor &scale_inv_rowwise, - const at::Tensor &scale_inv_colwise, int rows, int cols, - size_t start_offset); - -/*************************************************************************************************** - * Rotary positional embedding - **************************************************************************************************/ - -at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, - const std::optional start_positions, - const NVTE_QKV_Format qkv_format, const bool interleaved, - const std::optional cu_seqlens, const int cp_size, - const int cp_rank); - -at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, - const std::optional start_positions, - const NVTE_QKV_Format qkv_format, const bool interleaved, - const std::optional cu_seqlens, const int cp_size, - const int cp_rank); - -std::tuple fused_qkv_rope_forward( - const at::Tensor &qkv_input, const at::Tensor &q_freqs, const at::Tensor &k_freqs, - const std::optional start_positions, const std::vector &qkv_split_arg_list, - const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank); - -at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tensor &k_grad_out, - const at::Tensor &v_grad_out, const at::Tensor &q_freqs, - const at::Tensor &k_freqs, - const std::vector &qkv_split_arg_list, - const NVTE_QKV_Format qkv_format, const bool interleaved, - const int cp_size, const int cp_rank); - -/*************************************************************************************************** - * Miscellaneous - **************************************************************************************************/ - -size_t get_cublasLt_version(); - -size_t get_cudnn_version(); - -at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_dim); - -/*************************************************************************************************** - * Support THD format for Context Parallel - **************************************************************************************************/ - -at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens, - int half_idx); - -void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step, - const at::Tensor &cu_seqlens, bool lse_packed); - -at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, - bool lse_packed, int second_half_lse_seqlen); - -void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, - const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, - bool only_second_half, bool lse_packed); - -void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, - const at::Tensor &cu_seqlens, const std::string &first_half, - const std::string &second_half); - -at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens, - int world_size, int rank); - -/*************************************************************************************************** - * multi_tensor_* kernels - **************************************************************************************************/ - -void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, float scale); - -void multi_tensor_scale_tensor_cuda(int chunk_size, at::Tensor is_infinite, - std::vector> tensor_lists, - at::Tensor scale); - -std::tuple multi_tensor_l2norm_cuda( - int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, - at::optional per_tensor_python); - -std::tuple multi_tensor_unscale_l2norm_cuda( - int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, - at::Tensor inv_scale, at::optional per_tensor_python); - -void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, const float lr, - const float beta1, const float beta2, const float epsilon, - const int step, const int mode, const int bias_correction, - const float weight_decay); - -void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, - const float lr, const float beta1, const float beta2, - const float epsilon, const int step, const int mode, - const int bias_correction, const float weight_decay); - -void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, const float lr, - const float beta1, const float beta2, const float epsilon, - const int step, const int mode, const int bias_correction, - const float weight_decay, DType fp8_dtype); - -void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor lr, const float beta1, const float beta2, - const float epsilon, at::Tensor step, const int mode, - const int bias_correction, const float weight_decay, - at::Tensor inv_scale); - -void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor lr, const float beta1, const float beta2, - const float epsilon, at::Tensor step, const int mode, - const int bias_correction, const float weight_decay, - at::Tensor inv_scale); - -void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, float wd, - float momentum, float dampening, float lr, bool nesterov, bool first_run, - bool wd_after_momentum, float scale); - -void multi_tensor_compute_scale_and_scale_inv_cuda( - int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, - float max_fp8, bool force_pow_2_scales, float epsilon); - -void multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, const py::object &dummy, - std::vector> tensor_lists); - -/*************************************************************************************************** - * padding - **************************************************************************************************/ - -void fused_multi_row_padding(at::Tensor input, at::Tensor output, - std::vector input_row_list, - std::vector padded_input_row_list); - -void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, - std::vector input_row_list, - std::vector unpadded_input_row_list); - -/*************************************************************************************************** - * Scale swizzling for GEMM - **************************************************************************************************/ - -void inplace_swizzle_scale_for_gemm(py::handle &tensor); - -/*************************************************************************************************** - * NVSHMEM APIs - **************************************************************************************************/ - -void init_nvshmem_backend(c10d::ProcessGroup *process_group); - -at::Tensor create_nvshmem_tensor(const std::vector &shape, c10::ScalarType dtype); - -void nvshmem_send_on_current_stream(at::Tensor src, at::Tensor dst, int peer, at::Tensor signal); - -void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_kind); - -void nvshmem_finalize(); - -/*************************************************************************************************** - * Comm+GEMM Overlap Wrappers - **************************************************************************************************/ - -void bulk_overlap_ag_with_external_gemm(CommOverlap &allgather_communicator, at::Stream send_stream, - at::Stream recv_stream); - -} // namespace transformer_engine::pytorch - -/*************************************************************************************************** - * Comm+GEMM Overlap Wrappers - **************************************************************************************************/ - -class CommOverlapHelper : torch::CustomClassHolder { - private: - bool initialized{false}; - bool backend_is_nccl{false}; - std::map pgs; - - public: - int myrank = -1; - int numranks = -1; - int mylocal = -1; - int numlocal = -1; - int mynode = -1; - int numnodes = -1; - - CommOverlapHelper(); - - CommOverlapHelper(c10d::ProcessGroup *world_group, - std::optional intra_node_group); - - ~CommOverlapHelper(); - - void ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes, - ExtComm comm); - - void ub_barrier(ExtComm comm); -}; - -class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase { - public: - CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, - CommOverlapHelper *helper, int tp_size, int num_splits = 3, - int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, - int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16, - bool set_sm_margin = true, bool atomic_gemm = false, - bool rs_overlap_first_gemm = false); - - ~CommOverlap() {} - - using transformer_engine::CommOverlapCore::copy_into_buffer; - void copy_into_buffer(const at::Tensor &input, bool local_chunk = false); - - at::Tensor get_buffer(bool local_chunk = false, - std::optional> shape = std::nullopt); - - std::pair get_communication_stream(); - -}; // CommOverlap - -class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase { - public: - CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, - CommOverlapHelper *helper, int tp_size, - transformer_engine::CommOverlapType comm_type, - int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, - int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 3, - bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true, - bool aggregate = false); - - ~CommOverlapP2P() {} - - using transformer_engine::CommOverlapP2PBase::copy_into_buffer; - void copy_into_buffer(const at::Tensor &input, bool local_chunk = false); - - at::Tensor get_buffer(bool local_chunk = false, - std::optional> shape = std::nullopt); - - std::pair get_communication_stream(); - -}; // CommOverlapP2P - -#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 99b9c1fefa..737a006ec8 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -3,339 +3,168 @@ * * See LICENSE for license information. ************************************************************************/ -#include "../extensions.h" -#include "common.h" -#include "pybind.h" -namespace transformer_engine { -namespace pytorch { - -namespace { -using FuncType = void(const NVTETensor, NVTETensor, cudaStream_t); -using DFuncType = void(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t); - -template -py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1, - Args&&... args) { - init_extension(); - - // Input tensor - auto input_tensor = input.contiguous(); - const TensorWrapper& input_nvte = makeTransformerEngineTensor(input_tensor); - - // Construct output tensor - auto quantizer_cpp = convert_quantizer(quantizer); - const auto input_shape = input_nvte.shape(); - std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); - output_shape.back() /= shape_divisor; - auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); - auto [out_nvte, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype); - - // Choose implementation - enum class Impl { UNFUSED, FULLY_FUSED, FUSED_ACTIVATION_AMAX_FP8, FUSED_ACTIVATION_AMAX_NVFP4 }; - Impl impl = Impl::UNFUSED; - if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || - detail::IsMXFP8Quantizers(quantizer.ptr())) { - impl = Impl::FULLY_FUSED; - } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - impl = Impl::FUSED_ACTIVATION_AMAX_FP8; - } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { - auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); - NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer - impl = Impl::UNFUSED; - } else { - impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; - } - } - - // Perform compute - auto stream = at::cuda::getCurrentCUDAStream(); - switch (impl) { - case Impl::UNFUSED: - // Compute activation in high precision, then quantize - { - auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ - if constexpr (act_func == nullptr) { - act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward(args)..., - stream); - } else { - act_func(input_nvte.data(), temp_nvte.data(), stream); - } - }); - quantizer_cpp->quantize(temp_nvte, out_nvte); - } - break; - case Impl::FULLY_FUSED: - // Compute activation directly - { - NVTE_SCOPED_GIL_RELEASE({ - if constexpr (act_func == nullptr) { - act_func_with_args(input_nvte.data(), out_nvte.data(), std::forward(args)..., - stream); - } else { - act_func(input_nvte.data(), out_nvte.data(), stream); - } - }); - } - break; - case Impl::FUSED_ACTIVATION_AMAX_FP8: - // Compute activation and amax in high precision, then quantize to FP8 - { - auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); - NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); - auto [temp_nvte, _] = - fp8_quantizer_cpp->create_unquantized_tensor_with_amax(output_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ - if constexpr (act_func == nullptr) { - act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward(args)..., - stream); - } else { - act_func(input_nvte.data(), temp_nvte.data(), stream); - } - }); - fp8_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte); - } - break; - case Impl::FUSED_ACTIVATION_AMAX_NVFP4: - // Compute activation and amax in high precision, then quantize to NVFP4 - { - auto nvfp4_quantizer_cpp = - static_cast(quantizer_cpp.get()); // Already checked cast is valid - auto [temp_nvte, _] = - nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ - if constexpr (act_func == nullptr) { - act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward(args)..., - stream); - } else { - act_func(input_nvte.data(), temp_nvte.data(), stream); - } - }); - nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte); - } - break; - default: - NVTE_ERROR("Invalid activation implementation (", static_cast(impl), ")"); +#include + +#include "../stable_common.h" + +namespace transformer_engine::pytorch::stable { + +using Tensor = torch::stable::Tensor; + +// ============================================================================ +// Generic activation forward — no-alloc variant +// +// Handles all fusion paths: +// FULLY_FUSED: output is quantized, kernel writes directly +// NORM+AMAX: output is hp + amax attached, kernel computes act + amax +// UNFUSED: output is hp, Python calls quantize_from_amax separately +// +// The Python shim selects the path by choosing output buffer configuration. +// ============================================================================ + +// Forward activation with pre-allocated output buffers (no quantizer dispatch) +// shape_divisor: 2 for GLU variants (output last dim = input last dim / 2), 1 otherwise +void activation_fwd_noalloc(Tensor input, Tensor output_data, int64_t output_te_dtype, + std::optional output_amax, std::optional output_scale, + std::optional output_scale_inv, int64_t scaling_mode, + int64_t activation_type) { + auto input_ = torch::stable::contiguous(input); + auto input_cu = makeTransformerEngineTensor(input_); + auto shape = getStableTensorShape(input_); + // Output shape may differ (GLU halves last dim) — use output_data's shape + auto out_shape = getStableTensorShape(output_data); + auto te_dtype = static_cast(output_te_dtype); + auto nvte_scaling = static_cast(scaling_mode); + + auto output_cu = makeQuantizedTensorWrapper(output_data, te_dtype, out_shape, output_amax, + output_scale, output_scale_inv, nvte_scaling); + + auto stream = getCurrentCUDAStreamRaw(input_.get_device_index()); + + // Dispatch activation type + using ActFn = void (*)(const NVTETensor, NVTETensor, cudaStream_t); + // Activation type enum matches the order in registration + static constexpr ActFn act_table[] = { + nvte_gelu, nvte_glu, nvte_geglu, nvte_qgelu, nvte_qgeglu, nvte_relu, + nvte_reglu, nvte_srelu, nvte_sreglu, nvte_silu, nvte_swiglu, + }; + constexpr int num_acts = sizeof(act_table) / sizeof(act_table[0]); + STD_TORCH_CHECK(activation_type >= 0 && activation_type < num_acts, + "Invalid activation_type: ", activation_type); + act_table[activation_type](input_cu.data(), output_cu.data(), stream); +} + +// Backward activation (grad_output, input → grad_input) +// Same noalloc pattern — output buffer may be quantized +void dactivation_noalloc(Tensor grad_output, Tensor input, Tensor grad_input_data, + int64_t grad_input_te_dtype, std::optional grad_input_amax, + std::optional grad_input_scale, + std::optional grad_input_scale_inv, int64_t scaling_mode, + int64_t activation_type) { + auto grad_output_ = torch::stable::contiguous(grad_output); + auto input_ = torch::stable::contiguous(input); + + auto grad_output_cu = makeTransformerEngineTensor(grad_output_); + auto input_cu = makeTransformerEngineTensor(input_); + auto grad_shape = getStableTensorShape(input_); + auto te_dtype = static_cast(grad_input_te_dtype); + auto nvte_scaling = static_cast(scaling_mode); + + auto grad_input_cu = + makeQuantizedTensorWrapper(grad_input_data, te_dtype, grad_shape, grad_input_amax, + grad_input_scale, grad_input_scale_inv, nvte_scaling); + + auto stream = getCurrentCUDAStreamRaw(input_.get_device_index()); + + using DActFn = void (*)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t); + static constexpr DActFn dact_table[] = { + nvte_dgelu, nvte_dglu, nvte_dgeglu, nvte_dqgelu, nvte_dqgeglu, nvte_drelu, + nvte_dreglu, nvte_dsrelu, nvte_dsreglu, nvte_dsilu, nvte_dswiglu, + }; + constexpr int num_acts = sizeof(dact_table) / sizeof(dact_table[0]); + STD_TORCH_CHECK(activation_type >= 0 && activation_type < num_acts, + "Invalid activation_type: ", activation_type); + dact_table[activation_type](grad_output_cu.data(), input_cu.data(), grad_input_cu.data(), stream); +} + +// Clamped activations (with extra float params) +void clamped_activation_fwd_noalloc(Tensor input, Tensor output_data, int64_t output_te_dtype, + std::optional output_amax, + std::optional output_scale, + std::optional output_scale_inv, int64_t scaling_mode, + double limit, double alpha, int64_t activation_type) { + auto input_ = torch::stable::contiguous(input); + auto input_cu = makeTransformerEngineTensor(input_); + auto out_shape = getStableTensorShape(output_data); + auto te_dtype = static_cast(output_te_dtype); + auto nvte_scaling = static_cast(scaling_mode); + auto output_cu = makeQuantizedTensorWrapper(output_data, te_dtype, out_shape, output_amax, + output_scale, output_scale_inv, nvte_scaling); + auto stream = getCurrentCUDAStreamRaw(input_.get_device_index()); + + // 0 = clamped_swiglu + if (activation_type == 0) { + nvte_clamped_swiglu(input_cu.data(), output_cu.data(), static_cast(limit), + static_cast(alpha), stream); + } else { + NVTE_ERROR("Invalid clamped activation_type: ", activation_type); } - - return out_py; } -template -py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input, - py::handle quantizer, Args&&... args) { - init_extension(); - - // Grad output and input tensors - auto grad_output_tensor = grad_output.contiguous(); - auto input_tensor = input.contiguous(); - const TensorWrapper& grad_output_nvte = makeTransformerEngineTensor(grad_output_tensor); - const TensorWrapper& input_nvte = makeTransformerEngineTensor(input_tensor); - - // Construct grad input tensor - auto quantizer_cpp = convert_quantizer(quantizer); - const auto input_shape_te = input_nvte.shape(); - const std::vector input_shape(input_shape_te.data, - input_shape_te.data + input_shape_te.ndim); - auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); - auto [grad_input_nvte, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype); - - // Choose implementation - enum class Impl { UNFUSED, FULLY_FUSED, FUSED_ACTIVATION_AMAX_FP8, FUSED_ACTIVATION_AMAX_NVFP4 }; - Impl impl = Impl::UNFUSED; - if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || - detail::IsMXFP8Quantizers(quantizer.ptr())) { - impl = Impl::FULLY_FUSED; - } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - impl = Impl::FUSED_ACTIVATION_AMAX_FP8; - } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { - auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); - NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer - impl = Impl::UNFUSED; - } else { - impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; - } +void clamped_dactivation_noalloc(Tensor grad_output, Tensor input, Tensor grad_input_data, + int64_t grad_input_te_dtype, std::optional grad_input_amax, + std::optional grad_input_scale, + std::optional grad_input_scale_inv, int64_t scaling_mode, + double limit, double alpha, int64_t activation_type) { + auto grad_output_ = torch::stable::contiguous(grad_output); + auto input_ = torch::stable::contiguous(input); + auto grad_output_cu = makeTransformerEngineTensor(grad_output_); + auto input_cu = makeTransformerEngineTensor(input_); + auto grad_shape = getStableTensorShape(input_); + auto te_dtype = static_cast(grad_input_te_dtype); + auto nvte_scaling = static_cast(scaling_mode); + auto grad_input_cu = + makeQuantizedTensorWrapper(grad_input_data, te_dtype, grad_shape, grad_input_amax, + grad_input_scale, grad_input_scale_inv, nvte_scaling); + auto stream = getCurrentCUDAStreamRaw(input_.get_device_index()); + + if (activation_type == 0) { + nvte_clamped_dswiglu(grad_output_cu.data(), input_cu.data(), grad_input_cu.data(), + static_cast(limit), static_cast(alpha), stream); + } else { + NVTE_ERROR("Invalid clamped activation_type: ", activation_type); } - - // Perform compute - auto stream = at::cuda::getCurrentCUDAStream(); - switch (impl) { - case Impl::UNFUSED: - // Compute activation backward in high precision, then quantize - { - auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ - if constexpr (dact_func == nullptr) { - dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), - std::forward(args)..., stream); - } else { - dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); - } - }); - quantizer_cpp->quantize(temp_nvte, grad_input_nvte); - } - break; - case Impl::FULLY_FUSED: - // Compute activation backward directly - { - NVTE_SCOPED_GIL_RELEASE({ - if constexpr (dact_func == nullptr) { - dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(), - std::forward(args)..., stream); - } else { - dact_func(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(), stream); - } - }); - } - break; - case Impl::FUSED_ACTIVATION_AMAX_FP8: - // Compute activation and amax in high precision, then quantize to FP8 - { - auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); - NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); - auto [temp_nvte, _] = - fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ - if constexpr (dact_func == nullptr) { - dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), - std::forward(args)..., stream); - } else { - dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); - } - }); - fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); - } - break; - case Impl::FUSED_ACTIVATION_AMAX_NVFP4: - // Compute activation and amax in high precision, then quantize to NVFP4 - { - auto nvfp4_quantizer_cpp = - static_cast(quantizer_cpp.get()); // Already checked cast is valid - auto [temp_nvte, _] = - nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(grad_input_nvte, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ - if constexpr (dact_func == nullptr) { - dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), - std::forward(args)..., stream); - } else { - dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); - } - }); - nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); - } - break; - default: - NVTE_ERROR("Invalid activation implementation (", static_cast(impl), ")"); - } - - return grad_input_py; -} -} // namespace - -/* GELU and variants */ -py::object gelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); -} - -py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); -} - -py::object glu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); -} - -py::object dglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); -} - -py::object geglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); -} - -py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); -} - -py::object qgelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); -} - -py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); } -py::object qgeglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); +} // namespace transformer_engine::pytorch::stable + +STABLE_TORCH_LIBRARY_FRAGMENT(transformer_engine_stable, m) { + // activation_type: 0=gelu, 1=glu, 2=geglu, 3=qgelu, 4=qgeglu, + // 5=relu, 6=reglu, 7=srelu, 8=sreglu, 9=silu, 10=swiglu + m.def( + "activation_fwd_noalloc(Tensor input, Tensor output_data, int output_te_dtype, Tensor? " + "output_amax, Tensor? output_scale, Tensor? output_scale_inv, int scaling_mode, int " + "activation_type) -> ()"); + m.def( + "dactivation_noalloc(Tensor grad_output, Tensor input, Tensor grad_input_data, int " + "grad_input_te_dtype, Tensor? grad_input_amax, Tensor? grad_input_scale, Tensor? " + "grad_input_scale_inv, int scaling_mode, int activation_type) -> ()"); + m.def( + "clamped_activation_fwd_noalloc(Tensor input, Tensor output_data, int output_te_dtype, " + "Tensor? output_amax, Tensor? output_scale, Tensor? output_scale_inv, int scaling_mode, " + "float limit, float alpha, int activation_type) -> ()"); + m.def( + "clamped_dactivation_noalloc(Tensor grad_output, Tensor input, Tensor grad_input_data, int " + "grad_input_te_dtype, Tensor? grad_input_amax, Tensor? grad_input_scale, Tensor? " + "grad_input_scale_inv, int scaling_mode, float limit, float alpha, int activation_type) -> " + "()"); +} + +STABLE_TORCH_LIBRARY_IMPL(transformer_engine_stable, CUDA, m) { + using namespace transformer_engine::pytorch::stable; + m.impl("activation_fwd_noalloc", TORCH_BOX(activation_fwd_noalloc)); + m.impl("dactivation_noalloc", TORCH_BOX(dactivation_noalloc)); + m.impl("clamped_activation_fwd_noalloc", TORCH_BOX(clamped_activation_fwd_noalloc)); + m.impl("clamped_dactivation_noalloc", TORCH_BOX(clamped_dactivation_noalloc)); } - -py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); -} - -/* ReLU and variants */ -py::object relu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); -} - -py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); -} - -py::object reglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); -} - -py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); -} - -py::object srelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); -} - -py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); -} - -py::object sreglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); -} - -py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); -} -/* Silu and variants */ -py::object silu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); -} - -py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); -} - -py::object swiglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); -} - -py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); -} - -/* clamped functions */ -py::object clamped_swiglu(const at::Tensor& input, py::handle quantizer, float limit, float alpha) { - return activation_helper(input, quantizer, 2, limit, alpha); -} - -py::object clamped_dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, - float limit, float alpha) { - return dactivation_helper(grad, input, quantizer, limit, alpha); -} - -} // namespace pytorch -} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index 4392fa4b43..8662a8953b 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -4,283 +4,236 @@ * See LICENSE for license information. ************************************************************************/ -#include "../extensions.h" -#include "common.h" - -namespace transformer_engine::pytorch { - -at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, - const std::optional start_positions, - const NVTE_QKV_Format qkv_format, const bool interleaved, - const std::optional cu_seqlens, const int cp_size, - const int cp_rank) { - TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); - TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, - "expected the second and third dims of the freqs tensor equal 1"); - TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, - "Dtype of the freqs tensor must be float"); - - // output - auto act_options = at::TensorOptions().dtype(input.scalar_type()).device(input.device()); - auto output = at::empty(input.sizes(), act_options); +#include + +#include "../stable_common.h" + +namespace transformer_engine::pytorch::stable { + +using Tensor = torch::stable::Tensor; + +Tensor fused_rope_forward(Tensor input, Tensor freqs, std::optional start_positions, + int64_t qkv_format, bool interleaved, std::optional cu_seqlens, + int64_t cp_size, int64_t cp_rank) { + auto nvte_qkv_format = static_cast(qkv_format); + + STD_TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); + STD_TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, + "expected the second and third dims of the freqs tensor equal 1"); + STD_TORCH_CHECK(freqs.scalar_type() == ScalarType::Float, + "Dtype of the freqs tensor must be float"); + + // Allocate contiguous output (must NOT use empty_like which preserves + // non-contiguous strides from transposed inputs) + auto sizes = input.sizes(); + std::vector shape_vec(sizes.begin(), sizes.end()); + auto output = allocateStableTensor(shape_vec, input.scalar_type(), input.get_device_index()); auto input_cu = makeTransformerEngineTensor(input); auto freqs_cu = makeTransformerEngineTensor(freqs); auto output_cu = makeTransformerEngineTensor(output); - auto start_positions_cu = TensorWrapper(); // empty start_positions tensor - if (start_positions) { + auto start_positions_cu = TensorWrapper(); + if (start_positions.has_value()) { start_positions_cu = makeTransformerEngineTensor(start_positions.value()); - TORCH_CHECK(start_positions_cu.ndim() == 1, "expected 1D tensor"); + STD_TORCH_CHECK(start_positions_cu.ndim() == 1, "expected 1D tensor"); } - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); - TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor"); - TORCH_CHECK(cu_seqlens.value().dim() == 1, "expected 1D tensor"); - TORCH_CHECK(input.size(2) >= freqs.size(3), - "expected the last dim of the input tensor equals or is " - "greater than the freqs tensor"); - - // input sizes: (t, h, d) - // t: cumulative sum of sequence lengths - // h: head num - // d: dim of each head - // const int t = input.size(0); - const int h = input.size(1); - const int d = input.size(2); - // input strides - const int stride_t = input.stride(0); - const int stride_h = input.stride(1); - const int stride_d = input.stride(2); - // batch size - const int b = cu_seqlens.value().size(0) - 1; - // freqs' shape is (max_s, 1, 1, d2) - const int max_s = freqs.size(0); - const int d2 = freqs.size(3); + auto stream = getCurrentCUDAStreamRaw(input.get_device_index()); + + if (nvte_qkv_format == NVTE_QKV_Format::NVTE_THD) { + STD_TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); + STD_TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor"); + STD_TORCH_CHECK(cu_seqlens.value().dim() == 1, "expected 1D tensor"); + + const int h = static_cast(input.size(1)); + const int d = static_cast(input.size(2)); + const int stride_t = static_cast(input.stride(0)); + const int stride_h = static_cast(input.stride(1)); + const int stride_d = static_cast(input.stride(2)); + const int b = static_cast(cu_seqlens.value().size(0) - 1); + const int max_s = static_cast(freqs.size(0)); + const int d2 = static_cast(freqs.size(3)); auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value()); nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - start_positions_cu.data(), output_cu.data(), qkv_format, interleaved, - cp_size, cp_rank, max_s, b, h, d, d2, stride_t, /*stride_b=*/0, - stride_h, stride_d, at::cuda::getCurrentCUDAStream()); + start_positions_cu.data(), output_cu.data(), nvte_qkv_format, + interleaved, static_cast(cp_size), static_cast(cp_rank), + max_s, b, h, d, d2, stride_t, 0, stride_h, stride_d, stream); return output; } - TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); - // input sizes: (s, b, h, d) or (b, s, h, d) - // s: sequence length - // b: batch size - // h: head num - // d: dim of each head - const int s = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.size(0) : input.size(1); - const int b = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.size(1) : input.size(0); - const int h = input.size(2); - const int d = input.size(3); - // input strides - const int stride_s = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.stride(0) : input.stride(1); - const int stride_b = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.stride(1) : input.stride(0); - const int stride_h = input.stride(2); - const int stride_d = input.stride(3); - // freqs' shape is always (s, 1, 1, d2), so the strides are same under - // different memory formats - const int d2 = freqs.size(3); - - TORCH_CHECK(s * cp_size <= freqs.size(0), - "expected freqs tensor has a longer sequence length than input"); - TORCH_CHECK(d >= d2, - "expected the last dim of the input tensor equals or is " - "greater than the freqs tensor"); - - auto cu_seqlens_cu = TensorWrapper(); // empty cu_seqlens tensor + STD_TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); + const bool is_sbhd = nvte_qkv_format == NVTE_QKV_Format::NVTE_SBHD; + const int s = static_cast(is_sbhd ? input.size(0) : input.size(1)); + const int b = static_cast(is_sbhd ? input.size(1) : input.size(0)); + const int h = static_cast(input.size(2)); + const int d = static_cast(input.size(3)); + const int stride_s = static_cast(is_sbhd ? input.stride(0) : input.stride(1)); + const int stride_b = static_cast(is_sbhd ? input.stride(1) : input.stride(0)); + const int stride_h = static_cast(input.stride(2)); + const int stride_d = static_cast(input.stride(3)); + const int d2 = static_cast(freqs.size(3)); + + auto cu_seqlens_cu = TensorWrapper(); nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - start_positions_cu.data(), output_cu.data(), qkv_format, interleaved, - cp_size, cp_rank, s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, - at::cuda::getCurrentCUDAStream()); + start_positions_cu.data(), output_cu.data(), nvte_qkv_format, interleaved, + static_cast(cp_size), static_cast(cp_rank), s, b, h, d, d2, + stride_s, stride_b, stride_h, stride_d, stream); return output; } -std::tuple fused_qkv_rope_forward( - const at::Tensor &qkv_input, const at::Tensor &q_freqs, const at::Tensor &k_freqs, - const std::optional start_positions, const std::vector &qkv_split_arg_list, - const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, - const int cp_rank) { - TORCH_CHECK(q_freqs.dim() == 4, "expected 4D tensor"); - TORCH_CHECK(q_freqs.size(1) == 1 && q_freqs.size(2) == 1, - "expected the second and third dims of the freqs tensor equal 1"); - TORCH_CHECK(q_freqs.scalar_type() == at::ScalarType::Float, - "Dtype of the freqs tensor must be float"); - TORCH_CHECK(k_freqs.dim() == 4, "expected 4D tensor"); - TORCH_CHECK(k_freqs.size(1) == 1 && k_freqs.size(2) == 1, - "expected the second and third dims of the freqs tensor equal 1"); - TORCH_CHECK(k_freqs.scalar_type() == at::ScalarType::Float, - "Dtype of the freqs tensor must be float"); - // output - auto act_options = at::TensorOptions().dtype(qkv_input.scalar_type()).device(qkv_input.device()); - auto q_out_size = qkv_input.sizes().vec(); - q_out_size[2] = q_out_size[2] * qkv_split_arg_list[0] / qkv_split_arg_list[1]; - q_out_size[3] = qkv_split_arg_list[1]; - auto q_out = at::empty(q_out_size, act_options); - auto k_out_size = qkv_input.sizes().vec(); - k_out_size[3] = qkv_split_arg_list[1]; - auto k_out = at::empty(k_out_size, act_options); - auto v_out_size = qkv_input.sizes().vec(); - v_out_size[3] = qkv_split_arg_list[2]; - auto v_out = at::empty(v_out_size, act_options); +Tensor fused_rope_backward(Tensor output_grads, Tensor freqs, std::optional start_positions, + int64_t qkv_format, bool interleaved, std::optional cu_seqlens, + int64_t cp_size, int64_t cp_rank) { + auto nvte_qkv_format = static_cast(qkv_format); - auto qkv_cu = makeTransformerEngineTensor(qkv_input); - auto q_freqs_cu = makeTransformerEngineTensor(q_freqs); - auto k_freqs_cu = makeTransformerEngineTensor(k_freqs); - auto q_out_cu = makeTransformerEngineTensor(q_out); - auto k_out_cu = makeTransformerEngineTensor(k_out); - auto v_out_cu = makeTransformerEngineTensor(v_out); - - auto start_positions_cu = TensorWrapper(); // empty cu_seqlens tensor - if (start_positions) { - start_positions_cu = makeTransformerEngineTensor(start_positions.value()); - } - - TORCH_CHECK(qkv_input.dim() == 4, "expected 4D input tensor"); - TORCH_CHECK(qkv_input.is_contiguous(), "input tensor must be contiguous"); - - const bool is_sbhd = qkv_format == NVTE_QKV_Format::NVTE_SBHD; - const int s = is_sbhd ? qkv_input.size(0) : qkv_input.size(1); - const int b = is_sbhd ? qkv_input.size(1) : qkv_input.size(0); - const int h = qkv_input.size(2); - const int d = qkv_split_arg_list[2]; - const int d2 = q_freqs.size(3); + STD_TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); + STD_TORCH_CHECK(freqs.scalar_type() == ScalarType::Float, + "Dtype of the freqs tensor must be float"); - nvte_fused_qkv_rope_forward(qkv_cu.data(), q_freqs_cu.data(), k_freqs_cu.data(), - start_positions_cu.data(), q_out_cu.data(), k_out_cu.data(), - v_out_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b, h, - d, d2, qkv_split_arg_list[0], qkv_split_arg_list[1], - qkv_split_arg_list[2], at::cuda::getCurrentCUDAStream()); - - return std::make_tuple(q_out, k_out, v_out); -} - -at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, - const std::optional start_positions, - const NVTE_QKV_Format qkv_format, const bool interleaved, - const std::optional cu_seqlens, const int cp_size, - const int cp_rank) { - TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); - TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, - "expected the second and third dims of the freqs tensor equal 1"); - TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, - "Dtype of the freqs tensor must be float"); - - auto act_options = - at::TensorOptions().dtype(output_grads.scalar_type()).device(output_grads.device()); - auto input_grads = at::empty(output_grads.sizes(), act_options); + auto og_sizes = output_grads.sizes(); + std::vector og_shape(og_sizes.begin(), og_sizes.end()); + auto input_grads = + allocateStableTensor(og_shape, output_grads.scalar_type(), output_grads.get_device_index()); auto output_grads_cu = makeTransformerEngineTensor(output_grads); auto freqs_cu = makeTransformerEngineTensor(freqs); auto input_grads_cu = makeTransformerEngineTensor(input_grads); - auto start_positions_cu = TensorWrapper(); // empty start_positions tensor - if (start_positions) { + auto start_positions_cu = TensorWrapper(); + if (start_positions.has_value()) { start_positions_cu = makeTransformerEngineTensor(start_positions.value()); - TORCH_CHECK(start_positions_cu.ndim() == 1, "expected 1D tensor"); } - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); - TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor"); - TORCH_CHECK(cu_seqlens.value().dim() == 1, "expected 1D tensor"); - TORCH_CHECK(output_grads.size(2) >= freqs.size(3), - "expected the last dim of the output_grads tensor equals or is " - "greater than the freqs tensor"); - - // output_grads sizes: (t, h, d) - // t: cumulative sum of sequence lengths - // h: head num - // d: dim of each head - // const int t = output_grads.size(0); - const int h = output_grads.size(1); - const int d = output_grads.size(2); - // output_grads strides - const int stride_t = output_grads.stride(0); - const int stride_h = output_grads.stride(1); - const int stride_d = output_grads.stride(2); - // batch size - const int b = cu_seqlens.value().size(0) - 1; - // freqs' shape is (max_s, 1, 1, d2) - const int max_s = freqs.size(0); - const int d2 = freqs.size(3); + auto stream = getCurrentCUDAStreamRaw(output_grads.get_device_index()); + + if (nvte_qkv_format == NVTE_QKV_Format::NVTE_THD) { + STD_TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); + STD_TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor"); + + const int h = static_cast(output_grads.size(1)); + const int d = static_cast(output_grads.size(2)); + const int stride_t = static_cast(output_grads.stride(0)); + const int stride_h = static_cast(output_grads.stride(1)); + const int stride_d = static_cast(output_grads.stride(2)); + const int b = static_cast(cu_seqlens.value().size(0) - 1); + const int max_s = static_cast(freqs.size(0)); + const int d2 = static_cast(freqs.size(3)); auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value()); nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - start_positions_cu.data(), input_grads_cu.data(), qkv_format, - interleaved, cp_size, cp_rank, max_s, b, h, d, d2, stride_t, - /*stride_b=*/0, stride_h, stride_d, at::cuda::getCurrentCUDAStream()); + start_positions_cu.data(), input_grads_cu.data(), nvte_qkv_format, + interleaved, static_cast(cp_size), static_cast(cp_rank), + max_s, b, h, d, d2, stride_t, 0, stride_h, stride_d, stream); return input_grads; } - TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); - // output_grads sizes: (s, b, h, d) - // s: sequence length - // b: batch size - // h: head num - // d: dim of each head - const int s = - qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.size(0) : output_grads.size(1); - const int b = - qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.size(1) : output_grads.size(0); - const int h = output_grads.size(2); - const int d = output_grads.size(3); - // output_grads strides - const int stride_s = - qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.stride(0) : output_grads.stride(1); - const int stride_b = - qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.stride(1) : output_grads.stride(0); - const int stride_h = output_grads.stride(2); - const int stride_d = output_grads.stride(3); - // freqs' shape is always (s, 1, 1, d2), so the strides are same under - // different memory formats - const int d2 = freqs.size(3); - - TORCH_CHECK(s * cp_size <= freqs.size(0), - "expected freqs tensor has a longer sequence length than output_grads"); - TORCH_CHECK(d >= d2, - "expected the last dim of the output_grads tensor equals or is " - "greater than the freqs tensor"); - - auto cu_seqlens_cu = TensorWrapper(); // empty cu_seqlens tensor + STD_TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); + const bool is_sbhd = nvte_qkv_format == NVTE_QKV_Format::NVTE_SBHD; + const int s = static_cast(is_sbhd ? output_grads.size(0) : output_grads.size(1)); + const int b = static_cast(is_sbhd ? output_grads.size(1) : output_grads.size(0)); + const int h = static_cast(output_grads.size(2)); + const int d = static_cast(output_grads.size(3)); + const int stride_s = static_cast(is_sbhd ? output_grads.stride(0) : output_grads.stride(1)); + const int stride_b = static_cast(is_sbhd ? output_grads.stride(1) : output_grads.stride(0)); + const int stride_h = static_cast(output_grads.stride(2)); + const int stride_d = static_cast(output_grads.stride(3)); + const int d2 = static_cast(freqs.size(3)); + + auto cu_seqlens_cu = TensorWrapper(); nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - start_positions_cu.data(), input_grads_cu.data(), qkv_format, - interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s, stride_b, - stride_h, stride_d, at::cuda::getCurrentCUDAStream()); + start_positions_cu.data(), input_grads_cu.data(), nvte_qkv_format, + interleaved, static_cast(cp_size), static_cast(cp_rank), s, b, + h, d, d2, stride_s, stride_b, stride_h, stride_d, stream); return input_grads; } -at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tensor &k_grad_out, - const at::Tensor &v_grad_out, const at::Tensor &q_freqs, - const at::Tensor &k_freqs, - const std::vector &qkv_split_arg_list, - const NVTE_QKV_Format qkv_format, const bool interleaved, - const int cp_size, const int cp_rank) { - auto act_options = - at::TensorOptions().dtype(q_grad_out.scalar_type()).device(q_grad_out.device()); - auto qkv_grad_size = q_grad_out.sizes().vec(); +std::tuple fused_qkv_rope_forward(Tensor qkv_input, Tensor q_freqs, + Tensor k_freqs, + std::optional start_positions, + std::vector qkv_split_arg_list, + int64_t qkv_format, bool interleaved, + int64_t cp_size, int64_t cp_rank) { + auto nvte_qkv_format = static_cast(qkv_format); + + STD_TORCH_CHECK(q_freqs.dim() == 4, "expected 4D tensor"); + STD_TORCH_CHECK(k_freqs.dim() == 4, "expected 4D tensor"); + STD_TORCH_CHECK(qkv_input.dim() == 4, "expected 4D input tensor"); + STD_TORCH_CHECK(qkv_input.is_contiguous(), "input tensor must be contiguous"); + + auto sizes = qkv_input.sizes(); + auto dtype = qkv_input.scalar_type(); + auto device_idx = qkv_input.get_device_index(); + + // q_out shape + std::vector q_out_size = {sizes[0], sizes[1], + sizes[2] * qkv_split_arg_list[0] / qkv_split_arg_list[1], + qkv_split_arg_list[1]}; + auto q_out = allocateStableTensor(q_out_size, dtype, device_idx); + + std::vector k_out_size = {sizes[0], sizes[1], sizes[2], qkv_split_arg_list[1]}; + auto k_out = allocateStableTensor(k_out_size, dtype, device_idx); + + std::vector v_out_size = {sizes[0], sizes[1], sizes[2], qkv_split_arg_list[2]}; + auto v_out = allocateStableTensor(v_out_size, dtype, device_idx); + + auto qkv_cu = makeTransformerEngineTensor(qkv_input); + auto q_freqs_cu = makeTransformerEngineTensor(q_freqs); + auto k_freqs_cu = makeTransformerEngineTensor(k_freqs); + auto q_out_cu = makeTransformerEngineTensor(q_out); + auto k_out_cu = makeTransformerEngineTensor(k_out); + auto v_out_cu = makeTransformerEngineTensor(v_out); + + auto start_positions_cu = TensorWrapper(); + if (start_positions.has_value()) { + start_positions_cu = makeTransformerEngineTensor(start_positions.value()); + } + + const bool is_sbhd = nvte_qkv_format == NVTE_QKV_Format::NVTE_SBHD; + const int s = static_cast(is_sbhd ? qkv_input.size(0) : qkv_input.size(1)); + const int b = static_cast(is_sbhd ? qkv_input.size(1) : qkv_input.size(0)); + const int h = static_cast(qkv_input.size(2)); + const int d = static_cast(qkv_split_arg_list[2]); + const int d2 = static_cast(q_freqs.size(3)); + + nvte_fused_qkv_rope_forward( + qkv_cu.data(), q_freqs_cu.data(), k_freqs_cu.data(), start_positions_cu.data(), + q_out_cu.data(), k_out_cu.data(), v_out_cu.data(), nvte_qkv_format, interleaved, + static_cast(cp_size), static_cast(cp_rank), s, b, h, d, d2, + static_cast(qkv_split_arg_list[0]), static_cast(qkv_split_arg_list[1]), + static_cast(qkv_split_arg_list[2]), getCurrentCUDAStreamRaw(device_idx)); + + return std::make_tuple(q_out, k_out, v_out); +} + +Tensor fused_qkv_rope_backward(Tensor q_grad_out, Tensor k_grad_out, Tensor v_grad_out, + Tensor q_freqs, Tensor k_freqs, + std::vector qkv_split_arg_list, int64_t qkv_format, + bool interleaved, int64_t cp_size, int64_t cp_rank) { + auto nvte_qkv_format = static_cast(qkv_format); + auto dtype = q_grad_out.scalar_type(); + auto device_idx = q_grad_out.get_device_index(); + auto total_hd = (q_grad_out.size(2) + k_grad_out.size(2) + v_grad_out.size(2)) * q_grad_out.size(3); auto total_d = qkv_split_arg_list[0] + qkv_split_arg_list[1] + qkv_split_arg_list[2]; - qkv_grad_size[2] = total_hd / total_d; - qkv_grad_size[3] = total_d; - auto qkv_grad_input = at::empty(qkv_grad_size, act_options); - const bool is_sbhd = qkv_format == NVTE_QKV_Format::NVTE_SBHD; - const int s = is_sbhd ? q_grad_out.size(0) : q_grad_out.size(1); - const int b = is_sbhd ? q_grad_out.size(1) : q_grad_out.size(0); - const int h = qkv_grad_input.size(2); - const int d = qkv_split_arg_list[2]; - const int d2 = q_freqs.size(3); + std::vector qkv_grad_size = {q_grad_out.size(0), q_grad_out.size(1), total_hd / total_d, + total_d}; + auto qkv_grad_input = allocateStableTensor(qkv_grad_size, dtype, device_idx); + + const bool is_sbhd = nvte_qkv_format == NVTE_QKV_Format::NVTE_SBHD; + const int s = static_cast(is_sbhd ? q_grad_out.size(0) : q_grad_out.size(1)); + const int b = static_cast(is_sbhd ? q_grad_out.size(1) : q_grad_out.size(0)); + const int h = static_cast(qkv_grad_size[2]); + const int d = static_cast(qkv_split_arg_list[2]); + const int d2 = static_cast(q_freqs.size(3)); auto q_grad_out_cu = makeTransformerEngineTensor(q_grad_out); auto k_grad_out_cu = makeTransformerEngineTensor(k_grad_out); @@ -289,13 +242,21 @@ at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tenso auto k_freqs_cu = makeTransformerEngineTensor(k_freqs); auto qkv_grad_cu = makeTransformerEngineTensor(qkv_grad_input); - nvte_fused_qkv_rope_backward(q_grad_out_cu.data(), k_grad_out_cu.data(), v_grad_out_cu.data(), - q_freqs_cu.data(), k_freqs_cu.data(), qkv_grad_cu.data(), qkv_format, - interleaved, cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list[0], - qkv_split_arg_list[1], qkv_split_arg_list[2], - at::cuda::getCurrentCUDAStream()); + nvte_fused_qkv_rope_backward( + q_grad_out_cu.data(), k_grad_out_cu.data(), v_grad_out_cu.data(), q_freqs_cu.data(), + k_freqs_cu.data(), qkv_grad_cu.data(), nvte_qkv_format, interleaved, + static_cast(cp_size), static_cast(cp_rank), s, b, h, d, d2, + static_cast(qkv_split_arg_list[0]), static_cast(qkv_split_arg_list[1]), + static_cast(qkv_split_arg_list[2]), getCurrentCUDAStreamRaw(device_idx)); return qkv_grad_input; } -} // namespace transformer_engine::pytorch +STABLE_TORCH_LIBRARY_IMPL(transformer_engine_stable, CUDA, m) { + m.impl("fused_rope_forward", TORCH_BOX(fused_rope_forward)); + m.impl("fused_rope_backward", TORCH_BOX(fused_rope_backward)); + m.impl("fused_qkv_rope_forward", TORCH_BOX(fused_qkv_rope_forward)); + m.impl("fused_qkv_rope_backward", TORCH_BOX(fused_qkv_rope_backward)); +} + +} // namespace transformer_engine::pytorch::stable diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index ff60bb87bb..98a763c297 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -4,870 +4,727 @@ * See LICENSE for license information. ************************************************************************/ -#include "../extensions.h" -#include "common.h" -#include "pybind.h" +#include -namespace { +#include "../stable_common.h" -constexpr int block_size = 512; +namespace transformer_engine::pytorch::stable { -// fast zero-fills of tensors -void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &start_index) { - std::vector shape = transformer_engine::pytorch::convertShape(self.shape()); +using Tensor = torch::stable::Tensor; - auto max_tokens = shape[0]; - auto fcd_size = 1; - for (size_t i = 1; i <= shape.size(); i++) { - fcd_size *= shape[i]; - } - - NVTE_CHECK(fcd_size % block_size == 0, "input size not aligned to block size"); - - size_t element_size_bits = transformer_engine::pytorch::typeToNumBits(self.dtype()); - int32_t start_row = start_index.data_ptr()[0]; - void *base_ptr = static_cast(self.get_rowwise_data().data_ptr) + - static_cast(start_row) * fcd_size * element_size_bits / 8; - size_t num_rows_to_zero = max_tokens - start_row; - size_t total_bytes = num_rows_to_zero * fcd_size * element_size_bits / 8; - - NVTE_SCOPED_GIL_RELEASE( - { nvte_memset(base_ptr, 0, total_bytes, at::cuda::getCurrentCUDAStream()); }); -} - -} // namespace - -namespace transformer_engine::pytorch { - -// get the fused attention backend -NVTE_Fused_Attn_Backend get_fused_attn_backend( - bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, - bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, - max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, - return_max_logit, cuda_graph, deterministic); - return fused_attention_backend; -} - -// helper function for S and dP quantizers -std::pair quantizer_helper(py::handle quantizer, - const std::vector &shape, DType dtype, - bool create_hp_tensor_for_cs, - std::optional data) { - std::unique_ptr T_quantizer = convert_quantizer(quantizer); - TensorWrapper te_T; - py::object py_T; - if (quantizer.is_none()) { - // high precision - auto *none_quantizer = dynamic_cast(T_quantizer.get()); - if (data.has_value()) { - std::tie(te_T, py_T) = none_quantizer->create_tensor(shape, dtype, data.value()); - } else { - std::tie(te_T, py_T) = none_quantizer->create_tensor(shape, dtype); - } - } else if (detail::IsFloat8Quantizers(quantizer.ptr())) { - // delayed scaling; this helps initialize scale_inv - auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); - std::tie(te_T, py_T) = - T_quantizer_fp8->create_tensor(shape, dtype, data, std::nullopt, std::nullopt); - } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - // current scaling - auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); - if (create_hp_tensor_for_cs) { - if (data.has_value()) { - std::tie(te_T, py_T) = - T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, data.value()); - } else { - std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype); - } - } else { - std::tie(te_T, py_T) = T_quantizer_fp8->create_tensor(shape, dtype); - NVTE_CHECK( - !data.has_value(), - "Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"); - } - } - return {std::move(te_T), std::move(py_T)}; -} - -// fused attention FWD with separate Q, K and V tensors -std::vector fused_attn_fwd( - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, - bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - const std::vector window_size, bool bottom_right_diagonal, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, - const py::handle K, const py::handle V, const at::ScalarType fake_dtype, - const std::optional cu_seqlens_q_padded, - const std::optional cu_seqlens_kv_padded, - const std::optional page_table_k, const std::optional page_table_v, - py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, - const std::optional SoftmaxOffset, const std::optional rng_gen, - size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph) { - // Ensure that cuDNN handle is created on the correct device, - // overriding torch.cuda.set_device calls from user side. - // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(cu_seqlens_q.device()); - - auto none = py::none(); - - // create QKV tensor wrappers - TensorWrapper te_Q, te_K, te_V; - te_Q = makeTransformerEngineTensor(Q, none); - te_K = makeTransformerEngineTensor(K, none); - te_V = makeTransformerEngineTensor(V, none); - const DType qkv_type = te_Q.dtype(); - - // create S tensor - TensorWrapper te_S; - py::object py_S; - std::tie(te_S, py_S) = quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt); - - // create O tensor - TensorWrapper te_O; - py::object py_O; - std::unique_ptr O_quantizer = convert_quantizer(o_quantizer); - std::vector q_shape = convertShape(te_Q.shape()); - std::vector v_shape = convertShape(te_V.shape()); - auto o_shape = std::vector{q_shape.begin(), q_shape.end()}; - o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; - const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); - std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); - - // construct NVTE tensors - TensorWrapper te_Bias; - TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; - TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded; - TensorWrapper te_page_table_k, te_page_table_v; - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - // FP8 - auto h = q_shape[q_shape.size() - 2]; - auto d = q_shape[q_shape.size() - 1]; - if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - if ((h * d) % block_size == 0) { - mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - te_O.zero_(at::cuda::getCurrentCUDAStream()); - } - } - } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - te_O.zero_(at::cuda::getCurrentCUDAStream()); - } - } else { - NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); - } - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { - auto bias_sizes = Bias.value().sizes().vec(); - std::vector bias_shape{bias_sizes.begin(), bias_sizes.end()}; - te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, DType::kFloat32); - } - auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); - std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; - auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); - std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; - te_cu_seqlens_q = - makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, DType::kInt32); - te_cu_seqlens_kv = - makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, DType::kInt32); - - if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { - auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec(); - std::vector cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(), - cu_seqlens_q_padded_sizes.end()}; - auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); - std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), - cu_seqlens_kv_padded_sizes.end()}; - te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - cu_seqlens_q_padded_shape, DType::kInt32); - te_cu_seqlens_kv_padded = makeTransformerEngineTensor( - cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32); - } - - if ((page_table_k.has_value()) && (page_table_v.has_value())) { - auto page_table_k_sizes = page_table_k.value().sizes().vec(); - std::vector page_table_k_shape{page_table_k_sizes.begin(), page_table_k_sizes.end()}; - auto page_table_v_sizes = page_table_v.value().sizes().vec(); - std::vector page_table_v_shape{page_table_v_sizes.begin(), page_table_v_sizes.end()}; - te_page_table_k = - makeTransformerEngineTensor(page_table_k.value().data_ptr(), page_table_k_shape, - DType::kInt32, nullptr, nullptr, nullptr); - te_page_table_v = - makeTransformerEngineTensor(page_table_v.value().data_ptr(), page_table_v_shape, - DType::kInt32, nullptr, nullptr, nullptr); - } - - // softmax offset - TensorWrapper te_SoftmaxOffset; - if ((softmax_type != NVTE_VANILLA_SOFTMAX) && (SoftmaxOffset.has_value())) { - auto SoftmaxOffset_sizes = SoftmaxOffset.value().sizes().vec(); - std::vector SoftmaxOffset_shape{SoftmaxOffset_sizes.begin(), SoftmaxOffset_sizes.end()}; - te_SoftmaxOffset = - makeTransformerEngineTensor(SoftmaxOffset.value().data_ptr(), SoftmaxOffset_shape, - DType::kFloat32, nullptr, nullptr, nullptr); - } - - // extract rng seed and offset - auto gen = at::get_generator_or_default( - rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); - at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); - auto options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); - auto rng_state = torch::empty({2}, options); - philox_unpack(philox_args, static_cast(rng_state.data_ptr())); - auto te_rng_state = makeTransformerEngineTensor(rng_state); - - // create auxiliary output tensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - // create workspace - TensorWrapper workspace; - - // populate tensors with appropriate shapes and dtypes - NVTE_SCOPED_GIL_RELEASE({ - nvte_fused_attn_fwd( - te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_SoftmaxOffset.data(), te_S.data(), - te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), - te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), - at::cuda::getCurrentCUDAStream()); - }); - - // allocate memory for workspace and auxiliary output tensors - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // output_tensors = [O, nvte_aux_tensor_pack.tensors] - std::vector output_tensors; - output_tensors.push_back(py_O); - auto set_tensor_param = [&](size_t i, const at::Tensor &output_tensor) { - output_tensors.push_back(py::cast(output_tensor)); - NVTEBasicTensor temp_data = {output_tensor.data_ptr(), - nvte_tensor_type(nvte_aux_tensor_pack.tensors[i]), - nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])}; - nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); - }; - // allocate memory for nvte_aux_tensor_pack.tensors - // f16_max512 : S [b, h, sq, skv] - // f16_arbitrary: - // return_max_logit=false: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] - // return_max_logit=true: S [b, h, sq, 1], Max [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] - // fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2] - size_t i = 0; - at::Tensor output_tensor; - // intermediate softmax tensor, S or M (for fp8) - output_tensor = - allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), - static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); - set_tensor_param(i++, output_tensor); - // fp8 has an additional softmax stats tensor, ZInv; return_max_logit=true has an additional Max tensor - if (return_max_logit || qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - output_tensor = - allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), - static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); - set_tensor_param(i++, output_tensor); - } - // rng_state - if (i < nvte_aux_tensor_pack.size) { - set_tensor_param(i++, rng_state); - } - // bias (optional) - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { - set_tensor_param(i++, Bias.value()); - } - // softmax_offset (optional) - if ((softmax_type != NVTE_VANILLA_SOFTMAX) && (SoftmaxOffset.has_value())) { - set_tensor_param(i++, SoftmaxOffset.value()); - } +// ============================================================================ +// Flash Attention prepare helpers +// ============================================================================ - // execute the kernel - NVTE_SCOPED_GIL_RELEASE({ - nvte_fused_attn_fwd( - te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_SoftmaxOffset.data(), te_S.data(), - te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), - te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), - at::cuda::getCurrentCUDAStream()); - }); - - // destroy tensor wrappers, but not allocated memory - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - - // if training, [O, softmax-related tensors, rng_state]; if inference, [O] - return output_tensors; -} - -// fused attention BWD with separate Q, K and V -std::vector fused_attn_bwd( - size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, const std::vector window_size, - bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, - const std::vector Aux_CTX_Tensors, - const std::optional cu_seqlens_q_padded, - const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, - py::handle dp_quantizer, py::handle dqkv_quantizer, bool cuda_graph) { - auto none = py::none(); - - // create QKV, O, dO tensor wrappers - TensorWrapper te_Q, te_K, te_V, te_O, te_dO; - te_Q = makeTransformerEngineTensor(Q, none); - te_K = makeTransformerEngineTensor(K, none); - te_V = makeTransformerEngineTensor(V, none); - te_O = makeTransformerEngineTensor(O, none); - te_dO = makeTransformerEngineTensor(dO, none); - - // create S and dP tensors - TensorWrapper te_S, te_dP; - py::object py_S, py_dP; - std::tie(te_S, py_S) = quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt); - std::tie(te_dP, py_dP) = - quantizer_helper(dp_quantizer, {0}, DType::kFloat32, false, std::nullopt); - - // create dQ, dK, dV tensors - TensorWrapper te_dQ, te_dK, te_dV; - py::object py_dQ, py_dK, py_dV; - std::unique_ptr dQKV_quantizer = convert_quantizer(dqkv_quantizer); - std::vector q_shape = convertShape(te_Q.shape()); - std::vector k_shape = convertShape(te_K.shape()); - std::vector v_shape = convertShape(te_V.shape()); - auto h_q = q_shape[q_shape.size() - 2]; - auto h_kv = k_shape[k_shape.size() - 2]; - auto d_qk = q_shape[q_shape.size() - 1]; - const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); - - at::Tensor dQ, dK, dV, dQKV, dKV; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - std::vector tmp_shape; - auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); - if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { - options = options.dtype(torch::kUInt8); - } - if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr())) { - options = options.dtype(fake_dtype); - } - - switch (layout_group) { - case NVTE_QKV_Layout_Group::NVTE_3HD: - tmp_shape = std::vector{q_shape.begin(), q_shape.end()}; - tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(3)); - dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); - dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), - torch::indexing::Slice(0, torch::indexing::None, 1), - torch::indexing::Slice(0, torch::indexing::None, 1)}) - .squeeze(tmp_shape.size() - 3); - dK = dQKV.index({"...", torch::indexing::Slice(1, 2, 1), - torch::indexing::Slice(0, torch::indexing::None, 1), - torch::indexing::Slice(0, torch::indexing::None, 1)}) - .squeeze(tmp_shape.size() - 3); - dV = dQKV.index({"...", torch::indexing::Slice(2, torch::indexing::None, 1), - torch::indexing::Slice(0, torch::indexing::None, 1), - torch::indexing::Slice(0, torch::indexing::None, 1)}) - .squeeze(tmp_shape.size() - 3); - break; - case NVTE_QKV_Layout_Group::NVTE_H3D: - tmp_shape = std::vector{q_shape.begin(), q_shape.end()}; - tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(3)); - dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); - dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), - torch::indexing::Slice(0, torch::indexing::None, 1)}) - .squeeze(tmp_shape.size() - 2); - dK = dQKV.index({"...", torch::indexing::Slice(1, 2, 1), - torch::indexing::Slice(0, torch::indexing::None, 1)}) - .squeeze(tmp_shape.size() - 2); - dV = dQKV.index({"...", torch::indexing::Slice(2, torch::indexing::None, 1), - torch::indexing::Slice(0, torch::indexing::None, 1)}) - .squeeze(tmp_shape.size() - 2); - break; - case NVTE_QKV_Layout_Group::NVTE_HD_2HD: - tmp_shape = std::vector(q_shape.begin(), q_shape.end()); - dQ = torch::empty(tmp_shape, options); - tmp_shape = std::vector{k_shape.begin(), k_shape.end()}; - tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(2)); - dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); - dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), - torch::indexing::Slice(0, torch::indexing::None, 1), - torch::indexing::Slice(0, torch::indexing::None, 1)}) - .squeeze(tmp_shape.size() - 3); - dV = dKV.index({"...", torch::indexing::Slice(1, torch::indexing::None, 1), - torch::indexing::Slice(0, torch::indexing::None, 1), - torch::indexing::Slice(0, torch::indexing::None, 1)}) - .squeeze(tmp_shape.size() - 3); - break; - case NVTE_QKV_Layout_Group::NVTE_HD_H2D: - tmp_shape = std::vector(q_shape.begin(), q_shape.end()); - dQ = torch::empty(tmp_shape, options); - tmp_shape = std::vector{k_shape.begin(), k_shape.end()}; - tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(2)); - dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); - dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), - torch::indexing::Slice(0, torch::indexing::None, 1)}) - .squeeze(tmp_shape.size() - 2); - dV = dKV.index({"...", torch::indexing::Slice(1, torch::indexing::None, 1), - torch::indexing::Slice(0, torch::indexing::None, 1)}) - .squeeze(tmp_shape.size() - 2); - break; - case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: - tmp_shape = std::vector(q_shape.begin(), q_shape.end()); - dQ = torch::empty(tmp_shape, options); - tmp_shape = std::vector(k_shape.begin(), k_shape.end()); - dK = torch::empty(tmp_shape, options); - tmp_shape = std::vector(v_shape.begin(), v_shape.end()); - dV = torch::empty(tmp_shape, options); - break; - default: - NVTE_ERROR("QKV layout not supported!"); - } - - std::tie(te_dQ, py_dQ) = quantizer_helper(dqkv_quantizer, q_shape, fake_dtype_te, true, dQ); - std::tie(te_dK, py_dK) = quantizer_helper(dqkv_quantizer, k_shape, fake_dtype_te, true, dK); - std::tie(te_dV, py_dV) = quantizer_helper(dqkv_quantizer, v_shape, fake_dtype_te, true, dV); - - // construct NVTE tensors - if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { - // FP8 - if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - if (((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && - dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous()) { - mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); - mha_fill(te_dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); - mha_fill(te_dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - dQ.fill_(0); - dK.fill_(0); - dV.fill_(0); - } - } - } else if (dqkv_type == DType::kBFloat16 || dqkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - dQ.fill_(0); - dK.fill_(0); - dV.fill_(0); - } - } else { - NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); - } - - // create cu_seqlens tensorwrappers - auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); - std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; - auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); - std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; - TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; - te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, - DType::kInt32, nullptr, nullptr, nullptr); - te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, - DType::kInt32, nullptr, nullptr, nullptr); - - TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded; - if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { - auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec(); - std::vector cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(), - cu_seqlens_q_padded_sizes.end()}; - auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); - std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), - cu_seqlens_kv_padded_sizes.end()}; - te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - cu_seqlens_q_padded_shape, DType::kInt32); - te_cu_seqlens_kv_padded = makeTransformerEngineTensor( - cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32); - } - - // convert auxiliary tensors from forward to NVTETensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); - for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - const std::vector &signed_shape = Aux_CTX_Tensors[i].sizes().vec(); - const std::vector tmp(signed_shape.begin(), signed_shape.end()); - - NVTEBasicTensor temp_data = { - Aux_CTX_Tensors[i].data_ptr(), - static_cast(GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type())), - nvte_make_shape(tmp.data(), tmp.size())}; - nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); - } - - // create dBias the same shape as Bias - at::Tensor dBias; - TensorWrapper te_dBias; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - if (nvte_aux_tensor_pack.size >= 2) { - std::vector bias_shape(Aux_CTX_Tensors[nvte_aux_tensor_pack.size - 1].sizes().vec()); - dBias = torch::empty(bias_shape, options); - te_dBias = makeTransformerEngineTensor(dBias); - } else { - dBias = torch::empty({1, static_cast(h_q), static_cast(max_seqlen_q), - static_cast(max_seqlen_kv)}, - options); - te_dBias = makeTransformerEngineTensor(dBias); - } - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - dBias.fill_(0); - } - } - - // create dSoftmaxOffset in the same shape as SoftmaxOffset - at::Tensor dSoftmaxOffset; - TensorWrapper te_dSoftmaxOffset; - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - options = torch::TensorOptions().dtype(at::kFloat).device(torch::kCUDA); - dSoftmaxOffset = torch::empty({1, static_cast(h_q), 1, 1}, options); - te_dSoftmaxOffset = makeTransformerEngineTensor(dSoftmaxOffset); - } - - // create workspace - TensorWrapper workspace; - - // populate tensors with appropriate shapes and dtypes - NVTE_SCOPED_GIL_RELEASE({ - nvte_fused_attn_bwd( - te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), - &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), - te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), - at::cuda::getCurrentCUDAStream()); - }); - - // allocate memory for workspace - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // execute kernel - NVTE_SCOPED_GIL_RELEASE({ - nvte_fused_attn_bwd( - te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), - &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), - te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), - at::cuda::getCurrentCUDAStream()); - }); - - // destroy tensor wrappers - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - - return {py_dQ, py_dK, py_dV, py::cast(dBias), py::cast(dSoftmaxOffset)}; -} +Tensor fa_prepare_fwd(Tensor qkvi) { + STD_TORCH_CHECK(qkvi.dim() == 4, "Expected 4-dim tensor."); + auto dtype = qkvi.scalar_type(); + STD_TORCH_CHECK(dtype == ScalarType::Half || dtype == ScalarType::BFloat16, + "Expected fp16 or bf16 input."); -at::Tensor fa_prepare_fwd(at::Tensor qkvi) { - NVTE_CHECK(qkvi.dim() == 4, "Expected 4-dim tensor."); - NVTE_CHECK(qkvi.scalar_type() == at::ScalarType::Half || - qkvi.scalar_type() == at::ScalarType::BFloat16); - NVTE_CHECK(qkvi.stride(3) == 1, "Wrong stride."); - NVTE_CHECK(qkvi.stride(2) == 3 * qkvi.size(3), "Wrong stride."); - NVTE_CHECK(qkvi.stride(1) == 3 * qkvi.size(3) * qkvi.size(2), "Wrong stride."); - NVTE_CHECK(qkvi.stride(0) == 3 * qkvi.size(3) * qkvi.size(2) * qkvi.size(1), "Wrong stride."); + // qkvi is a non-contiguous 4D view (s, b, n, d) of interleaved (s, b, n, 3, d) storage. + // size(3) = d (head_dim), NOT d*3. The factor of 3 is in the strides, not the sizes. + // Output: (3, b, s, n, d). + auto s = qkvi.size(0), b = qkvi.size(1), n = qkvi.size(2), d = qkvi.size(3); + auto qkv = allocateStableTensor({3, b, s, n, d}, dtype, qkvi.get_device_index()); - // [s, b, n, h * 3] -> [3, b, s, n, h] - std::vector shape = {3, qkvi.size(1), qkvi.size(0), qkvi.size(2), qkvi.size(3)}; - at::Tensor qkv = at::empty(shape, at::CUDA(qkvi.scalar_type())); - - auto te_qkvi = makeTransformerEngineTensor(qkvi); + auto te_dtype = GetTransformerEngineDType(dtype); + // Pass the tensor shape as [s, b, n, d] -- the kernel knows the data is 3-way interleaved + // and uses the factor of 3 internally via blockIdx.y and stride arithmetic. + auto te_qkvi = makeTransformerEngineTensor(qkvi.data_ptr(), + {static_cast(s), static_cast(b), + static_cast(n), static_cast(d)}, + te_dtype); auto te_qkv = makeTransformerEngineTensor(qkv); - - nvte_prepare_flash_attn_fwd(te_qkvi.data(), te_qkv.data(), at::cuda::getCurrentCUDAStream()); - + nvte_prepare_flash_attn_fwd(te_qkvi.data(), te_qkv.data(), + getCurrentCUDAStreamRaw(qkvi.get_device_index())); return qkv; } -at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { - NVTE_CHECK(q.is_contiguous()); - NVTE_CHECK(k.is_contiguous()); - NVTE_CHECK(v.is_contiguous()); - NVTE_CHECK(q.dim() == 4, "Expected 4-dim tensor."); - NVTE_CHECK(k.dim() == 4, "Expected 4-dim tensor."); - NVTE_CHECK(v.dim() == 4, "Expected 4-dim tensor."); - NVTE_CHECK(q.scalar_type() == at::ScalarType::Half || - q.scalar_type() == at::ScalarType::BFloat16); - NVTE_CHECK(k.scalar_type() == q.scalar_type()); - NVTE_CHECK(v.scalar_type() == q.scalar_type()); - - // 3 x [s, b, n, h] -> [b, s, n, 3 * h] - std::vector shape = {q.size(1), q.size(0), q.size(2), 3 * q.size(3)}; - at::Tensor qkv = at::empty(shape, at::CUDA(q.scalar_type())); +Tensor fa_prepare_bwd(Tensor q, Tensor k, Tensor v) { + STD_TORCH_CHECK(q.dim() == 4, "Expected 4-dim tensor."); + auto dtype = q.scalar_type(); + auto b = q.size(1), s = q.size(0), n = q.size(2), h = q.size(3); + auto qkv = allocateStableTensor({b, s, n, 3 * h}, dtype, q.get_device_index()); auto te_q = makeTransformerEngineTensor(q); auto te_k = makeTransformerEngineTensor(k); auto te_v = makeTransformerEngineTensor(v); auto te_qkv = makeTransformerEngineTensor(qkv); - nvte_prepare_flash_attn_bwd(te_q.data(), te_k.data(), te_v.data(), te_qkv.data(), - at::cuda::getCurrentCUDAStream()); - + getCurrentCUDAStreamRaw(q.get_device_index())); return qkv; } -/*************************************************************************************************** - * Support THD format for Context Parallel: Read the half of a THD tensor - **************************************************************************************************/ - -at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens, - int half_idx) { - NVTE_CHECK(tensor.dim() == 3 || tensor.dim() == 4); - NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); - NVTE_CHECK(cu_seqlens.dim() == 1); - NVTE_CHECK(cu_seqlens.size(0) >= 2); +// ============================================================================ +// THD format helpers for Context Parallel +// ============================================================================ - // Shapes of q and dq are [t, h, d], so the dimension of "t" is 0 - // Shapes of kv and dkv are [2, t, h, d], so the dimension of "t" is 1 +Tensor thd_read_half_tensor(Tensor tensor, Tensor cu_seqlens, int64_t half_idx) { int seq_dim = tensor.dim() == 3 ? 0 : 1; - - int num_heads = tensor.size(seq_dim + 1); - int dim_per_head = tensor.size(seq_dim + 2); - int hidden_size_in_bytes = num_heads * dim_per_head * c10::elementSize(tensor.scalar_type()); - - // For 128-bits load/store - NVTE_CHECK(hidden_size_in_bytes % 16 == 0); - - // Generate output - std::vector shape(tensor.dim()); - for (size_t i = 0; i < shape.size(); i++) { - shape[i] = tensor.size(i); + auto sizes = tensor.sizes(); + std::vector shape; + for (int64_t i = 0; i < tensor.dim(); ++i) { + shape.push_back(i == seq_dim ? sizes[i] / 2 : sizes[i]); } - shape[seq_dim] /= 2; - at::Tensor half = at::empty(shape, at::CUDA(tensor.scalar_type())); + auto half = allocateStableTensor(shape, tensor.scalar_type(), tensor.get_device_index()); auto te_tensor = makeTransformerEngineTensor(tensor); - auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens); + auto te_cu = makeTransformerEngineTensor(cu_seqlens); auto te_half = makeTransformerEngineTensor(half); - - nvte_cp_thd_read_half_tensor(te_tensor.data(), te_cu_seqlens.data(), te_half.data(), half_idx, - at::cuda::getCurrentCUDAStream()); - + nvte_cp_thd_read_half_tensor(te_tensor.data(), te_cu.data(), te_half.data(), + static_cast(half_idx), + getCurrentCUDAStreamRaw(tensor.get_device_index())); return half; } -/*************************************************************************************************** - * Support THD format for Context Parallel: softmax_lse related operations - **************************************************************************************************/ - -void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step, - const at::Tensor &cu_seqlens, bool lse_packed) { - NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); - NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float); - NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); - NVTE_CHECK(cu_seqlens.dim() == 1); - - int batch, num_heads, lse_seqlen, second_half_lse_seqlen; - - if (lse_packed) { - NVTE_CHECK(lse.dim() == 2); - NVTE_CHECK(lse_per_step.dim() == 2); - - batch = cu_seqlens.size(0) - 1; - num_heads = lse.size(0); - lse_seqlen = lse.size(1); - second_half_lse_seqlen = lse_per_step.size(1); - - NVTE_CHECK(lse_per_step.size(0) == num_heads); - NVTE_CHECK(second_half_lse_seqlen >= lse_seqlen / 2); - } else { - NVTE_CHECK(lse.dim() == 3); - NVTE_CHECK(lse_per_step.dim() == 3); - - batch = lse.size(0); - num_heads = lse.size(1); - lse_seqlen = lse.size(2); - second_half_lse_seqlen = lse_per_step.size(2); - - NVTE_CHECK(lse_per_step.size(0) == batch); - NVTE_CHECK(lse_per_step.size(1) == num_heads); - NVTE_CHECK(second_half_lse_seqlen == lse_seqlen / 2); - NVTE_CHECK(cu_seqlens.size(0) == batch + 1); - } - +void thd_second_half_lse_correction(Tensor lse, Tensor lse_per_step, Tensor cu_seqlens, + bool lse_packed) { auto te_lse = makeTransformerEngineTensor(lse); - auto te_lse_per_step = makeTransformerEngineTensor(lse_per_step); - auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens); - - nvte_cp_thd_second_half_lse_correction(te_lse.data(), te_lse_per_step.data(), - te_cu_seqlens.data(), lse_packed, - at::cuda::getCurrentCUDAStream()); + auto te_lse_ps = makeTransformerEngineTensor(lse_per_step); + auto te_cu = makeTransformerEngineTensor(cu_seqlens); + nvte_cp_thd_second_half_lse_correction(te_lse.data(), te_lse_ps.data(), te_cu.data(), lse_packed, + getCurrentCUDAStreamRaw(lse.get_device_index())); } -at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, - bool lse_packed, int second_half_lse_seqlen) { - NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); - NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); - NVTE_CHECK(cu_seqlens.dim() == 1); - - int batch, num_heads, lse_seqlen; +Tensor thd_read_second_half_lse(Tensor lse, Tensor cu_seqlens, bool lse_packed, + int64_t second_half_lse_seqlen) { std::vector shape; - if (lse_packed) { - NVTE_CHECK(lse.dim() == 2); - - batch = cu_seqlens.size(0) - 1; - num_heads = lse.size(0); - lse_seqlen = lse.size(1); - - NVTE_CHECK(second_half_lse_seqlen >= lse_seqlen / 2); - - shape = {num_heads, second_half_lse_seqlen}; + shape = {lse.size(0), second_half_lse_seqlen}; } else { - NVTE_CHECK(lse.dim() == 3); - - batch = lse.size(0); - num_heads = lse.size(1); - lse_seqlen = lse.size(2); - - NVTE_CHECK(cu_seqlens.size(0) == batch + 1); - NVTE_CHECK(second_half_lse_seqlen == lse_seqlen / 2); - - shape = {batch, num_heads, second_half_lse_seqlen}; + shape = {lse.size(0), lse.size(1), second_half_lse_seqlen}; } - - at::Tensor half_lse = at::zeros(shape, at::CUDA(lse.scalar_type())); + auto half_lse = allocateStableTensorZeros(shape, ScalarType::Float, lse.get_device_index()); auto te_lse = makeTransformerEngineTensor(lse); - auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens); - auto te_half_lse = makeTransformerEngineTensor(half_lse); - - nvte_cp_thd_read_second_half_lse(te_lse.data(), te_cu_seqlens.data(), te_half_lse.data(), - lse_packed, second_half_lse_seqlen, - at::cuda::getCurrentCUDAStream()); - + auto te_cu = makeTransformerEngineTensor(cu_seqlens); + auto te_half = makeTransformerEngineTensor(half_lse); + nvte_cp_thd_read_second_half_lse(te_lse.data(), te_cu.data(), te_half.data(), lse_packed, + static_cast(second_half_lse_seqlen), + getCurrentCUDAStreamRaw(lse.get_device_index())); return half_lse; } -/*************************************************************************************************** - * Support THD format for Context Parallel: Out correction in forward - **************************************************************************************************/ - -void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, - const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, - bool only_second_half, bool lse_packed) { +void thd_out_correction(Tensor out, Tensor out_per_step, Tensor lse, Tensor lse_per_step, + Tensor cu_seqlens, bool only_second_half, bool lse_packed) { auto te_out = makeTransformerEngineTensor(out); - auto te_out_per_step = makeTransformerEngineTensor(out_per_step); + auto te_ops = makeTransformerEngineTensor(out_per_step); auto te_lse = makeTransformerEngineTensor(lse); - auto te_lse_per_step = makeTransformerEngineTensor(lse_per_step); - auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens); - nvte_cp_thd_out_correction(te_out.data(), te_out_per_step.data(), te_lse.data(), - te_lse_per_step.data(), te_cu_seqlens.data(), only_second_half, - lse_packed, at::cuda::getCurrentCUDAStream()); + auto te_lps = makeTransformerEngineTensor(lse_per_step); + auto te_cu = makeTransformerEngineTensor(cu_seqlens); + nvte_cp_thd_out_correction(te_out.data(), te_ops.data(), te_lse.data(), te_lps.data(), + te_cu.data(), only_second_half, lse_packed, + getCurrentCUDAStreamRaw(out.get_device_index())); } -/*************************************************************************************************** - * Support THD format for Context Parallel: Gradients correction in backward - **************************************************************************************************/ - -void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, - const at::Tensor &cu_seqlens, const std::string &first_half, - const std::string &second_half) { +void thd_grad_correction(Tensor grad, Tensor grad_per_step, Tensor cu_seqlens, + std::string first_half, std::string second_half) { auto te_grad = makeTransformerEngineTensor(grad); - auto te_grad_per_step = makeTransformerEngineTensor(grad_per_step); - auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens); - nvte_cp_thd_grad_correction(te_grad.data(), te_grad_per_step.data(), te_cu_seqlens.data(), - first_half.data(), second_half.data(), - at::cuda::getCurrentCUDAStream()); + auto te_gps = makeTransformerEngineTensor(grad_per_step); + auto te_cu = makeTransformerEngineTensor(cu_seqlens); + nvte_cp_thd_grad_correction(te_grad.data(), te_gps.data(), te_cu.data(), first_half.data(), + second_half.data(), getCurrentCUDAStreamRaw(grad.get_device_index())); } -/*************************************************************************************************** - * Support THD format for Context Parallel: Generate partitioned indices for input tokens - **************************************************************************************************/ - -at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens, - int world_size, int rank) { - NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); - NVTE_CHECK(cu_seqlens.dim() == 1); - NVTE_CHECK(cu_seqlens.size(0) >= 2); - NVTE_CHECK(rank >= 0 && rank < world_size); - NVTE_CHECK(world_size > 0); - NVTE_CHECK(total_tokens > 0 && total_tokens % (world_size * 2) == 0); +Tensor thd_get_partitioned_indices(Tensor cu_seqlens, int64_t total_tokens, int64_t world_size, + int64_t rank) { + auto output = allocateStableTensor({total_tokens / world_size}, ScalarType::Int, + cu_seqlens.get_device_index()); + auto te_cu = makeTransformerEngineTensor(cu_seqlens); + auto te_out = makeTransformerEngineTensor(output); + nvte_cp_thd_get_partitioned_indices(te_cu.data(), te_out.data(), static_cast(total_tokens), + static_cast(world_size), static_cast(rank), + getCurrentCUDAStreamRaw(cu_seqlens.get_device_index())); + return output; +} - std::vector shape = {total_tokens / world_size}; - at::Tensor output = at::empty(shape, at::CUDA(at::ScalarType::Int)); +// ============================================================================ +// Format conversions +// ============================================================================ + +Tensor convert_thd_to_bshd(Tensor tensor, Tensor cu_seqlens, int64_t b, int64_t max_seq_len) { + auto h = tensor.size(1), d = tensor.size(2); + auto new_tensor = allocateStableTensorZeros({b, max_seq_len, h, d}, tensor.scalar_type(), + tensor.get_device_index()); + + auto te_t = makeTransformerEngineTensor(tensor); + auto te_cu = makeTransformerEngineTensor(cu_seqlens); + auto te_new = makeTransformerEngineTensor(new_tensor); + nvte_convert_thd_to_bshd(te_t.data(), te_cu.data(), te_new.data(), static_cast(b), + static_cast(max_seq_len), + getCurrentCUDAStreamRaw(tensor.get_device_index())); + return new_tensor; +} - auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens); - auto te_output = makeTransformerEngineTensor(output); +Tensor convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, int64_t t) { + auto h = tensor.size(2), d = tensor.size(3); + auto new_tensor = + allocateStableTensorZeros({t, h, d}, tensor.scalar_type(), tensor.get_device_index()); - nvte_cp_thd_get_partitioned_indices(te_cu_seqlens.data(), te_output.data(), total_tokens, - world_size, rank, at::cuda::getCurrentCUDAStream()); + auto te_t = makeTransformerEngineTensor(tensor); + auto te_cu = makeTransformerEngineTensor(cu_seqlens); + auto te_new = makeTransformerEngineTensor(new_tensor); + nvte_convert_bshd_to_thd(te_t.data(), te_cu.data(), te_new.data(), static_cast(t), + getCurrentCUDAStreamRaw(tensor.get_device_index())); + return new_tensor; +} - return output; +// ============================================================================ +// KV Cache +// ============================================================================ + +void copy_to_kv_cache(Tensor new_k, Tensor new_v, Tensor k_cache, Tensor v_cache, Tensor page_table, + Tensor cu_new_lens, Tensor cu_cached_lens, int64_t qkv_format, int64_t b, + int64_t max_ctx_len, int64_t max_seq_len, int64_t max_pages_per_seq, + bool is_non_paged) { + auto te_nk = makeTransformerEngineTensor(new_k); + auto te_nv = makeTransformerEngineTensor(new_v); + auto te_kc = makeTransformerEngineTensor(k_cache); + auto te_vc = makeTransformerEngineTensor(v_cache); + auto te_pt = makeTransformerEngineTensor(page_table); + auto te_cnl = makeTransformerEngineTensor(cu_new_lens); + auto te_ccl = makeTransformerEngineTensor(cu_cached_lens); + + nvte_copy_to_kv_cache(te_nk.data(), te_nv.data(), te_kc.data(), te_vc.data(), te_pt.data(), + te_cnl.data(), te_ccl.data(), static_cast(qkv_format), + static_cast(b), static_cast(max_ctx_len), + static_cast(max_seq_len), static_cast(max_pages_per_seq), + is_non_paged, getCurrentCUDAStreamRaw(new_k.get_device_index())); } -/*************************************************************************************************** - * KV Cache: Convert a tensor from qkv_format = thd to qkv_format = bshd - **************************************************************************************************/ +// ============================================================================ +// Fused Attention Forward — noalloc variant +// +// Aux tensors are flattened to individual Tensor? args (max 10 per +// NVTETensorPack::MAX_SIZE). The Python shim packs/unpacks them. +// ============================================================================ + +std::tuple +fused_attn_fwd_noalloc( + int64_t max_seqlen_q, int64_t max_seqlen_kv, bool is_training, double attn_scale, + double p_dropout, bool set_zero, int64_t qkv_layout, int64_t bias_type, int64_t attn_mask_type, + int64_t softmax_type, std::vector window_size, bool bottom_right_diagonal, + Tensor cu_seqlens_q, Tensor cu_seqlens_kv, + // Q/K/V with optional quantization + Tensor Q_data, int64_t Q_dtype, std::optional Q_scale_inv, int64_t Q_scaling_mode, + Tensor K_data, int64_t K_dtype, std::optional K_scale_inv, int64_t K_scaling_mode, + Tensor V_data, int64_t V_dtype, std::optional V_scale_inv, int64_t V_scaling_mode, + // S (softmax) — usually empty placeholder + Tensor S_data, int64_t S_dtype, std::optional S_amax, std::optional S_scale, + std::optional S_scale_inv, int64_t S_scaling_mode, + // O (output) — pre-allocated + Tensor O_data, int64_t O_dtype, std::optional O_amax, std::optional O_scale, + std::optional O_scale_inv, int64_t O_scaling_mode, + // Optional tensors + std::optional cu_seqlens_q_padded, std::optional cu_seqlens_kv_padded, + std::optional page_table_k, std::optional page_table_v, + std::optional Bias, std::optional SoftmaxOffset, + // RNG state [seed, offset] as int64 tensor + Tensor rng_state, bool return_max_logit, bool cuda_graph) { + auto nvte_layout = static_cast(qkv_layout); + auto nvte_bias = static_cast(bias_type); + auto nvte_mask = static_cast(attn_mask_type); + auto nvte_softmax = static_cast(softmax_type); + + auto Q_shape = getStableTensorShape(Q_data); + auto K_shape = getStableTensorShape(K_data); + auto V_shape = getStableTensorShape(V_data); + + // Build Q/K/V TensorWrappers + auto te_Q = makeQuantizedTensorWrapper(Q_data, static_cast(Q_dtype), Q_shape, std::nullopt, + std::nullopt, Q_scale_inv, + static_cast(Q_scaling_mode)); + auto te_K = makeQuantizedTensorWrapper(K_data, static_cast(K_dtype), K_shape, std::nullopt, + std::nullopt, K_scale_inv, + static_cast(K_scaling_mode)); + auto te_V = makeQuantizedTensorWrapper(V_data, static_cast(V_dtype), V_shape, std::nullopt, + std::nullopt, V_scale_inv, + static_cast(V_scaling_mode)); + + // Build O TensorWrapper (output shape = Q shape with V's last dim) + auto O_shape = getStableTensorShape(O_data); + auto te_O = + makeQuantizedTensorWrapper(O_data, static_cast(O_dtype), O_shape, O_amax, O_scale, + O_scale_inv, static_cast(O_scaling_mode)); + + // Build S TensorWrapper (placeholder — NVTE determines actual shape) + auto S_shape_vec = getStableTensorShape(S_data); + auto te_S = + makeQuantizedTensorWrapper(S_data, static_cast(S_dtype), S_shape_vec, S_amax, S_scale, + S_scale_inv, static_cast(S_scaling_mode)); + + // Zero-fill O if needed for THD format + auto qkv_type = static_cast(Q_dtype); + auto device_idx = Q_data.get_device_index(); + auto stream = getCurrentCUDAStreamRaw(device_idx); + + if (set_zero && nvte_get_qkv_format(nvte_layout) == NVTE_QKV_Format::NVTE_THD) { + te_O.zero_(stream); + } -at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len) { - int h = tensor.size(1); - int d = tensor.size(2); - std::vector shape = {b, max_seq_len, h, d}; - at::Tensor new_tensor = at::zeros(shape, at::CUDA(tensor.scalar_type())); + // Optional tensors + TensorWrapper te_Bias, te_SoftmaxOffset; + TensorWrapper te_cu_q, te_cu_kv, te_cu_q_pad, te_cu_kv_pad; + TensorWrapper te_pt_k, te_pt_v; - auto te_tensor = makeTransformerEngineTensor(tensor); - auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens); - auto te_new_tensor = makeTransformerEngineTensor(new_tensor); + te_cu_q = makeTransformerEngineTensor(cu_seqlens_q); + te_cu_kv = makeTransformerEngineTensor(cu_seqlens_kv); - nvte_convert_thd_to_bshd(te_tensor.data(), te_cu_seqlens.data(), te_new_tensor.data(), b, - max_seq_len, at::cuda::getCurrentCUDAStream()); + if (Bias.has_value()) { + te_Bias = makeTransformerEngineTensor(Bias.value()); + } + if (SoftmaxOffset.has_value()) { + te_SoftmaxOffset = makeTransformerEngineTensor(SoftmaxOffset.value()); + } + if (cu_seqlens_q_padded.has_value()) { + te_cu_q_pad = makeTransformerEngineTensor(cu_seqlens_q_padded.value()); + } + if (cu_seqlens_kv_padded.has_value()) { + te_cu_kv_pad = makeTransformerEngineTensor(cu_seqlens_kv_padded.value()); + } + if (page_table_k.has_value()) { + te_pt_k = makeTransformerEngineTensor(page_table_k.value()); + } + if (page_table_v.has_value()) { + te_pt_v = makeTransformerEngineTensor(page_table_v.value()); + } - return new_tensor; + auto te_rng = makeTransformerEngineTensor(rng_state); + + // Aux tensor pack + NVTETensorPack nvte_aux; + nvte_tensor_pack_create(&nvte_aux); + + // Workspace + TensorWrapper workspace; + + // Phase 1: shape query + nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), + te_SoftmaxOffset.data(), te_S.data(), te_O.data(), &nvte_aux, te_cu_q.data(), + te_cu_kv.data(), te_cu_q_pad.data(), te_cu_kv_pad.data(), te_pt_k.data(), + te_pt_v.data(), te_rng.data(), static_cast(max_seqlen_q), + static_cast(max_seqlen_kv), is_training, return_max_logit, cuda_graph, + static_cast(attn_scale), static_cast(p_dropout), nvte_layout, + nvte_bias, nvte_mask, nvte_softmax, window_size[0], window_size[1], + bottom_right_diagonal, workspace.data(), stream); + + // Allocate workspace — declare ws_data OUTSIDE the if-block so the Tensor stays alive + // through Phase 2 execution. If ws_data were declared inside the if-block, it would be + // destroyed before Phase 2, and subsequent aux-tensor allocations could reuse the same + // memory, causing Phase 2 to corrupt the aux tensors (seen as err 700 on 3+ layers). + auto ws_shape = workspace.shape(); + Tensor ws_data; + if (ws_shape.ndim > 0 && workspace.numel() > 0) { + ws_data = + allocateStableTensor(std::vector(ws_shape.data, ws_shape.data + ws_shape.ndim), + workspace.dtype(), device_idx); + workspace = makeTransformerEngineTensor( + ws_data.data_ptr(), std::vector(ws_shape.data, ws_shape.data + ws_shape.ndim), + workspace.dtype()); + } + + // Allocate aux tensors and populate the pack. + // IMPORTANT: the rng_state slot (shape=[2], dtype=int64) must use the caller-supplied + // rng_state tensor directly — NOT a new allocation. cuDNN (line 1182 in + // fused_attn_f16_arbitrary_seqlen.cu) overwrites the aux slot's dptr to point to the + // input rng_state buffer during Phase 2, so whatever the aux pack slot points to + // before Phase 2 must be the same tensor that will be returned to Python. + // Pybind11 does this at extensions/attention.cpp:280: set_tensor_param(i++, rng_state). + // + // Similarly, the Bias slot must use the caller-supplied Bias tensor so that the + // backward pass can read the original bias shape for dBias allocation. The old pybind + // code saves the original Bias tensor at extensions/attention.cpp:283. + // The SoftmaxOffset slot is handled the same way. + bool rng_state_placed = false; + bool bias_placed = false; + bool softmax_offset_placed = false; + std::vector aux_tensors; + for (size_t i = 0; i < nvte_aux.size; ++i) { + auto aux_shape = nvte_tensor_shape(nvte_aux.tensors[i]); + auto aux_dtype = static_cast(nvte_tensor_type(nvte_aux.tensors[i])); + std::vector shape_vec; + for (size_t d = 0; d < aux_shape.ndim; ++d) { + shape_vec.push_back(static_cast(aux_shape.data[d])); + } + + Tensor aux_tensor; + // Detect the rng_state slot: shape [2], dtype int64 (kInt64=3). + bool is_rng_slot = (!rng_state_placed && aux_dtype == DType::kInt64 && aux_shape.ndim == 1 && + aux_shape.data[0] == 2); + // Detect the Bias slot: 4D tensor that comes AFTER the rng_state slot. + // bias_type must be not NO_BIAS(0) and not ALIBI(3), and the caller must have + // provided a Bias tensor. We require rng_state to be placed first because + // the aux order is: S, [Max], rng_state, [Bias], [SoftmaxOffset]. + // Note: NVTE reports the Bias aux slot with QKV dtype (BF16/FP16), not kFloat32, + // so we do NOT filter on aux_dtype here. + bool is_bias_slot = + (rng_state_placed && !bias_placed && Bias.has_value() && nvte_bias != NVTE_NO_BIAS && + nvte_bias != NVTE_ALIBI && aux_shape.ndim == 4); + // Detect the SoftmaxOffset slot: 4D float32 [1, h, 1, 1] tensor after rng_state. + bool is_softmax_offset_slot = + (rng_state_placed && !softmax_offset_placed && SoftmaxOffset.has_value() && + nvte_softmax != NVTE_VANILLA_SOFTMAX && aux_dtype == DType::kFloat32 && + aux_shape.ndim == 4 && aux_shape.data[0] == 1 && aux_shape.data[2] == 1 && + aux_shape.data[3] == 1); + + if (is_rng_slot) { + aux_tensor = rng_state; + rng_state_placed = true; + } else if (is_bias_slot) { + aux_tensor = Bias.value(); + bias_placed = true; + } else if (is_softmax_offset_slot) { + aux_tensor = SoftmaxOffset.value(); + softmax_offset_placed = true; + } else { + aux_tensor = allocateStableTensor(shape_vec, aux_dtype, device_idx); + } + aux_tensors.push_back(aux_tensor); + + // Use the aux_tensor's actual shape for the tensor param (important for Bias + // which may have batch dim > 1 while NVTE reports [1, h, s, s]). + auto actual_shape = getStableTensorShape(aux_tensor); + NVTEBasicTensor temp = {aux_tensor.data_ptr(), nvte_tensor_type(nvte_aux.tensors[i]), + nvte_make_shape(actual_shape.data(), actual_shape.size())}; + nvte_set_tensor_param(&nvte_aux.tensors[i], kNVTERowwiseData, &temp); + } + + // Phase 2: execute + nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), + te_SoftmaxOffset.data(), te_S.data(), te_O.data(), &nvte_aux, te_cu_q.data(), + te_cu_kv.data(), te_cu_q_pad.data(), te_cu_kv_pad.data(), te_pt_k.data(), + te_pt_v.data(), te_rng.data(), static_cast(max_seqlen_q), + static_cast(max_seqlen_kv), is_training, return_max_logit, cuda_graph, + static_cast(attn_scale), static_cast(p_dropout), nvte_layout, + nvte_bias, nvte_mask, nvte_softmax, window_size[0], window_size[1], + bottom_right_diagonal, workspace.data(), stream); + + int64_t num_aux = static_cast(aux_tensors.size()); + nvte_tensor_pack_destroy(&nvte_aux); + + // Pad to 10 slots + while (aux_tensors.size() < 10) { + aux_tensors.push_back(Tensor()); // empty/undefined tensor + } + return std::make_tuple(aux_tensors[0], aux_tensors[1], aux_tensors[2], aux_tensors[3], + aux_tensors[4], aux_tensors[5], aux_tensors[6], aux_tensors[7], + aux_tensors[8], aux_tensors[9], num_aux); } -/*************************************************************************************************** - * KV Cache: Convert a tensor from qkv_format = bshd to qkv_format = thd - **************************************************************************************************/ +// ============================================================================ +// Fused Attention Backward — noalloc variant +// +// dQ/dK/dV: pre-allocated by Python based on layout group. +// Aux_CTX_Tensors: the aux tensors from forward pass. +// ============================================================================ + +std::tuple fused_attn_bwd_noalloc( + int64_t max_seqlen_q, int64_t max_seqlen_kv, double attn_scale, double p_dropout, bool set_zero, + int64_t qkv_layout, int64_t bias_type, int64_t attn_mask_type, int64_t softmax_type, + std::vector window_size, bool bottom_right_diagonal, bool deterministic, + Tensor cu_seqlens_q, Tensor cu_seqlens_kv, + // Q/K/V/O/dO with optional quantization + Tensor Q_data, int64_t Q_dtype, std::optional Q_scale_inv, int64_t Q_scaling_mode, + Tensor K_data, int64_t K_dtype, std::optional K_scale_inv, int64_t K_scaling_mode, + Tensor V_data, int64_t V_dtype, std::optional V_scale_inv, int64_t V_scaling_mode, + Tensor O_data, int64_t O_dtype, std::optional O_scale_inv, int64_t O_scaling_mode, + Tensor dO_data, int64_t dO_dtype, std::optional dO_scale_inv, int64_t dO_scaling_mode, + // S and dP (softmax tensors) + Tensor S_data, int64_t S_dtype, std::optional S_amax, std::optional S_scale, + std::optional S_scale_inv, int64_t S_scaling_mode, Tensor dP_data, int64_t dP_dtype, + std::optional dP_amax, std::optional dP_scale, + std::optional dP_scale_inv, int64_t dP_scaling_mode, + // dQ/dK/dV pre-allocated by Python + Tensor dQ_data, int64_t dQ_dtype, std::optional dQ_amax, std::optional dQ_scale, + std::optional dQ_scale_inv, int64_t dQ_scaling_mode, Tensor dK_data, int64_t dK_dtype, + std::optional dK_amax, std::optional dK_scale, + std::optional dK_scale_inv, int64_t dK_scaling_mode, Tensor dV_data, int64_t dV_dtype, + std::optional dV_amax, std::optional dV_scale, + std::optional dV_scale_inv, int64_t dV_scaling_mode, + // dBias/dSoftmaxOffset pre-allocated + std::optional dBias, std::optional dSoftmaxOffset, + // Aux context tensors from forward (flattened, max 10) + int64_t num_aux_tensors, std::optional aux0, std::optional aux1, + std::optional aux2, std::optional aux3, std::optional aux4, + std::optional aux5, std::optional aux6, std::optional aux7, + std::optional aux8, std::optional aux9, + std::optional cu_seqlens_q_padded, std::optional cu_seqlens_kv_padded, + bool cuda_graph) { + auto nvte_layout = static_cast(qkv_layout); + auto nvte_bias = static_cast(bias_type); + auto nvte_mask = static_cast(attn_mask_type); + auto nvte_softmax = static_cast(softmax_type); + + auto device_idx = Q_data.get_device_index(); + auto stream = getCurrentCUDAStreamRaw(device_idx); + + // Build TensorWrappers + auto te_Q = makeQuantizedTensorWrapper(Q_data, static_cast(Q_dtype), + getStableTensorShape(Q_data), std::nullopt, std::nullopt, + Q_scale_inv, static_cast(Q_scaling_mode)); + auto te_K = makeQuantizedTensorWrapper(K_data, static_cast(K_dtype), + getStableTensorShape(K_data), std::nullopt, std::nullopt, + K_scale_inv, static_cast(K_scaling_mode)); + auto te_V = makeQuantizedTensorWrapper(V_data, static_cast(V_dtype), + getStableTensorShape(V_data), std::nullopt, std::nullopt, + V_scale_inv, static_cast(V_scaling_mode)); + auto te_O = makeQuantizedTensorWrapper(O_data, static_cast(O_dtype), + getStableTensorShape(O_data), std::nullopt, std::nullopt, + O_scale_inv, static_cast(O_scaling_mode)); + auto te_dO = makeQuantizedTensorWrapper( + dO_data, static_cast(dO_dtype), getStableTensorShape(dO_data), std::nullopt, + std::nullopt, dO_scale_inv, static_cast(dO_scaling_mode)); + + auto te_S = makeQuantizedTensorWrapper(S_data, static_cast(S_dtype), + getStableTensorShape(S_data), S_amax, S_scale, S_scale_inv, + static_cast(S_scaling_mode)); + auto te_dP = makeQuantizedTensorWrapper( + dP_data, static_cast(dP_dtype), getStableTensorShape(dP_data), dP_amax, dP_scale, + dP_scale_inv, static_cast(dP_scaling_mode)); + + auto te_dQ = makeQuantizedTensorWrapper( + dQ_data, static_cast(dQ_dtype), getStableTensorShape(dQ_data), dQ_amax, dQ_scale, + dQ_scale_inv, static_cast(dQ_scaling_mode)); + auto te_dK = makeQuantizedTensorWrapper( + dK_data, static_cast(dK_dtype), getStableTensorShape(dK_data), dK_amax, dK_scale, + dK_scale_inv, static_cast(dK_scaling_mode)); + auto te_dV = makeQuantizedTensorWrapper( + dV_data, static_cast(dV_dtype), getStableTensorShape(dV_data), dV_amax, dV_scale, + dV_scale_inv, static_cast(dV_scaling_mode)); + + TensorWrapper te_dBias, te_dSoftmaxOffset; + if (dBias.has_value()) te_dBias = makeTransformerEngineTensor(dBias.value()); + if (dSoftmaxOffset.has_value()) + te_dSoftmaxOffset = makeTransformerEngineTensor(dSoftmaxOffset.value()); + + auto te_cu_q = makeTransformerEngineTensor(cu_seqlens_q); + auto te_cu_kv = makeTransformerEngineTensor(cu_seqlens_kv); + TensorWrapper te_cu_q_pad, te_cu_kv_pad; + if (cu_seqlens_q_padded.has_value()) + te_cu_q_pad = makeTransformerEngineTensor(cu_seqlens_q_padded.value()); + if (cu_seqlens_kv_padded.has_value()) + te_cu_kv_pad = makeTransformerEngineTensor(cu_seqlens_kv_padded.value()); + + // Build aux tensor pack from flattened forward context tensors + std::optional aux_slots[] = {aux0, aux1, aux2, aux3, aux4, aux5, aux6, aux7, aux8, aux9}; + NVTETensorPack nvte_aux; + nvte_tensor_pack_create(&nvte_aux); + nvte_aux.size = static_cast(num_aux_tensors); + for (size_t i = 0; i < nvte_aux.size; ++i) { + NVTE_CHECK(aux_slots[i].has_value(), "aux tensor ", i, + " is None but num_aux_tensors=", num_aux_tensors); + auto& t = aux_slots[i].value(); + auto shape_vec = getStableTensorShape(t); + auto dtype = GetTransformerEngineDType(t.scalar_type()); + NVTEBasicTensor temp = {t.data_ptr(), static_cast(dtype), + nvte_make_shape(shape_vec.data(), shape_vec.size())}; + nvte_set_tensor_param(&nvte_aux.tensors[i], kNVTERowwiseData, &temp); + } + + // Zero-fill dQ/dK/dV for THD format if needed + if (set_zero && nvte_get_qkv_format(nvte_layout) == NVTE_QKV_Format::NVTE_THD) { + te_dQ.zero_(stream); + te_dK.zero_(stream); + te_dV.zero_(stream); + } + if (dBias.has_value() && nvte_get_qkv_format(nvte_layout) == NVTE_QKV_Format::NVTE_THD) { + torch::stable::zero_(dBias.value()); + } -at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t) { - int h = tensor.size(2); - int d = tensor.size(3); - std::vector shape = {t, h, d}; - at::Tensor new_tensor = at::zeros(shape, at::CUDA(tensor.scalar_type())); + TensorWrapper workspace; - auto te_tensor = makeTransformerEngineTensor(tensor); - auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens); - auto te_new_tensor = makeTransformerEngineTensor(new_tensor); + // Phase 1: shape query + nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), + te_dP.data(), &nvte_aux, te_dQ.data(), te_dK.data(), te_dV.data(), + te_dBias.data(), te_dSoftmaxOffset.data(), te_cu_q.data(), te_cu_kv.data(), + te_cu_q_pad.data(), te_cu_kv_pad.data(), static_cast(max_seqlen_q), + static_cast(max_seqlen_kv), static_cast(attn_scale), + static_cast(p_dropout), nvte_layout, nvte_bias, nvte_mask, + nvte_softmax, window_size[0], window_size[1], bottom_right_diagonal, + deterministic, cuda_graph, workspace.data(), stream); + + // Allocate workspace — declare ws_data OUTSIDE the if-block so it stays alive + // through Phase 2 (same issue as fwd: ws_data inside if-block would be freed + // before Phase 2, and subsequent allocations could reuse the workspace memory). + auto ws_shape = workspace.shape(); + Tensor ws_data; + if (ws_shape.ndim > 0 && workspace.numel() > 0) { + ws_data = + allocateStableTensor(std::vector(ws_shape.data, ws_shape.data + ws_shape.ndim), + workspace.dtype(), device_idx); + workspace = makeTransformerEngineTensor( + ws_data.data_ptr(), std::vector(ws_shape.data, ws_shape.data + ws_shape.ndim), + workspace.dtype()); + } - nvte_convert_bshd_to_thd(te_tensor.data(), te_cu_seqlens.data(), te_new_tensor.data(), t, - at::cuda::getCurrentCUDAStream()); + // Phase 2: execute + nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), + te_dP.data(), &nvte_aux, te_dQ.data(), te_dK.data(), te_dV.data(), + te_dBias.data(), te_dSoftmaxOffset.data(), te_cu_q.data(), te_cu_kv.data(), + te_cu_q_pad.data(), te_cu_kv_pad.data(), static_cast(max_seqlen_q), + static_cast(max_seqlen_kv), static_cast(attn_scale), + static_cast(p_dropout), nvte_layout, nvte_bias, nvte_mask, + nvte_softmax, window_size[0], window_size[1], bottom_right_diagonal, + deterministic, cuda_graph, workspace.data(), stream); + + nvte_tensor_pack_destroy(&nvte_aux); + + // Return dBias and dSoftmaxOffset (dQ/dK/dV are written in-place) + Tensor ret_dBias = dBias.has_value() ? dBias.value() : Tensor(); + Tensor ret_dSO = dSoftmaxOffset.has_value() ? dSoftmaxOffset.value() : Tensor(); + return std::make_tuple(ret_dBias, ret_dSO); +} - return new_tensor; +// ============================================================================ +// Fused Attention Backward — packed variant (57 args, under 64-arg limit) +// +// dtype_info is a 1-D int64 CPU tensor with 20 values: +// [Q_dtype, Q_sm, K_dtype, K_sm, V_dtype, V_sm, O_dtype, O_sm, +// dO_dtype, dO_sm, S_dtype, S_sm, dP_dtype, dP_sm, +// dQ_dtype, dQ_sm, dK_dtype, dK_sm, dV_dtype, dV_sm] +// dQ/dK/dV/dBias/dSoftmaxOffset are pre-allocated by the Python caller. +// ============================================================================ +std::tuple fused_attn_bwd_packed( + // Config (13) + int64_t max_seqlen_q, int64_t max_seqlen_kv, double attn_scale, double p_dropout, bool set_zero, + int64_t qkv_layout, int64_t bias_type, int64_t attn_mask_type, int64_t softmax_type, + std::vector window_size, bool bottom_right_diagonal, bool deterministic, + bool cuda_graph, + // Sequence lengths (4) + Tensor cu_seqlens_q, Tensor cu_seqlens_kv, std::optional cu_seqlens_q_padded, + std::optional cu_seqlens_kv_padded, + // Input tensors: data + scale_inv (10) + Tensor Q_data, std::optional Q_scale_inv, Tensor K_data, + std::optional K_scale_inv, Tensor V_data, std::optional V_scale_inv, + Tensor O_data, std::optional O_scale_inv, Tensor dO_data, + std::optional dO_scale_inv, + // Softmax buffers (8) + Tensor S_data, std::optional S_amax, std::optional S_scale, + std::optional S_scale_inv, Tensor dP_data, std::optional dP_amax, + std::optional dP_scale, std::optional dP_scale_inv, + // Output grad tensors (12) + Tensor dQ_data, std::optional dQ_amax, std::optional dQ_scale, + std::optional dQ_scale_inv, Tensor dK_data, std::optional dK_amax, + std::optional dK_scale, std::optional dK_scale_inv, Tensor dV_data, + std::optional dV_amax, std::optional dV_scale, + std::optional dV_scale_inv, + // Optional bias outputs (2) + std::optional dBias, std::optional dSoftmaxOffset, + // Packed dtype info (1): [Q_dtype, Q_sm, K_dtype, K_sm, V_dtype, V_sm, + // O_dtype, O_sm, dO_dtype, dO_sm, S_dtype, S_sm, dP_dtype, dP_sm, + // dQ_dtype, dQ_sm, dK_dtype, dK_sm, dV_dtype, dV_sm] + Tensor dtype_info, + // Aux context from forward (11) + int64_t num_aux_tensors, std::optional aux0, std::optional aux1, + std::optional aux2, std::optional aux3, std::optional aux4, + std::optional aux5, std::optional aux6, std::optional aux7, + std::optional aux8, std::optional aux9) { + // Unpack dtype info from CPU int64 tensor (passed from Python as CPU tensor) + const auto* dt_ptr = static_cast(dtype_info.data_ptr()); + int64_t Q_dtype = dt_ptr[0], Q_sm = dt_ptr[1]; + int64_t K_dtype = dt_ptr[2], K_sm = dt_ptr[3]; + int64_t V_dtype = dt_ptr[4], V_sm = dt_ptr[5]; + int64_t O_dtype = dt_ptr[6], O_sm = dt_ptr[7]; + int64_t dO_dtype = dt_ptr[8], dO_sm = dt_ptr[9]; + int64_t S_dtype = dt_ptr[10], S_sm = dt_ptr[11]; + int64_t dP_dtype = dt_ptr[12], dP_sm = dt_ptr[13]; + int64_t dQ_dtype = dt_ptr[14], dQ_sm = dt_ptr[15]; + int64_t dK_dtype = dt_ptr[16], dK_sm = dt_ptr[17]; + int64_t dV_dtype = dt_ptr[18], dV_sm = dt_ptr[19]; + + return fused_attn_bwd_noalloc( + max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, set_zero, qkv_layout, bias_type, + attn_mask_type, softmax_type, window_size, bottom_right_diagonal, deterministic, cu_seqlens_q, + cu_seqlens_kv, Q_data, Q_dtype, Q_scale_inv, Q_sm, K_data, K_dtype, K_scale_inv, K_sm, V_data, + V_dtype, V_scale_inv, V_sm, O_data, O_dtype, O_scale_inv, O_sm, dO_data, dO_dtype, + dO_scale_inv, dO_sm, S_data, S_dtype, S_amax, S_scale, S_scale_inv, S_sm, dP_data, dP_dtype, + dP_amax, dP_scale, dP_scale_inv, dP_sm, dQ_data, dQ_dtype, dQ_amax, dQ_scale, dQ_scale_inv, + dQ_sm, dK_data, dK_dtype, dK_amax, dK_scale, dK_scale_inv, dK_sm, dV_data, dV_dtype, dV_amax, + dV_scale, dV_scale_inv, dV_sm, dBias, dSoftmaxOffset, num_aux_tensors, aux0, aux1, aux2, aux3, + aux4, aux5, aux6, aux7, aux8, aux9, cu_seqlens_q_padded, cu_seqlens_kv_padded, cuda_graph); } -void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at::Tensor v_cache, - at::Tensor page_table, at::Tensor cu_new_lens, at::Tensor cu_cached_lens, - NVTE_QKV_Format qkv_format, int b, int max_ctx_len, int max_seq_len, - int max_pages_per_seq, bool is_non_paged) { - NVTE_CHECK(k_cache.scalar_type() == v_cache.scalar_type() && - new_k.scalar_type() == new_v.scalar_type() && - new_k.scalar_type() == k_cache.scalar_type(), - "new_k, new_v, k_cache and v_cache must be of the same data type."); - NVTE_CHECK(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || - qkv_format == NVTE_QKV_Format::NVTE_THD, - "qkv_format must be {BSHD, SBHD, THD}."); - - auto te_new_k = makeTransformerEngineTensor(new_k); - auto te_new_v = makeTransformerEngineTensor(new_v); - auto te_k_cache = makeTransformerEngineTensor(k_cache); - auto te_v_cache = makeTransformerEngineTensor(v_cache); - auto te_page_table = makeTransformerEngineTensor(page_table); - auto te_cu_new_lens = makeTransformerEngineTensor(cu_new_lens); - auto te_cu_cached_lens = makeTransformerEngineTensor(cu_cached_lens); - - nvte_copy_to_kv_cache(te_new_k.data(), te_new_v.data(), te_k_cache.data(), te_v_cache.data(), - te_page_table.data(), te_cu_new_lens.data(), te_cu_cached_lens.data(), - qkv_format, b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged, - at::cuda::getCurrentCUDAStream()); +} // namespace transformer_engine::pytorch::stable + +STABLE_TORCH_LIBRARY_FRAGMENT(transformer_engine_stable, m) { + // Fused attention forward/backward (noalloc, flattened aux tensors) + m.def( + "fused_attn_fwd_noalloc(int max_seqlen_q, int max_seqlen_kv, bool is_training, float " + "attn_scale, float p_dropout, bool set_zero, int qkv_layout, int bias_type, int " + "attn_mask_type, int softmax_type, int[] window_size, bool bottom_right_diagonal, Tensor " + "cu_seqlens_q, Tensor cu_seqlens_kv, Tensor Q_data, int Q_dtype, Tensor? Q_scale_inv, int " + "Q_scaling_mode, Tensor K_data, int K_dtype, Tensor? K_scale_inv, int K_scaling_mode, Tensor " + "V_data, int V_dtype, Tensor? V_scale_inv, int V_scaling_mode, Tensor S_data, int S_dtype, " + "Tensor? S_amax, Tensor? S_scale, Tensor? S_scale_inv, int S_scaling_mode, Tensor O_data, " + "int O_dtype, Tensor? O_amax, Tensor? O_scale, Tensor? O_scale_inv, int O_scaling_mode, " + "Tensor? cu_seqlens_q_padded, Tensor? cu_seqlens_kv_padded, Tensor? page_table_k, Tensor? " + "page_table_v, Tensor? Bias, Tensor? SoftmaxOffset, Tensor rng_state, bool return_max_logit, " + "bool cuda_graph) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, " + "Tensor, Tensor, int)"); + m.def( + "fused_attn_bwd_packed(" + "int max_seqlen_q, int max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, " + "int qkv_layout, int bias_type, int attn_mask_type, int softmax_type, int[] window_size, " + "bool bottom_right_diagonal, bool deterministic, bool cuda_graph, " + "Tensor cu_seqlens_q, Tensor cu_seqlens_kv, Tensor? cu_seqlens_q_padded, Tensor? " + "cu_seqlens_kv_padded, " + "Tensor Q_data, Tensor? Q_scale_inv, Tensor K_data, Tensor? K_scale_inv, " + "Tensor V_data, Tensor? V_scale_inv, Tensor O_data, Tensor? O_scale_inv, " + "Tensor dO_data, Tensor? dO_scale_inv, " + "Tensor S_data, Tensor? S_amax, Tensor? S_scale, Tensor? S_scale_inv, " + "Tensor dP_data, Tensor? dP_amax, Tensor? dP_scale, Tensor? dP_scale_inv, " + "Tensor dQ_data, Tensor? dQ_amax, Tensor? dQ_scale, Tensor? dQ_scale_inv, " + "Tensor dK_data, Tensor? dK_amax, Tensor? dK_scale, Tensor? dK_scale_inv, " + "Tensor dV_data, Tensor? dV_amax, Tensor? dV_scale, Tensor? dV_scale_inv, " + "Tensor? dBias, Tensor? dSoftmaxOffset, Tensor dtype_info, " + "int num_aux_tensors, " + "Tensor? aux0, Tensor? aux1, Tensor? aux2, Tensor? aux3, Tensor? aux4, " + "Tensor? aux5, Tensor? aux6, Tensor? aux7, Tensor? aux8, Tensor? aux9" + ") -> (Tensor, Tensor)"); + // fused_attn_bwd_noalloc has 77 args which exceeds the 64-arg PyTorch + // dispatcher limit. Use fused_attn_bwd_packed instead. + // Helpers + m.def("fa_prepare_fwd(Tensor qkvi) -> Tensor"); + m.def("fa_prepare_bwd(Tensor q, Tensor k, Tensor v) -> Tensor"); + m.def("thd_read_half_tensor(Tensor tensor, Tensor cu_seqlens, int half_idx) -> Tensor"); + m.def( + "thd_second_half_lse_correction(Tensor lse, Tensor lse_per_step, Tensor cu_seqlens, bool " + "lse_packed) -> ()"); + m.def( + "thd_read_second_half_lse(Tensor lse, Tensor cu_seqlens, bool lse_packed, int " + "second_half_lse_seqlen) -> Tensor"); + m.def( + "thd_out_correction(Tensor out, Tensor out_per_step, Tensor lse, Tensor lse_per_step, Tensor " + "cu_seqlens, bool only_second_half, bool lse_packed) -> ()"); + m.def( + "thd_grad_correction(Tensor grad, Tensor grad_per_step, Tensor cu_seqlens, str first_half, " + "str second_half) -> ()"); + m.def( + "thd_get_partitioned_indices(Tensor cu_seqlens, int total_tokens, int world_size, int rank) " + "-> Tensor"); + m.def("convert_thd_to_bshd(Tensor tensor, Tensor cu_seqlens, int b, int max_seq_len) -> Tensor"); + m.def("convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, int t) -> Tensor"); + m.def( + "copy_to_kv_cache(Tensor new_k, Tensor new_v, Tensor k_cache, Tensor v_cache, Tensor " + "page_table, Tensor cu_new_lens, Tensor cu_cached_lens, int qkv_format, int b, int " + "max_ctx_len, int max_seq_len, int max_pages_per_seq, bool is_non_paged) -> ()"); } -} // namespace transformer_engine::pytorch +STABLE_TORCH_LIBRARY_IMPL(transformer_engine_stable, CUDA, m) { + using namespace transformer_engine::pytorch::stable; + m.impl("fused_attn_fwd_noalloc", TORCH_BOX(fused_attn_fwd_noalloc)); + m.impl("fused_attn_bwd_packed", TORCH_BOX(fused_attn_bwd_packed)); + // fused_attn_bwd_noalloc not registered (77 args > 64 limit); use fused_attn_bwd_packed + m.impl("fa_prepare_fwd", TORCH_BOX(fa_prepare_fwd)); + m.impl("fa_prepare_bwd", TORCH_BOX(fa_prepare_bwd)); + m.impl("thd_read_half_tensor", TORCH_BOX(thd_read_half_tensor)); + m.impl("thd_second_half_lse_correction", TORCH_BOX(thd_second_half_lse_correction)); + m.impl("thd_read_second_half_lse", TORCH_BOX(thd_read_second_half_lse)); + m.impl("thd_out_correction", TORCH_BOX(thd_out_correction)); + m.impl("thd_grad_correction", TORCH_BOX(thd_grad_correction)); + m.impl("thd_get_partitioned_indices", TORCH_BOX(thd_get_partitioned_indices)); + m.impl("convert_thd_to_bshd", TORCH_BOX(convert_thd_to_bshd)); + m.impl("convert_bshd_to_thd", TORCH_BOX(convert_bshd_to_thd)); + m.impl("copy_to_kv_cache", TORCH_BOX(copy_to_kv_cache)); +} diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index c59e3c4f64..ecab586295 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -4,270 +4,150 @@ * See LICENSE for license information. ************************************************************************/ -#include -#include - -#include -#include - -#include "common.h" -#include "extensions.h" -#include "pybind.h" -#include "transformer_engine/cast.h" -#include "transformer_engine/transformer_engine.h" - -namespace transformer_engine { -namespace pytorch { - -std::vector bgrad_quantize(const at::Tensor &grad_output, py::handle quantizer) { - using namespace transformer_engine::pytorch::detail; - init_extension(); - - // Grad output tensor - auto grad_output_torch = grad_output.contiguous(); - const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const auto shape = getTensorShape(grad_output_torch); - auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); - - // Construct grad bias tensor - const int64_t bias_size = static_cast(shape.back()); - auto grad_bias_torch = allocateTorchTensor(bias_size, grad_output_dtype); - auto grad_bias_nvte = makeTransformerEngineTensor(grad_bias_torch); - - // Unquantized impl only requires computing grad bias - if (quantizer.is_none()) { - if (product(shape) == 0) { - grad_bias_torch.zero_(); - } else { - at::sum_out(grad_bias_torch, grad_output_torch.reshape({-1, bias_size}), {0}); - } - return {py::cast(std::move(grad_bias_torch)), py::cast(std::move(grad_output_torch))}; - } - - // Construct grad input tensor - auto quantizer_cpp = convert_quantizer(quantizer); - auto [grad_input_nvte, grad_input_py] = quantizer_cpp->create_tensor(shape, grad_output_dtype); - - // Trivial impl if tensors are empty - if (product(shape) == 0) { - grad_bias_torch.zero_(); - return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; - } - - // Check if fused kernel is supported - bool with_fused_kernel = false; - if (detail::IsFloat8Quantizers(quantizer.ptr())) { - auto prop = at::cuda::getCurrentDeviceProperties(); - const size_t sm_arch = 10 * prop->major + prop->minor; - if (sm_arch >= 100) { - // Fused kernel for dbias + FP8 cast on SM arch 10.0+ - with_fused_kernel = true; - } else if (quantizer_cpp->rowwise_usage && quantizer_cpp->columnwise_usage) { - // Fused kernel for dbias + FP8 cast + FP8 transpose - with_fused_kernel = true; - } - } else if (detail::IsMXFP8Quantizers(quantizer.ptr())) { - // Fused kernel for dbias + MXFP8 quantize - with_fused_kernel = true; - } - - // Apply unfused impl if fused kernel is not supported - if (!with_fused_kernel) { - at::sum_out(grad_bias_torch, grad_output_torch.reshape({-1, bias_size}), {0}); - quantizer_cpp->quantize(grad_output_nvte, grad_input_nvte); - return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; +#include +#include + +#include "../stable_common.h" + +namespace transformer_engine::pytorch::stable { + +using Tensor = torch::stable::Tensor; + +// ============================================================================ +// bgrad_quantize: compute grad_bias and optionally quantize grad_output +// +// Fused kernel: nvte_quantize_dbias computes both dbias and quantized grad_input +// Unfused: at::sum for dbias + separate quantize +// ============================================================================ + +void bgrad_quantize_noalloc(Tensor grad_output, + Tensor grad_bias, // pre-allocated [hidden_size] + Tensor grad_input_data, int64_t grad_input_te_dtype, + std::optional grad_input_amax, + std::optional grad_input_scale, + std::optional grad_input_scale_inv, int64_t scaling_mode) { + auto grad_output_ = torch::stable::contiguous(grad_output); + auto grad_output_cu = makeTransformerEngineTensor(grad_output_); + auto grad_bias_cu = makeTransformerEngineTensor(grad_bias); + + auto shape = getStableTensorShape(grad_output_); + auto te_dtype = static_cast(grad_input_te_dtype); + auto nvte_scaling = static_cast(scaling_mode); + + auto grad_input_cu = + makeQuantizedTensorWrapper(grad_input_data, te_dtype, shape, grad_input_amax, + grad_input_scale, grad_input_scale_inv, nvte_scaling); + + auto device_idx = grad_output_.get_device_index(); + auto stream = getCurrentCUDAStreamRaw(device_idx); + + TensorWrapper workspace; + + // First call: query workspace + nvte_quantize_dbias(grad_output_cu.data(), grad_input_cu.data(), grad_bias_cu.data(), + workspace.data(), stream); + + // workspace_data must outlive the second kernel call — hoist out of if block. + Tensor workspace_data; + auto ws_shape = workspace.shape(); + auto ws_dtype = workspace.dtype(); + if (ws_shape.ndim > 0 && workspace.numel() > 0) { + workspace_data = allocateStableTensor( + std::vector(ws_shape.data, ws_shape.data + ws_shape.ndim), ws_dtype, device_idx); + workspace = makeTransformerEngineTensor( + workspace_data.data_ptr(), + std::vector(ws_shape.data, ws_shape.data + ws_shape.ndim), ws_dtype); } - // Query workspace size - TensorWrapper workspace_nvte; - at::Tensor workspace_torch; - auto stream = at::cuda::getCurrentCUDAStream(); - NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_dbias(grad_output_nvte.data(), grad_input_nvte.data(), grad_bias_nvte.data(), - workspace_nvte.data(), stream); - }); - - // Allocate workspace - if (workspace_nvte.ndim() > 0 && workspace_nvte.numel() > 0) { - workspace_torch = allocateSpace(workspace_nvte.shape(), workspace_nvte.dtype()); - workspace_nvte = makeTransformerEngineTensor(workspace_torch.data_ptr(), workspace_nvte.shape(), - workspace_nvte.dtype()); - } - - // Launch fused kernel - NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_dbias(grad_output_nvte.data(), grad_input_nvte.data(), grad_bias_nvte.data(), - workspace_nvte.data(), stream); - }); - - return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; + // Second call: compute + nvte_quantize_dbias(grad_output_cu.data(), grad_input_cu.data(), grad_bias_cu.data(), + workspace.data(), stream); } -namespace { - -std::vector dact_dbias( - void (*dact_dbias_func)(const NVTETensor, const NVTETensor, NVTETensor, NVTETensor, NVTETensor, - cudaStream_t), - void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t), - at::Tensor grad_output_torch, at::Tensor act_input_torch, py::handle quantizer_py) { - using namespace transformer_engine::pytorch::detail; - init_extension(); - - // Grad output and activation input tensors - grad_output_torch = grad_output_torch.contiguous(); - const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const auto output_shape = getTensorShape(grad_output_torch); - auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); - act_input_torch = act_input_torch.contiguous(); - const TensorWrapper &act_input_nvte = makeTransformerEngineTensor(act_input_torch); - const auto input_shape = getTensorShape(act_input_torch); - - // Construct tensors - auto quantizer_cpp = convert_quantizer(quantizer_py); - auto [grad_input_nvte, grad_input_py] = - quantizer_cpp->create_tensor(input_shape, grad_output_dtype); - const int64_t bias_size = static_cast(input_shape.back()); - auto grad_bias_torch = allocateTorchTensor(bias_size, grad_output_dtype); - auto grad_bias_nvte = makeTransformerEngineTensor(grad_bias_torch); - - // Return immediately if tensors are empty - if (product(output_shape) == 0) { - grad_bias_torch.zero_(); - return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; - } - - // Choose implementation - enum class Impl { - UNFUSED, - FUSED_DACT_DBIAS_QUANTIZE, - FUSED_DACT_AMAX_FP8, - FUSED_DACT_AMAX_NVFP4 +// ============================================================================ +// Fused dact + dbias + quantize +// +// activation_type: 0=dgelu, 1=dsilu, 2=drelu, 3=dqgelu, 4=dsrelu +// Fused kernel computes dact(grad_output, act_input), dbias, and quantize in one pass +// ============================================================================ + +void dact_dbias_noalloc(Tensor grad_output, Tensor act_input, + Tensor grad_bias, // pre-allocated [hidden_size] + Tensor grad_input_data, int64_t grad_input_te_dtype, + std::optional grad_input_amax, + std::optional grad_input_scale, + std::optional grad_input_scale_inv, int64_t scaling_mode, + int64_t activation_type) { + auto grad_output_ = torch::stable::contiguous(grad_output); + auto act_input_ = torch::stable::contiguous(act_input); + + auto grad_output_cu = makeTransformerEngineTensor(grad_output_); + auto act_input_cu = makeTransformerEngineTensor(act_input_); + auto grad_bias_cu = makeTransformerEngineTensor(grad_bias); + + auto shape = getStableTensorShape(act_input_); + auto te_dtype = static_cast(grad_input_te_dtype); + auto nvte_scaling = static_cast(scaling_mode); + + auto grad_input_cu = + makeQuantizedTensorWrapper(grad_input_data, te_dtype, shape, grad_input_amax, + grad_input_scale, grad_input_scale_inv, nvte_scaling); + + auto device_idx = grad_output_.get_device_index(); + auto stream = getCurrentCUDAStreamRaw(device_idx); + + // Fused dact + dbias + quantize kernel table + using FusedFn = void (*)(const NVTETensor, const NVTETensor, NVTETensor, NVTETensor, NVTETensor, + cudaStream_t); + static constexpr FusedFn fused_table[] = { + nvte_quantize_dbias_dgelu, nvte_quantize_dbias_dsilu, nvte_quantize_dbias_drelu, + nvte_quantize_dbias_dqgelu, nvte_quantize_dbias_dsrelu, }; - Impl impl = Impl::UNFUSED; - if (detail::IsFloat8Quantizers(quantizer_py.ptr()) || - detail::IsMXFP8Quantizers(quantizer_py.ptr())) { - impl = Impl::FUSED_DACT_DBIAS_QUANTIZE; - } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) { - impl = Impl::FUSED_DACT_AMAX_FP8; - } else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) { - auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); - NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer - impl = Impl::UNFUSED; - } else { - impl = Impl::FUSED_DACT_AMAX_NVFP4; - } - } - - // Perform compute - auto stream = at::cuda::getCurrentCUDAStream(); - switch (impl) { - case Impl::UNFUSED: - // Unfused dact, dbias, quantize - { - auto [temp_nvte, temp_py] = - NoneQuantizer(py::none()).create_tensor(input_shape, grad_output_dtype); - NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream); - }); - const auto temp_torch = temp_py.cast(); - at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0}); - quantizer_cpp->quantize(temp_nvte, grad_input_nvte); - break; - } - case Impl::FUSED_DACT_DBIAS_QUANTIZE: - // Fused dact-dbias-quantize kernel - { - // Query workspace size - TensorWrapper workspace_nvte; - NVTE_SCOPED_GIL_RELEASE({ - dact_dbias_func(grad_output_nvte.data(), act_input_nvte.data(), grad_input_nvte.data(), - grad_bias_nvte.data(), workspace_nvte.data(), stream); - }); - - // Allocate workspace - at::Tensor workspace_torch; - if (workspace_nvte.ndim() > 0 && workspace_nvte.numel() > 0) { - workspace_torch = allocateSpace(workspace_nvte.shape(), workspace_nvte.dtype()); - workspace_nvte = makeTransformerEngineTensor( - workspace_torch.data_ptr(), workspace_nvte.shape(), workspace_nvte.dtype()); - } - - // Launch kernel - NVTE_SCOPED_GIL_RELEASE({ - dact_dbias_func(grad_output_nvte.data(), act_input_nvte.data(), grad_input_nvte.data(), - grad_bias_nvte.data(), workspace_nvte.data(), stream); - }); - break; - } - case Impl::FUSED_DACT_AMAX_FP8: - // Fused dact-amax kernel, unfused dbias and FP8 quantize - { - auto *fp8_quantizer_cpp = - dynamic_cast(quantizer_cpp.get()); - NVTE_CHECK(fp8_quantizer_cpp != nullptr, - "Invalid quantizer for fused dact-amax kernel impl"); - auto [temp_nvte, temp_py] = - fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, grad_output_dtype); - NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream); - }); - const auto temp_torch = temp_py.cast(); - at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0}); - fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); - break; - } - case Impl::FUSED_DACT_AMAX_NVFP4: - // Fused dact-amax kernel, unfused dbias and NVFP4 quantize - { - auto *nvfp4_quantizer_cpp = - static_cast(quantizer_cpp.get()); // Already checked cast is valid - NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, - "Invalid quantizer for fused dact-amax kernel impl"); - auto [temp_nvte, temp_py] = nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax( - grad_input_nvte, grad_output_dtype); - NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream); - }); - const auto temp_torch = temp_py.cast(); - at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0}); - nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); - break; - } - default: - NVTE_ERROR("Invalid implementation"); + constexpr int num_fns = sizeof(fused_table) / sizeof(fused_table[0]); + STD_TORCH_CHECK(activation_type >= 0 && activation_type < num_fns, + "Invalid activation_type for dact_dbias: ", activation_type); + + auto fn = fused_table[activation_type]; + TensorWrapper workspace; + + // First call: query workspace + fn(grad_output_cu.data(), act_input_cu.data(), grad_input_cu.data(), grad_bias_cu.data(), + workspace.data(), stream); + + // workspace_data must outlive the second kernel call — hoist out of if block. + Tensor workspace_data; + auto ws_shape = workspace.shape(); + auto ws_dtype = workspace.dtype(); + if (ws_shape.ndim > 0 && workspace.numel() > 0) { + workspace_data = allocateStableTensor( + std::vector(ws_shape.data, ws_shape.data + ws_shape.ndim), ws_dtype, device_idx); + workspace = makeTransformerEngineTensor( + workspace_data.data_ptr(), + std::vector(ws_shape.data, ws_shape.data + ws_shape.ndim), ws_dtype); } - return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; + // Second call: compute + fn(grad_output_cu.data(), act_input_cu.data(), grad_input_cu.data(), grad_bias_cu.data(), + workspace.data(), stream); } -} // namespace - -std::vector dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer) { - return dact_dbias(nvte_quantize_dbias_dgelu, nvte_dgelu, grad_output, act_input, quantizer); +} // namespace transformer_engine::pytorch::stable + +STABLE_TORCH_LIBRARY_FRAGMENT(transformer_engine_stable, m) { + m.def( + "bgrad_quantize_noalloc(Tensor grad_output, Tensor grad_bias, Tensor grad_input_data, int " + "grad_input_te_dtype, Tensor? grad_input_amax, Tensor? grad_input_scale, Tensor? " + "grad_input_scale_inv, int scaling_mode) -> ()"); + // activation_type: 0=dgelu, 1=dsilu, 2=drelu, 3=dqgelu, 4=dsrelu + m.def( + "dact_dbias_noalloc(Tensor grad_output, Tensor act_input, Tensor grad_bias, Tensor " + "grad_input_data, int grad_input_te_dtype, Tensor? grad_input_amax, Tensor? " + "grad_input_scale, Tensor? grad_input_scale_inv, int scaling_mode, int activation_type) -> " + "()"); } -std::vector dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer) { - return dact_dbias(nvte_quantize_dbias_dsilu, nvte_dsilu, grad_output, act_input, quantizer); +STABLE_TORCH_LIBRARY_IMPL(transformer_engine_stable, CUDA, m) { + using namespace transformer_engine::pytorch::stable; + m.impl("bgrad_quantize_noalloc", TORCH_BOX(bgrad_quantize_noalloc)); + m.impl("dact_dbias_noalloc", TORCH_BOX(dact_dbias_noalloc)); } - -std::vector dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer) { - return dact_dbias(nvte_quantize_dbias_drelu, nvte_drelu, grad_output, act_input, quantizer); -} - -std::vector dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer) { - return dact_dbias(nvte_quantize_dbias_dqgelu, nvte_dqgelu, grad_output, act_input, quantizer); -} - -std::vector dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer) { - return dact_dbias(nvte_quantize_dbias_dsrelu, nvte_dsrelu, grad_output, act_input, quantizer); -} - -} // namespace pytorch -} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index cb3434ec52..8f0a65d09a 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -4,1412 +4,251 @@ * See LICENSE for license information. ************************************************************************/ -#include "transformer_engine/cast.h" +#include +#include -#include -#include -#include -#include -#include -#include -#include -#include +#include "../stable_common.h" -#include "../extensions.h" -#include "common.h" -#include "common/util/system.h" -#include "pybind.h" -#include "transformer_engine/transformer_engine.h" +namespace transformer_engine::pytorch::stable { -namespace transformer_engine { -namespace pytorch { +using Tensor = torch::stable::Tensor; -namespace { +// ============================================================================ +// Quantize: input (hp) → output (fp8/fp4) +// Covers delayed scaling (pre-computed scale) and the quantize-only step +// of the FUSED_NORM_AMAX path (scale computed from amax). +// ============================================================================ -std::vector get_tensor_shape(const TensorWrapper &tensor) { - const auto &shape = tensor.shape(); - return std::vector(shape.data, shape.data + shape.ndim); -} - -} // namespace - -py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, - std::optional noop_flag) { - // Convert quantizer to C++ object - auto quantizer_cpp = convert_quantizer(quantizer); - - // Convert input tensor to C++ object - auto input_contiguous = tensor.contiguous(); - auto input_cpp = makeTransformerEngineTensor(input_contiguous); - - // Set amax if use_existing_amax = true (only valid for CS) - bool use_existing_amax = false; - if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - use_existing_amax = quantizer.attr("use_existing_amax").cast(); - if (use_existing_amax) { - const at::Tensor &amax = quantizer.attr("amax").cast(); - input_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), - getTensorShape(amax)); - } - } +void quantize(Tensor input, Tensor output_data, int64_t output_te_dtype, + std::optional output_amax, std::optional output_scale, + std::optional output_scale_inv, int64_t scaling_mode, bool force_pow_2_scales, + double amax_epsilon, std::optional noop_flag, bool nvfp4_2d_quantization) { + auto shape = getStableTensorShape(input); + auto te_dtype = static_cast(output_te_dtype); + auto nvte_scaling = static_cast(scaling_mode); - // Initialize output tensor - TensorWrapper output_cpp; - py::object output_py; - if (output.is_none()) { - const auto shape = get_tensor_shape(input_cpp); - const auto fake_dtype = input_cpp.dtype(); - std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(shape, fake_dtype); - } else { - std::tie(output_cpp, output_py) = quantizer_cpp->convert_and_update_tensor(output); - } + auto input_cu = makeTransformerEngineTensor(input); + auto output_cu = makeQuantizedTensorWrapper(output_data, te_dtype, shape, output_amax, + output_scale, output_scale_inv, nvte_scaling); - // Initialize no-op flag - std::optional noop_flag_cpp; + QuantizationConfigWrapper quant_config; + std::optional noop_cu; if (noop_flag.has_value()) { - noop_flag_cpp = makeTransformerEngineTensor(*noop_flag); - } - - // Perform quantization - if (use_existing_amax) { - auto *quantizer_cs = dynamic_cast(quantizer_cpp.get()); - quantizer_cs->quantize_with_amax(input_cpp, output_cpp, noop_flag_cpp); - } else { - quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp); - } - - return output_py; -} - -namespace { - -// helper functions for NVFP4 grouped quantization (cuda graph safe with shapes stored in device without D2H copy) -void group_quantize_nvfp4_impl(const GroupedTensorWrapper &grouped_input_tensor, - GroupedTensorWrapper &grouped_output_tensor, - NVFP4Quantizer *nvfp4_quantizer_cpp, cudaStream_t stream) { - size_t num_tensors = grouped_input_tensor.num_tensors(); - - // assert the 2D scaling case, since 2D scaling grouped quant kernel is not ready yet - NVTE_CHECK(!nvfp4_quantizer_cpp->with_2d_quantization, - "2D scaling grouped quant kernel is not ready yet"); - - auto quant_config_cpp = QuantizationConfigWrapper(); - - // stochastic rounding - bool need_stochastic_rounding = nvfp4_quantizer_cpp->stochastic_rounding; - auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); - at::Tensor rng_states_tensor; // Declare tensor outside, do not allocate yet - TensorWrapper te_rng_state; - - if (need_stochastic_rounding) { - // in fused kernel, one rng state will be used by the grouped kernel to generate random - // number for different tensors in the group, so we only need to allocate one rng state - const size_t rng_elts_per_thread = 1024 * num_tensors; - rng_states_tensor = torch::empty({2}, opts); - auto gen = at::get_generator_or_default( - std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); - at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); - philox_unpack(philox_args, static_cast(rng_states_tensor.data_ptr())); - - te_rng_state = makeTransformerEngineTensor(rng_states_tensor); - quant_config_cpp.set_rng_state(te_rng_state.data()); - quant_config_cpp.set_stochastic_rounding(true); - } - - // fast math - const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); - if (use_fast_math) { - quant_config_cpp.set_use_fast_math(true); - } - - // so far, only the RHT path has grouped kernel support - // grouped kernels for non-RHT path will be added later - - if (nvfp4_quantizer_cpp->with_rht) { - // post-RHT amax or not - if (nvfp4_quantizer_cpp->with_post_rht_amax) { - NVTE_SCOPED_GIL_RELEASE({ - nvte_group_hadamard_transform_amax_graph_safe( - grouped_input_tensor.data(), grouped_output_tensor.data(), 0, - nvfp4_quantizer_cpp->rht_matrix_random_sign_mask_t, stream); - }); - } else { - NVTE_ERROR("graph safe grouped quant kernel for non-RHT path is not ready yet"); - } - - // RHT cast fusion - auto tile_scheduler_workspace_torch = - at::empty({1}, at::device(at::kCUDA).dtype(torch::kInt32)); - auto nvte_tile_scheduler_workspace = - makeTransformerEngineTensor(tile_scheduler_workspace_torch); - - auto rht_matrix_nvte = makeTransformerEngineTensor(nvfp4_quantizer_cpp->rht_matrix); - NVTE_SCOPED_GIL_RELEASE({ - nvte_group_hadamard_transform_cast_fusion_graph_safe( - grouped_input_tensor.data(), grouped_output_tensor.data(), rht_matrix_nvte.data(), - quant_config_cpp, nvte_tile_scheduler_workspace.data(), stream); - }); - - } else { - NVTE_ERROR("graph safe grouped quant kernel for non-RHT path is not ready yet"); - } -} - -} // namespace - -// NOTE: Only supports varying first dim. -py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, - std::optional first_dims) { - using namespace transformer_engine::pytorch::detail; - init_extension(); - - NVTE_CHECK(tensor.dim() == 2, "Tensor must be 2D"); - - std::vector logical_shape; - for (const auto &d : tensor.sizes()) { - logical_shape.push_back(d); - } - const auto logical_first_dim = logical_shape[0]; - const auto logical_last_dim = logical_shape[1]; - - bool empty_input_buffer = logical_first_dim == 0 || logical_last_dim == 0; - - auto quantizer_cpp = convert_quantizer(quantizer); - - // Create input GroupedTensor. - auto grouped_input_tensor = GroupedTensorWrapper(num_tensors, logical_shape); - grouped_input_tensor.set_rowwise_data( - tensor.data_ptr(), GetTransformerEngineDType(tensor.scalar_type()), getTensorShape(tensor)); - - // Create output GroupedTensor. - auto [grouped_output_tensor_cpp, grouped_output_py] = quantizer_cpp->create_grouped_tensor( - num_tensors, logical_shape, GetTransformerEngineDType(tensor.scalar_type()), - py::reinterpret_borrow(quantizer), first_dims, logical_first_dim, - logical_last_dim); - - // dispatch to scaling methods - enum class GroupedQuantizationMode { - MXFP8_GROUPED_QUANTIZE, - NVFP4_GROUPED_QUANTIZE, - INVALID_FOR_GROUPED_QUANTIZE - }; - GroupedQuantizationMode grouped_quantization_mode = - GroupedQuantizationMode::INVALID_FOR_GROUPED_QUANTIZE; - if (detail::IsMXFP8Quantizers(quantizer.ptr())) { - grouped_quantization_mode = GroupedQuantizationMode::MXFP8_GROUPED_QUANTIZE; - } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { - grouped_quantization_mode = GroupedQuantizationMode::NVFP4_GROUPED_QUANTIZE; - } - - if (empty_input_buffer) { - // early return for empty input buffer - // just return the output tensor as is - // no need to quantize - return py::reinterpret_borrow(grouped_output_py); - } - - switch (grouped_quantization_mode) { - case GroupedQuantizationMode::NVFP4_GROUPED_QUANTIZE: { - // NVFP4 grouped quantization - NVFP4Quantizer *nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); - group_quantize_nvfp4_impl(grouped_input_tensor, grouped_output_tensor_cpp, - nvfp4_quantizer_cpp, at::cuda::getCurrentCUDAStream()); - break; - } - case GroupedQuantizationMode::MXFP8_GROUPED_QUANTIZE: { - NVTE_SCOPED_GIL_RELEASE({ - nvte_group_quantize(grouped_input_tensor.data(), grouped_output_tensor_cpp.data(), - at::cuda::getCurrentCUDAStream()); - }); - break; - } - case GroupedQuantizationMode::INVALID_FOR_GROUPED_QUANTIZE: - default: - NVTE_ERROR("group_quantize: only support NVFP4 or MXFP8 quantizer."); - break; - } - - return py::reinterpret_borrow(grouped_output_py); -} - -py::object dequantize(const py::handle &input, transformer_engine::DType otype) { - init_extension(); - - const auto none = py::none(); - - const auto &input_tensor = makeTransformerEngineTensor(input, none); - - NoneQuantizer q(none); - - const auto &shape = convertShape(input_tensor.shape()); - - auto [out_tensor, out] = q.create_tensor(shape, otype); - - NVTE_SCOPED_GIL_RELEASE({ - nvte_dequantize(input_tensor.data(), out_tensor.data(), at::cuda::getCurrentCUDAStream()); - }); - - return out; -} - -namespace { - -void multi_tensor_quantize_impl(const std::vector &input_list, - std::vector &quantizer_py_list, - std::vector> &quantizer_cpp_list, - std::vector &output_list) { - // Check number of tensors - const size_t num_tensors = input_list.size(); - NVTE_CHECK(quantizer_py_list.size() == num_tensors, "Expected ", num_tensors, - " Python quantizers, but got ", quantizer_py_list.size()); - NVTE_CHECK(quantizer_cpp_list.size() == num_tensors, "Expected ", num_tensors, - " C++ quantizers, but got ", quantizer_cpp_list.size()); - NVTE_CHECK(output_list.size() == num_tensors, "Expected ", num_tensors, - " output tensors, but got ", output_list.size()); - - // Choose implementation - // Note: Currently only have fused kernel for FP8 delayed scaling - bool with_fused_kernel = true; - for (size_t i = 0; i < num_tensors; i++) { - if (!detail::IsFloat8Quantizers(quantizer_py_list[i].ptr())) { - with_fused_kernel = false; - break; - } - if (nvte_tensor_data(output_list[i].data()) == nullptr || - nvte_tensor_columnwise_data(output_list[i].data()) == nullptr) { - with_fused_kernel = false; - break; - } - } - - // Launch TE kernel - if (with_fused_kernel) { - // Fused kernel for multi-tensor quantize - std::vector nvte_tensor_input_list; - std::vector nvte_tensor_output_list; - for (size_t i = 0; i < num_tensors; ++i) { - nvte_tensor_input_list.push_back(input_list[i].data()); - nvte_tensor_output_list.push_back(output_list[i].data()); - } - NVTE_SCOPED_GIL_RELEASE({ - nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(), - nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream()); - }); - } else { - // Quantize kernels individually - for (size_t i = 0; i < num_tensors; ++i) { - quantizer_cpp_list[i]->quantize(input_list[i], output_list[i]); - } - } -} - -} // namespace - -std::vector multi_tensor_quantize(const std::vector &tensor_list, - std::vector quantizer_list) { - // Check number of tensors - const size_t num_tensors = tensor_list.size(); - NVTE_CHECK(quantizer_list.size() == num_tensors, "Expected ", num_tensors, - " quantizers, but got ", quantizer_list.size()); - - // Convert quantizers to C++ objects - std::vector> quantizer_cpp_list; - for (size_t i = 0; i < num_tensors; i++) { - quantizer_cpp_list.push_back(convert_quantizer(quantizer_list[i])); - } - - // Initialize input and output tensors - std::vector input_cpp_list; - std::vector output_cpp_list; - std::vector output_py_list; - for (size_t i = 0; i < num_tensors; ++i) { - // Convert input tensor to C++ object - const auto &input_py = tensor_list[i]; - NVTE_CHECK(input_py.is_contiguous(), "Input tensor ", i, " is not contiguous"); - input_cpp_list.emplace_back(makeTransformerEngineTensor(input_py)); - const auto &input_cpp = input_cpp_list.back(); - const auto input_shape = input_cpp.shape(); - const auto input_dtype = GetTransformerEngineDType(input_py.scalar_type()); - - // Construct output tensor - std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); - auto [output_cpp, output_py] = quantizer_cpp_list[i]->create_tensor(output_shape, input_dtype); - output_cpp_list.emplace_back(std::move(output_cpp)); - output_py_list.emplace_back(std::move(output_py)); - } - - // Perform multi-tensor quantization - multi_tensor_quantize_impl(input_cpp_list, quantizer_list, quantizer_cpp_list, output_cpp_list); - - return output_py_list; -} - -namespace { - -std::tuple, std::vector> bulk_allocate_fp8_blockwise_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, - std::vector &quantizer_cpp_list) { - init_extension(); - std::tuple, std::vector> retval; - auto &tensor_py_list = std::get<0>(retval); - auto &tensor_cpp_list = std::get<1>(retval); - - // Number of tensors - const size_t num_tensors = shape_list.size(); - if (num_tensors == 0) { - return retval; + noop_cu.emplace(makeTransformerEngineTensor(noop_flag.value())); + quant_config.set_noop_tensor(noop_cu->data()); } + quant_config.set_force_pow_2_scales(force_pow_2_scales); + quant_config.set_amax_epsilon(static_cast(amax_epsilon)); + quant_config.set_nvfp4_2d_quantization(nvfp4_2d_quantization); - // Quantization parameters - const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; - const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; - const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); - const auto is_2D_scaled = scaling_mode == NVTE_BLOCK_SCALING_2D; - const auto fp8_dtype = quantizer_cpp_list[0]->dtype; - constexpr size_t fp8_elem_size = 1; - constexpr size_t scale_elem_size = 4; - - // Helper function to construct tensor view - // Note: Deleter holds a shared_ptr for the buffer, so the buffer - // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, - size_t offset, at::ScalarType dtype) -> at::Tensor { - std::vector shape_int64(shape.begin(), shape.end()); - bool is_empty_shape = product(shape) == 0; - if (buffer->data_ptr() == nullptr || is_empty_shape) { - return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); - } - return at::from_blob( - buffer->data_ptr() + offset, shape_int64, - [buffer](void *) {}, // deleter holds shared_ptr - at::device(at::kCUDA).dtype(dtype)); - }; - - // Allocate row-wise data - std::vector rowwise_data_list, rowwise_scale_list; - std::vector> rowwise_data_shapes, rowwise_scale_shapes; - if (rowwise_usage) { - // Tensor sizes - for (size_t i = 0; i < num_tensors; ++i) { - rowwise_data_shapes.emplace_back(shape_list[i]); - rowwise_scale_shapes.emplace_back( - quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); - } - - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets; - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 256); // align to 256B - data_offsets.push_back(buffer_size); - buffer_size += product(rowwise_data_shapes[i]) * fp8_elem_size; - } - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 16); // align to 16B - scale_offsets.push_back(buffer_size); - buffer_size += product(rowwise_scale_shapes[i]) * scale_elem_size; - } - - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - - // Construct tensor views - for (size_t i = 0; i < num_tensors; ++i) { - rowwise_data_list.emplace_back( - make_torch_view(buffer, rowwise_data_shapes[i], data_offsets[i], torch::kUInt8)); - rowwise_scale_list.emplace_back( - make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kFloat32)); - } - } - - // Allocate column-wise data - std::vector columnwise_data_list, columnwise_scale_list; - std::vector> columnwise_data_shapes, columnwise_scale_shapes; - if (columnwise_usage) { - // Tensor sizes - for (size_t i = 0; i < num_tensors; ++i) { - columnwise_data_shapes.emplace_back(); - auto &shape = columnwise_data_shapes.back(); - shape.push_back(shape_list[i].back()); - for (size_t j = 0; j < shape_list[i].size() - 1; ++j) { - shape.push_back(shape_list[i][j]); - } - columnwise_scale_shapes.emplace_back( - quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true)); - } - - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets; - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 256); // align to 256B - data_offsets.push_back(buffer_size); - buffer_size += product(columnwise_data_shapes[i]) * fp8_elem_size; - } - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 16); // align to 16B - scale_offsets.push_back(buffer_size); - buffer_size += product(columnwise_scale_shapes[i]) * scale_elem_size; - } - - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - - // Construct tensor views - for (size_t i = 0; i < num_tensors; ++i) { - columnwise_data_list.emplace_back( - make_torch_view(buffer, columnwise_data_shapes[i], data_offsets[i], torch::kUInt8)); - columnwise_scale_list.emplace_back( - make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kFloat32)); - } - } - - // Construct FP8 block-wise tensors - py::handle Float8BlockwiseQTensorClass( - reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass)); - for (size_t i = 0; i < num_tensors; ++i) { - // Create tensor objects with proper reference counting - py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none(); - py::object rowwise_scale = rowwise_usage ? py::cast(rowwise_scale_list[i]) : py::none(); - py::object columnwise_data = - (columnwise_usage ? py::cast(columnwise_data_list[i]) : py::none()); - py::object columnwise_scale = - (columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none()); - - // Construct Python tensor - tensor_py_list.emplace_back( - Float8BlockwiseQTensorClass(rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, - fp8_dtype, quantizer_py_list[i], is_2D_scaled)); - - // Construct C++ tensor - tensor_cpp_list.emplace_back(makeTransformerEngineTensor( - rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, - columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_data_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_data_shapes[i] : std::vector{0}, fp8_dtype, nullptr, - nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, - columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_scale_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode)); - } - - return retval; + auto stream = getCurrentCUDAStreamRaw(input.get_device_index()); + nvte_quantize_v2(input_cu.data(), output_cu.data(), quant_config, stream); } -std::tuple, std::vector> bulk_allocate_mxfp8_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, - std::vector &quantizer_cpp_list) { - init_extension(); - std::tuple, std::vector> retval; - auto &tensor_py_list = std::get<0>(retval); - auto &tensor_cpp_list = std::get<1>(retval); - - // Number of tensors - const size_t num_tensors = shape_list.size(); - if (num_tensors == 0) { - return retval; - } - - // Quantization parameters - const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; - const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; - const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); - const auto fp8_dtype = quantizer_cpp_list[0]->dtype; - const bool with_gemm_swizzled_scales = quantizer_cpp_list[0]->optimize_for_gemm; - - constexpr size_t fp8_elem_size = 1; - constexpr size_t scale_elem_size = 1; - - // Helper function to construct tensor view - // Note: Deleter holds a shared_ptr for the buffer, so the buffer - // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, - size_t offset, at::ScalarType dtype) -> at::Tensor { - std::vector shape_int64(shape.begin(), shape.end()); - bool is_empty_shape = product(shape) == 0; - if (buffer->data_ptr() == nullptr || is_empty_shape) { - return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); - } - return at::from_blob( - buffer->data_ptr() + offset, shape_int64, - [buffer](void *) {}, // deleter holds shared_ptr - at::device(at::kCUDA).dtype(dtype)); - }; - - // Allocate row-wise data - std::vector rowwise_data_list, rowwise_scale_list; - std::vector> rowwise_data_shapes, rowwise_scale_shapes; - if (rowwise_usage) { - // Tensor sizes - for (size_t i = 0; i < num_tensors; ++i) { - rowwise_data_shapes.emplace_back(shape_list[i]); - rowwise_scale_shapes.emplace_back( - quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); - } - - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets; - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 256); // align to 256B - data_offsets.push_back(buffer_size); - buffer_size += product(rowwise_data_shapes[i]) * fp8_elem_size; - } - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 16); // align to 16B - scale_offsets.push_back(buffer_size); - buffer_size += product(rowwise_scale_shapes[i]) * scale_elem_size; - } +// ============================================================================ +// Quantize with both rowwise and columnwise output in one fused kernel call. +// This is needed for MXFP8 (and potentially other formats) where +// nvte_quantize_v2 can fill both rowwise and columnwise buffers simultaneously. +// ============================================================================ - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); +void quantize_bidirectional(Tensor input, Tensor output_rowwise_data, int64_t output_te_dtype, + std::optional output_amax, std::optional output_scale, + Tensor output_rowwise_scale_inv, Tensor output_columnwise_data, + Tensor output_columnwise_scale_inv, int64_t scaling_mode, + bool force_pow_2_scales, double amax_epsilon, + std::optional noop_flag, bool nvfp4_2d_quantization) { + auto shape = getStableTensorShape(input); + auto te_dtype = static_cast(output_te_dtype); + auto nvte_scaling = static_cast(scaling_mode); - // Construct tensor views - for (size_t i = 0; i < num_tensors; ++i) { - rowwise_data_list.emplace_back( - make_torch_view(buffer, rowwise_data_shapes[i], data_offsets[i], torch::kUInt8)); - rowwise_scale_list.emplace_back( - make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); - } - } - - // Allocate column-wise data - std::vector columnwise_data_list, columnwise_scale_list; - std::vector> columnwise_data_shapes, columnwise_scale_shapes; - if (columnwise_usage) { - // Tensor sizes - for (size_t i = 0; i < num_tensors; ++i) { - // For MXFP8, the columnwise data doesn't need transpose - // because of TN, NT, NN layout support in SM100 - columnwise_data_shapes.emplace_back(shape_list[i]); - columnwise_scale_shapes.emplace_back( - quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true)); - } - - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets; - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 256); // align to 256B - data_offsets.push_back(buffer_size); - buffer_size += product(columnwise_data_shapes[i]) * fp8_elem_size; - } - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 16); // align to 16B - scale_offsets.push_back(buffer_size); - buffer_size += product(columnwise_scale_shapes[i]) * scale_elem_size; - } + auto input_cu = makeTransformerEngineTensor(input); - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + // Build output TensorWrapper with both rowwise and columnwise data + TensorWrapper output_cu(nvte_scaling); + output_cu.set_rowwise_data(output_rowwise_data.data_ptr(), te_dtype, shape); - // Construct tensor views - for (size_t i = 0; i < num_tensors; ++i) { - columnwise_data_list.emplace_back( - make_torch_view(buffer, columnwise_data_shapes[i], data_offsets[i], torch::kUInt8)); - columnwise_scale_list.emplace_back( - make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); - } + // Determine scale_inv dtype from scaling mode + DType si_dtype = DType::kFloat32; + if (nvte_scaling == NVTE_MXFP8_1D_SCALING) { + si_dtype = DType::kFloat8E8M0; + } else if (nvte_scaling == NVTE_NVFP4_1D_SCALING) { + si_dtype = DType::kFloat8E4M3; } - // Construct mxfp8 tensors - py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorStoragePythonClass)); - for (size_t i = 0; i < num_tensors; ++i) { - // Create tensor objects with proper reference counting - py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none(); - py::object rowwise_scale = rowwise_usage ? py::cast(rowwise_scale_list[i]) : py::none(); - py::object columnwise_data = - (columnwise_usage ? py::cast(columnwise_data_list[i]) : py::none()); - py::object columnwise_scale = - (columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none()); + auto rw_si_shape = getStableTensorShape(output_rowwise_scale_inv); + output_cu.set_rowwise_scale_inv(output_rowwise_scale_inv.data_ptr(), si_dtype, rw_si_shape); - // Construct Python tensor - tensor_py_list.emplace_back(MXFP8TensorClass(rowwise_data, rowwise_scale, columnwise_data, - columnwise_scale, fp8_dtype, quantizer_py_list[i], - with_gemm_swizzled_scales)); - - // Construct C++ tensor - tensor_cpp_list.emplace_back(makeTransformerEngineTensor( - rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, - columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_data_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_data_shapes[i] : std::vector{0}, fp8_dtype, nullptr, - nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, - columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_scale_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode)); - tensor_cpp_list.back().set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); + // For MXFP8, columnwise data has the same logical shape as the input [M, K]. + // For NVFP4/block-scaling, columnwise data is the transpose [K, M]. + // Use the actual tensor shape to let TensorWrapper::shape() compute correctly. + auto cw_data_shape = getStableTensorShape(output_columnwise_data); + // FP4 data is packed (2 elements per byte). Double last dim for logical shape. + if (is_fp4_dtype(te_dtype) && !cw_data_shape.empty()) { + cw_data_shape.back() *= 2; } + output_cu.set_columnwise_data(output_columnwise_data.data_ptr(), te_dtype, cw_data_shape); - return retval; -} - -// allocate fp4 data, fp8 scalings, and amax values -// layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN] -// amax buffer will be zeroed out by later amax kernels, so we can use empty to allocate -std::tuple, std::vector, bool> bulk_allocate_nvfp4_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, - std::vector &quantizer_cpp_list) { - init_extension(); - std::tuple, std::vector, bool> retval; - auto &tensor_py_list = std::get<0>(retval); - auto &tensor_cpp_list = std::get<1>(retval); - auto &contiguous_data_and_scale = std::get<2>(retval); - contiguous_data_and_scale = true; + auto cw_si_shape = getStableTensorShape(output_columnwise_scale_inv); + output_cu.set_columnwise_scale_inv(output_columnwise_scale_inv.data_ptr(), si_dtype, cw_si_shape); - // Number of tensors - const size_t num_tensors = shape_list.size(); - if (num_tensors == 0) { - return retval; + const std::vector scalar_shape{1}; + if (output_amax.has_value() && output_amax->numel() > 0) { + output_cu.set_amax(output_amax->data_ptr(), DType::kFloat32, scalar_shape); } - - // Quantization parameters - const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; - const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; - const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); - const auto fp4_dtype = quantizer_cpp_list[0]->dtype; - const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) Enable based on optimize_for_gemm; - constexpr size_t scale_elem_size = 1; - - // Helper function to construct tensor view - // Note: Deleter holds a shared_ptr for the buffer, so the buffer - // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, - size_t offset, at::ScalarType dtype) -> at::Tensor { - std::vector shape_int64(shape.begin(), shape.end()); - bool is_empty_shape = product(shape) == 0; - if (buffer->data_ptr() == nullptr || is_empty_shape) { - return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); - } - return at::from_blob( - buffer->data_ptr() + offset, shape_int64, - [buffer](void *) {}, // deleter holds shared_ptr - at::device(at::kCUDA).dtype(dtype)); - }; - - // Lambda function for converting std::vector shape to NVFP4 shape (last dim divided by 2) - auto to_fp4_shape = [](const std::vector &shape) { - std::vector fp4_shape(shape.begin(), shape.end()); - if (!fp4_shape.empty()) { - fp4_shape.back() /= 2; - } - return fp4_shape; - }; - - // Allocate row-wise data - std::vector rowwise_data_list, rowwise_scale_list, amax_rowwise_list; - std::vector> rowwise_data_shapes, rowwise_scale_shapes; - if (rowwise_usage) { - // Tensor sizes - for (size_t i = 0; i < num_tensors; ++i) { - rowwise_data_shapes.emplace_back(shape_list[i]); - rowwise_scale_shapes.emplace_back( - quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); - } - - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets, amax_offsets; - for (size_t i = 0; i < num_tensors; ++i) { - // FP4 data is aligned to 256B - const auto offset = roundup(buffer_size, 256); - if (offset != buffer_size) { - contiguous_data_and_scale = false; - } - data_offsets.push_back(offset); - buffer_size = offset + (product(rowwise_data_shapes[i]) + 1) / 2; - } - for (size_t i = 0; i < num_tensors; ++i) { - // Scales are aligned to 16B - const auto offset = roundup(buffer_size, 16); - if (offset != buffer_size) { - contiguous_data_and_scale = false; - } - scale_offsets.push_back(offset); - buffer_size = offset + product(rowwise_scale_shapes[i]) * scale_elem_size; - } - for (size_t i = 0; i < num_tensors; ++i) { - // Amaxes (FP32) are aligned to 16B - // Note: Multi-quantize kernel does not require contiguous amaxes. - const auto offset = roundup(buffer_size, 16); - amax_offsets.push_back(offset); - buffer_size = offset + 4; - } - - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - - // Construct tensor views - for (size_t i = 0; i < num_tensors; ++i) { - rowwise_data_list.emplace_back(make_torch_view(buffer, to_fp4_shape(rowwise_data_shapes[i]), - data_offsets[i], torch::kUInt8)); - rowwise_scale_list.emplace_back( - make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); - amax_rowwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); - } - } - - // Allocate column-wise data - std::vector columnwise_data_list, columnwise_scale_list, amax_columnwise_list; - std::vector> columnwise_data_shapes, columnwise_scale_shapes; - if (columnwise_usage) { - // Tensor sizes - for (size_t i = 0; i < num_tensors; ++i) { - // push the transposed shape into NVFP4 columnwise shape - // NVFP4 on SM100 is TN only - columnwise_data_shapes.emplace_back(); - auto &shape = columnwise_data_shapes.back(); - shape.push_back(shape_list[i].back()); - for (size_t j = 0; j < shape_list[i].size() - 1; ++j) { - shape.push_back(shape_list[i][j]); - } - columnwise_scale_shapes.emplace_back( - quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true)); - } - - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets, amax_offsets; - for (size_t i = 0; i < num_tensors; ++i) { - // FP4 data is aligned to 256B - const auto offset = roundup(buffer_size, 256); - if (offset != buffer_size) { - contiguous_data_and_scale = false; - } - data_offsets.push_back(offset); - buffer_size = offset + (product(columnwise_data_shapes[i]) + 1) / 2; - } - for (size_t i = 0; i < num_tensors; ++i) { - // Scales are aligned to 16B - const auto offset = roundup(buffer_size, 16); - if (offset != buffer_size) { - contiguous_data_and_scale = false; - } - scale_offsets.push_back(offset); - buffer_size = offset + product(columnwise_scale_shapes[i]) * scale_elem_size; - } - for (size_t i = 0; i < num_tensors; ++i) { - // Amaxes (FP32) are aligned to 16B - // Note: Multi-quantize kernel does not require contiguous amaxes. - const auto offset = roundup(buffer_size, 16); - amax_offsets.push_back(offset); - buffer_size = offset + 4; - } - - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - - // Construct tensor views - for (size_t i = 0; i < num_tensors; ++i) { - columnwise_data_list.emplace_back(make_torch_view( - buffer, to_fp4_shape(columnwise_data_shapes[i]), data_offsets[i], torch::kUInt8)); - columnwise_scale_list.emplace_back( - make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); - amax_columnwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); - } + if (output_scale.has_value() && output_scale->numel() > 0) { + output_cu.set_scale(output_scale->data_ptr(), DType::kFloat32, scalar_shape); } - // Construct nvfp4 tensors - py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorStoragePythonClass)); - for (size_t i = 0; i < num_tensors; ++i) { - // Create tensor objects with proper reference counting - py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none(); - py::object rowwise_scale = rowwise_usage ? py::cast(rowwise_scale_list[i]) : py::none(); - py::object columnwise_data = - (columnwise_usage ? py::cast(columnwise_data_list[i]) : py::none()); - py::object columnwise_scale = - (columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none()); - py::object amax_rowwise = rowwise_usage ? py::cast(amax_rowwise_list[i]) : py::none(); - py::object amax_columnwise = columnwise_usage ? py::cast(amax_columnwise_list[i]) : py::none(); - - // Construct Python tensor - tensor_py_list.emplace_back(NVFP4TensorClass( - rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, amax_rowwise, - amax_columnwise, fp4_dtype, quantizer_py_list[i], with_gemm_swizzled_scales)); - - // Construct C++ tensor - // Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor, - // then set the amax and amax_columnwise values. - { - auto tensor_wrapper = makeTransformerEngineTensor( - rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, - columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_data_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_data_shapes[i] : std::vector{0}, fp4_dtype, - /*amax_ptr=*/nullptr, - /*scale_ptr=*/nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, - columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_scale_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode); - tensor_wrapper.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); - - // Set the amax rowwise and amax columnwise if available - if (rowwise_usage) { - tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32, - std::vector{1}); - } - if (columnwise_usage) { - tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32, - std::vector{1}); - } - - tensor_cpp_list.emplace_back(std::move(tensor_wrapper)); - } + QuantizationConfigWrapper quant_config; + std::optional noop_cu; + if (noop_flag.has_value()) { + noop_cu.emplace(makeTransformerEngineTensor(noop_flag.value())); + quant_config.set_noop_tensor(noop_cu->data()); } + quant_config.set_force_pow_2_scales(force_pow_2_scales); + quant_config.set_amax_epsilon(static_cast(amax_epsilon)); + quant_config.set_nvfp4_2d_quantization(nvfp4_2d_quantization); - return retval; + auto stream = getCurrentCUDAStreamRaw(input.get_device_index()); + nvte_quantize_v2(input_cu.data(), output_cu.data(), quant_config, stream); } -// Owns all allocations/wrappers backing quant_config_list[*].set_rng_state(...). -struct StochasticRngStateResources { - at::Tensor rng_states_tensor; // [2 * num_tensors], int64, CUDA - at::Tensor rng_states_tensor_colwise; // optional, same shape/dtype/device - std::vector te_rng_state_list; - std::vector te_rng_state_list_colwise; - - bool enabled{false}; - bool need_separate_rng_states{false}; - bool with_bulk_generate_rng_states{false}; -}; - -// Populates quant_config_list (+ optional colwise list) with rng_state pointers and stochastic flag. -static StochasticRngStateResources setup_stochastic_rounding_rng_states_helper( - size_t num_tensors, bool stochastic_rounding, bool with_bulk_generate_rng_states, - bool need_separate_rng_states, - std::vector &quant_config_list_rowwise, - std::vector &quant_config_list_colwise) { - // the return object will be used to keep rng states alive - StochasticRngStateResources res; - res.enabled = stochastic_rounding; - res.need_separate_rng_states = need_separate_rng_states; - res.with_bulk_generate_rng_states = with_bulk_generate_rng_states; - - if (!stochastic_rounding) return res; - - // Basic sanity: caller usually pre-sizes these to num_tensors. - TORCH_CHECK(quant_config_list_rowwise.size() == num_tensors, - "quant_config_list_rowwise must be sized to num_tensors"); - if (need_separate_rng_states) { - TORCH_CHECK(quant_config_list_colwise.size() == num_tensors, - "quant_config_list_colwise must be sized to num_tensors when " - "need_separate_rng_states=true"); - } - - const size_t rng_elts_per_thread = - res.with_bulk_generate_rng_states ? (1024 * num_tensors) : 1024; - - auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); - res.rng_states_tensor = torch::empty({static_cast(2 * num_tensors)}, opts); - if (need_separate_rng_states) { - res.rng_states_tensor_colwise = torch::empty({static_cast(2 * num_tensors)}, opts); +// ============================================================================ +// Quantize with amax: compute amax → compute scale → quantize +// This is the full current-scaling quantization pipeline. +// ============================================================================ + +void quantize_with_amax(Tensor input, Tensor output_data, int64_t output_te_dtype, + Tensor output_amax, Tensor output_scale, + std::optional output_scale_inv, int64_t scaling_mode, + bool force_pow_2_scales, double amax_epsilon, + std::optional noop_flag) { + auto shape = getStableTensorShape(input); + auto te_dtype = static_cast(output_te_dtype); + auto nvte_scaling = static_cast(scaling_mode); + + auto input_cu = makeTransformerEngineTensor(input); + auto output_cu = makeQuantizedTensorWrapper(output_data, te_dtype, shape, output_amax, + output_scale, output_scale_inv, nvte_scaling); + + QuantizationConfigWrapper quant_config; + std::optional noop_cu; + if (noop_flag.has_value()) { + noop_cu.emplace(makeTransformerEngineTensor(noop_flag.value())); + quant_config.set_noop_tensor(noop_cu->data()); } + quant_config.set_force_pow_2_scales(force_pow_2_scales); + quant_config.set_amax_epsilon(static_cast(amax_epsilon)); - res.te_rng_state_list.reserve(num_tensors); - if (need_separate_rng_states) res.te_rng_state_list_colwise.reserve(num_tensors); - - for (size_t i = 0; i < num_tensors; ++i) { - auto gen = at::get_generator_or_default( - std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); - - // Rowwise RNG state - at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); - int64_t *rng_state_ptr = static_cast(res.rng_states_tensor.data_ptr()) + i * 2; - philox_unpack(philox_args, rng_state_ptr); - - res.te_rng_state_list.push_back(makeTransformerEngineTensor( - static_cast(rng_state_ptr), std::vector{2}, DType::kInt64)); - quant_config_list_rowwise[i].set_rng_state(res.te_rng_state_list[i].data()); - quant_config_list_rowwise[i].set_stochastic_rounding(true); - - // Colwise RNG state (only if you truly need a different sequence) - if (need_separate_rng_states) { - // re-initialize philox_args for colwise RNG state - at::PhiloxCudaState philox_args_col = init_philox_state(gen, rng_elts_per_thread); - int64_t *rng_state_ptr_colwise = - static_cast(res.rng_states_tensor_colwise.data_ptr()) + i * 2; + auto stream = getCurrentCUDAStreamRaw(input.get_device_index()); - philox_unpack(philox_args_col, rng_state_ptr_colwise); + // Step 1: Compute amax from input, store in output's amax buffer + nvte_compute_amax_with_config(input_cu.data(), output_cu.data(), quant_config, stream); - res.te_rng_state_list_colwise.push_back(makeTransformerEngineTensor( - static_cast(rng_state_ptr_colwise), std::vector{2}, DType::kInt64)); - quant_config_list_colwise[i].set_rng_state(res.te_rng_state_list_colwise[i].data()); - quant_config_list_colwise[i].set_stochastic_rounding(true); - } + // Step 2: Compute scale from amax + nvte_compute_scale_from_amax(output_cu.data(), quant_config, stream); - // break the loop if we are using bulk generate rng states - if (res.with_bulk_generate_rng_states) break; - } - - return res; + // Step 3: Quantize using computed scale + // Clear amax before quantize to avoid atomic conflicts + output_cu.set_amax(nullptr, DType::kFloat32, std::vector{1}); + nvte_quantize_v2(input_cu.data(), output_cu.data(), quant_config, stream); } -// Implements split-quantize NVFP4 with Row/Column-wise Hadamard Transform (RHT) -void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, - const std::vector &input_list, - std::vector &output_list, - const std::vector &split_sections, - const std::vector &quantizers, - cudaStream_t stream) { - const size_t num_tensors = split_sections.size(); - const auto &quantizer = *quantizers.front(); - - std::vector nvte_tensor_input_list; - std::vector nvte_tensor_output_list; - for (size_t i = 0; i < num_tensors; ++i) { - nvte_tensor_input_list.push_back(input_list[i].data()); - nvte_tensor_output_list.push_back(output_list[i].data()); - } - - // trigger the row-col fusion when the split-sections shapes are all 128 aligned for max performance - bool all_aligned_token_dim = - std::all_of(split_sections.begin(), split_sections.end(), - [](size_t split_section) { return split_section % 128 == 0; }); - - // in the case when rowwise and colwise cannot be fused, we have to generate the RNG states twice - // so that rowwise and colwise will have different random numbers - bool need_separate_rng_states = - (!all_aligned_token_dim) && quantizer.rowwise_usage && quantizer.columnwise_usage; - - // Objects for TE C API - std::vector quant_config_list; - std::vector quant_config_list_colwise; - for (size_t i = 0; i < num_tensors; ++i) { - quant_config_list.emplace_back(QuantizationConfigWrapper()); - quant_config_list_colwise.emplace_back(QuantizationConfigWrapper()); - } - - // this is true because we have already built grouped kernels for rowwise and colwise quantization with RHT - bool with_bulk_generate_rng_states = true; - - // Stochastic rounding - bool need_stochastic_rounding = quantizer.stochastic_rounding; - auto stochastic_rng_state_resources = setup_stochastic_rounding_rng_states_helper( - num_tensors, need_stochastic_rounding, with_bulk_generate_rng_states, - need_separate_rng_states, quant_config_list, quant_config_list_colwise); - - // Enable NVFP4 kernels to use math operations that sacrifice - // accuracy for performance. These optimizations are experimental - // and inconsistently implemented. - // What math is accelerated? Only the high precision math, so numerical impact is minimal - // 1. replace 1 / x by reciprocal_approximate_ftz(x) - // 2. when RHT cast fusion is available, fusion allows cast to be performed on FP32 data, - // this will essentially remove a round trip between FP32 to BF16 then FP32 - const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); - if (use_fast_math) { - for (auto &config : quant_config_list) { - config.set_use_fast_math(true); - } - for (auto &config : quant_config_list_colwise) { - config.set_use_fast_math(true); - } - } - - auto &quant_config_list_colwise_to_use = - need_separate_rng_states ? quant_config_list_colwise : quant_config_list; - - // Compute amaxes - if (quantizer.with_post_rht_amax) { - // We need: - // 1. Rowwise amax = amax for input - // 2. Columnwise amax = amax for RHT(input.t) - nvte_group_hadamard_transform_amax( - input.data(), reinterpret_cast(nvte_tensor_output_list.data()), - split_sections.data(), num_tensors, 0, quantizer.rht_matrix_random_sign_mask_t, stream); - } else { - // RHT is enabled, but amax is pre-RHT amax - NVTE_ERROR("NVFP4 split-quantize does not yet support pre-RHT amax"); +// ============================================================================ +// Quantize from pre-computed amax (skip amax computation) +// Used after FUSED_NORM_AMAX path where norm kernel already wrote amax. +// ============================================================================ + +void quantize_from_amax(Tensor input, Tensor output_data, int64_t output_te_dtype, + Tensor output_amax, Tensor output_scale, + std::optional output_scale_inv, int64_t scaling_mode, + bool force_pow_2_scales, double amax_epsilon, + std::optional noop_flag) { + auto shape = getStableTensorShape(input); + auto te_dtype = static_cast(output_te_dtype); + auto nvte_scaling = static_cast(scaling_mode); + + auto input_cu = makeTransformerEngineTensor(input); + auto output_cu = makeQuantizedTensorWrapper(output_data, te_dtype, shape, output_amax, + output_scale, output_scale_inv, nvte_scaling); + + QuantizationConfigWrapper quant_config; + std::optional noop_cu; + if (noop_flag.has_value()) { + noop_cu.emplace(makeTransformerEngineTensor(noop_flag.value())); + quant_config.set_noop_tensor(noop_cu->data()); } + quant_config.set_force_pow_2_scales(force_pow_2_scales); + quant_config.set_amax_epsilon(static_cast(amax_epsilon)); - // Check that RHT matrix is available - NVTE_CHECK(quantizer.rht_matrix.defined() && quantizer.rht_matrix.numel() > 0, - "RHT matrix is not available."); - auto rht_matrix_nvte = makeTransformerEngineTensor(quantizer.rht_matrix); - - if (all_aligned_token_dim) { - // allocate a tile scheduler workspace - auto tile_scheduler_workspace_torch = - at::empty({1}, at::device(at::kCUDA).dtype(torch::kInt32)); - auto nvte_tile_scheduler_workspace = - makeTransformerEngineTensor(tile_scheduler_workspace_torch); - // call the fully-fused grouped kernel for rowwise quantization & colwise RHT quantization transpose - nvte_group_hadamard_transform_cast_fusion( - input.data(), reinterpret_cast(nvte_tensor_output_list.data()), - rht_matrix_nvte.data(), split_sections.data(), num_tensors, quant_config_list[0], - nvte_tile_scheduler_workspace.data(), stream); - } else { - // Separate quantization for rowwise usage and columnwise usage - // Rowwise quantization fusion with grouped version - if (quantizer.rowwise_usage) { - std::vector out_identity_list; - std::vector nvte_tensor_out_identity_list; - for (size_t i = 0; i < num_tensors; i++) { - bool is_empty_split = input_list[i].numel() == 0; - TensorWrapper out_identity(output_list[i].scaling_mode()); - auto out_identity_data = output_list[i].get_rowwise_data(); - auto out_identity_scale_inv = output_list[i].get_rowwise_scale_inv(); - auto out_identity_amax = output_list[i].get_amax(); - if (!is_empty_split) { - out_identity.set_rowwise_data(out_identity_data.data_ptr, - static_cast(out_identity_data.dtype), - out_identity_data.shape); - out_identity.set_rowwise_scale_inv(out_identity_scale_inv.data_ptr, - static_cast(out_identity_scale_inv.dtype), - out_identity_scale_inv.shape); - out_identity.set_amax(out_identity_amax.data_ptr, - static_cast(out_identity_amax.dtype), - out_identity_amax.shape); - } - out_identity_list.emplace_back(std::move(out_identity)); - nvte_tensor_out_identity_list.push_back(out_identity_list.back().data()); - } - nvte_group_nvfp4_quantize_with_amax(input.data(), nvte_tensor_out_identity_list.data(), - split_sections.data(), num_tensors, quant_config_list[0], - stream); - } - - // Columnwise RHT quantization fusion with grouped version - if (quantizer.columnwise_usage) { - std::vector out_transpose_list; - std::vector nvte_tensor_out_transpose_list; - for (size_t i = 0; i < num_tensors; i++) { - bool is_empty_split = input_list[i].numel() == 0; - auto out_columnwise_data = output_list[i].get_columnwise_data(); - auto out_columnwise_scale_inv = output_list[i].get_columnwise_scale_inv(); - auto out_columnwise_amax = output_list[i].get_columnwise_amax(); + auto stream = getCurrentCUDAStreamRaw(input.get_device_index()); - // Create a wrapper for the columnwise output, as the rowwise output. Input is in transposed layout. - TensorWrapper out_transpose(output_list[i].scaling_mode()); - if (!is_empty_split) { - auto colwise_data_shape = out_columnwise_data.shape; - std::vector colwise_data_shape_2d; - colwise_data_shape_2d.push_back(colwise_data_shape.data[0]); - size_t last_dim = 1; - for (size_t j = 1; j < colwise_data_shape.ndim; ++j) { - last_dim *= colwise_data_shape.data[j]; - } - colwise_data_shape_2d.push_back(last_dim); - - out_transpose.set_rowwise_data(out_columnwise_data.data_ptr, - static_cast(out_columnwise_data.dtype), - colwise_data_shape_2d); - out_transpose.set_rowwise_scale_inv(out_columnwise_scale_inv.data_ptr, - static_cast(out_columnwise_scale_inv.dtype), - out_columnwise_scale_inv.shape); - out_transpose.set_amax(out_columnwise_amax.data_ptr, - static_cast(out_columnwise_amax.dtype), - out_columnwise_amax.shape); - } - out_transpose_list.emplace_back(std::move(out_transpose)); - nvte_tensor_out_transpose_list.push_back(out_transpose_list.back().data()); - } - nvte_group_hadamard_transform_cast_fusion_columnwise( - input.data(), reinterpret_cast(nvte_tensor_out_transpose_list.data()), - rht_matrix_nvte.data(), split_sections.data(), num_tensors, - quant_config_list_colwise_to_use[0], stream); - } - } + // Amax is already computed (by fused norm kernel) — just compute scale + quantize + nvte_compute_scale_from_amax(output_cu.data(), quant_config, stream); + output_cu.set_amax(nullptr, DType::kFloat32, std::vector{1}); + nvte_quantize_v2(input_cu.data(), output_cu.data(), quant_config, stream); } -void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, - const std::vector &input_list, - std::vector &output_list, - const std::vector &split_sections, - const std::vector &quantizers, - cudaStream_t stream) { - const size_t num_tensors = input_list.size(); - const auto &quantizer = *quantizers.front(); - - std::vector nvte_tensor_input_list; - std::vector nvte_tensor_output_list; - for (size_t i = 0; i < num_tensors; ++i) { - nvte_tensor_input_list.push_back(input_list[i].data()); - nvte_tensor_output_list.push_back(output_list[i].data()); - } +// ============================================================================ +// Dequantize: input (fp8) → output (hp) +// ============================================================================ - // In this case without RHT, the rowwise and colwise quantization are fused - // we don't need separate rng states for rowwise and colwise - bool need_separate_rng_states = false; +Tensor dequantize(Tensor input_data, int64_t input_te_dtype, std::optional input_scale_inv, + std::optional input_amax, int64_t scaling_mode, int64_t output_te_dtype) { + auto shape = getStableTensorShape(input_data); + auto in_te_dtype = static_cast(input_te_dtype); + auto out_te_dtype = static_cast(output_te_dtype); + auto nvte_scaling = static_cast(scaling_mode); - // Objects for TE C API - std::vector quant_config_list; - for (size_t i = 0; i < num_tensors; ++i) { - quant_config_list.emplace_back(QuantizationConfigWrapper()); + // FP4 data is packed (2 elements per byte). Report logical element count. + if (is_fp4_dtype(in_te_dtype) && !shape.empty()) { + shape.back() *= 2; } - // TODO: this is only true because the non-RHT path doesn't have grouped kernels yet, which we can be optimized - // so that we can generate all rng states at once - bool with_bulk_generate_rng_states = false; - - bool need_stochastic_rounding = quantizer.stochastic_rounding; + auto input_cu = makeQuantizedTensorWrapper(input_data, in_te_dtype, shape, input_amax, + std::nullopt, input_scale_inv, nvte_scaling); - // place holder for colwise rng states, which are not needed in this case - std::vector dummy_quant_config_list_colwise; + auto output = allocateStableTensor(std::vector(shape.begin(), shape.end()), out_te_dtype, + input_data.get_device_index()); + auto output_cu = makeTransformerEngineTensor(output); - auto stochastic_rng_state_resources = setup_stochastic_rounding_rng_states_helper( - num_tensors, need_stochastic_rounding, with_bulk_generate_rng_states, - need_separate_rng_states, quant_config_list, - dummy_quant_config_list_colwise); // colwise rng states are not needed in this case + nvte_dequantize(input_cu.data(), output_cu.data(), + getCurrentCUDAStreamRaw(input_data.get_device_index())); - // We need: - // 1. Rowwise amax = amax for input - // 2. Columnwise amax = amax for input too - // Columnwise amax will be filled with a fused D2D copy from rowwise amax - // Note that the multi compute amax API expects rowwise amax pointer to be not null - // So we need to set the pointer accordingly to make colwise-only quantization work - std::vector orig_amax_ptr_list; - for (size_t i = 0; i < num_tensors; i++) { - auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr; - orig_amax_ptr_list.push_back(rowwise_amax_ptr); - auto columnwise_amax_ptr = output_list[i].get_columnwise_amax().data_ptr; - void *amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; - NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); - output_list[i].set_amax(amax_ptr, DType::kFloat32, std::vector{1}); - } - nvte_group_amax(input.data(), reinterpret_cast(nvte_tensor_output_list.data()), - split_sections.data(), num_tensors, stream); - for (size_t i = 0; i < num_tensors; i++) { - output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, std::vector{1}); - } - - // Quantize tensors individually - for (size_t i = 0; i < num_tensors; i++) { - // skip this round if input is empty - if (input_list[i].numel() == 0) { - continue; - } - nvte_quantize_v2(input_list[i].data(), output_list[i].data(), quant_config_list[i], stream); - } + return output; } -void split_quantize_nvfp4_impl(const TensorWrapper &input, - const std::vector &input_list, - std::vector &output_list, - const std::vector &split_sections, - const std::vector &quantizers) { - // Check tensor lists - const size_t num_tensors = split_sections.size(); - NVTE_CHECK(input_list.size() == num_tensors, "Expected ", num_tensors, " input tensors, but got ", - input_list.size(), "."); - NVTE_CHECK(output_list.size() == num_tensors, "Expected ", num_tensors, - " output tensors, but got ", output_list.size(), "."); - NVTE_CHECK(quantizers.size() == num_tensors, "Expected ", num_tensors, - " NVFP4 quantizers, but got ", quantizers.size(), "."); - - // sanity check all the quantizers have the same scaling mode - bool all_same_scaling_mode = - std::all_of(quantizers.begin(), quantizers.end(), [&](const NVFP4Quantizer *quantizer) { - return quantizer->get_scaling_mode() == quantizers.front()->get_scaling_mode(); - }); - NVTE_CHECK(all_same_scaling_mode, "All quantizers must have the same scaling mode"); - - // Trivial cases - if (num_tensors == 0) { - return; - } - if (input.numel() == 0) { - for (const auto &tensor : input_list) { - NVTE_CHECK(tensor.numel() == 0, - "Input tensor has zero elements but got split with non-zero elements"); - } - return; - } - - // Assume all quantizers have identical config - const auto &quantizer = *quantizers.front(); - NVTE_CHECK(!quantizer.with_2d_quantization, - "NVFP4 split-quantize does not support 2D quantization"); - NVTE_CHECK(!quantizer.with_amax_reduction, - "NVFP4 split-quantize does not support amax reduction"); - - // Check input tensor shape - const size_t input_last_dim = input.ndim() > 0 ? input.size(input.ndim() - 1) : 1; - NVTE_CHECK(input_last_dim % 128 == 0, - "NVFP4 multi-quantize requires inner dim to be multiple of 128."); - - // CUDA stream - auto stream = at::cuda::getCurrentCUDAStream(); - - // Perform multi-tensor quantization - NVTE_SCOPED_GIL_RELEASE({ - if (quantizer.with_rht) { // Quantize row-wise data, RHT+quantize column-wise data - // Check that config is supported - NVTE_CHECK(input.dtype() == DType::kBFloat16, "RHT is only supported for bfloat16 input"); - // Fuse the rowwise and colwise into one when the kernel is ready - split_quantize_nvfp4_impl_with_rht_helper(input, input_list, output_list, split_sections, - quantizers, stream); - } else { // NVFP4 quantize - // Fuse the rowwise and colwise into one when the kernel is ready - split_quantize_nvfp4_impl_helper(input, input_list, output_list, split_sections, quantizers, - stream); - } - }); +} // namespace transformer_engine::pytorch::stable + +STABLE_TORCH_LIBRARY_FRAGMENT(transformer_engine_stable, m) { + m.def( + "quantize(Tensor input, Tensor output_data, int output_te_dtype, Tensor? output_amax, " + "Tensor? output_scale, Tensor? output_scale_inv, int scaling_mode, bool force_pow_2_scales, " + "float amax_epsilon, Tensor? noop_flag, bool nvfp4_2d_quantization=False) -> ()"); + m.def( + "quantize_with_amax(Tensor input, Tensor output_data, int output_te_dtype, Tensor " + "output_amax, Tensor output_scale, Tensor? output_scale_inv, int scaling_mode, bool " + "force_pow_2_scales, float amax_epsilon, Tensor? noop_flag) -> ()"); + m.def( + "quantize_from_amax(Tensor input, Tensor output_data, int output_te_dtype, Tensor " + "output_amax, Tensor output_scale, Tensor? output_scale_inv, int scaling_mode, bool " + "force_pow_2_scales, float amax_epsilon, Tensor? noop_flag) -> ()"); + m.def( + "quantize_bidirectional(Tensor input, Tensor output_rowwise_data, int output_te_dtype, " + "Tensor? output_amax, Tensor? output_scale, Tensor output_rowwise_scale_inv, " + "Tensor output_columnwise_data, Tensor output_columnwise_scale_inv, " + "int scaling_mode, bool force_pow_2_scales, float amax_epsilon, Tensor? noop_flag, " + "bool nvfp4_2d_quantization=False) -> ()"); + m.def( + "dequantize(Tensor input_data, int input_te_dtype, Tensor? input_scale_inv, Tensor? " + "input_amax, int scaling_mode, int output_te_dtype) -> Tensor"); } -} // namespace - -std::vector split_quantize(const at::Tensor &tensor, - const std::vector &split_sections, - std::vector quantizer_list, - bool disable_bulk_allocation) { - init_extension(); - - // Check number of tensors - const size_t num_splits = split_sections.size(); - NVTE_CHECK(quantizer_list.size() == num_splits, "Expected ", num_splits, " quantizers, but got ", - quantizer_list.size()); - if (num_splits == 0) { - return {}; - } - - // Input tensor properties - auto input_py = tensor.contiguous(); - uint8_t *input_dptr = reinterpret_cast(input_py.data_ptr()); - auto input_dtype = GetTransformerEngineDType(input_py.scalar_type()); - std::vector input_shape; - size_t input_size = 1; - for (const auto &d : input_py.sizes()) { - input_shape.push_back(d); - input_size *= d; - } - NVTE_CHECK(input_shape.size() > 0, "Input tensor has 0 dims"); - - // Split input tensor along dim 0 - std::vector input_list; - std::vector> split_shapes; - size_t dim0_offset = 0; - const size_t dim0_stride = - input_shape[0] == 0 ? 0 : input_py.element_size() * input_size / input_shape[0]; - for (size_t i = 0; i < num_splits; ++i) { - NVTE_CHECK(dim0_offset + split_sections[i] <= input_shape[0], - "Attempted to split tensor with shape=", input_shape, - " along dim 0 with split_sections=", split_sections); - split_shapes.push_back(input_shape); - auto &split_shape = split_shapes.back(); - split_shape[0] = split_sections[i]; - void *split_dptr = static_cast(input_dptr + dim0_offset * dim0_stride); - input_list.emplace_back(makeTransformerEngineTensor(split_dptr, split_shape, input_dtype)); - dim0_offset += split_sections[i]; - } - - // Convert quantizers to C++ objects - std::vector> quantizer_cpp_list; - for (size_t i = 0; i < num_splits; i++) { - quantizer_cpp_list.push_back(convert_quantizer(quantizer_list[i])); - } - - // Choose implementation for allocating and populating tensors - enum class AllocationMethod { UNFUSED, BULK_FP8_BLOCKWISE, BULK_MXFP8, BULK_NVFP4 }; - enum class QuantizationMethod { UNFUSED, FUSED_NVFP4 }; - AllocationMethod allocation_method = AllocationMethod::UNFUSED; - QuantizationMethod quantization_method = QuantizationMethod::UNFUSED; - if (!disable_bulk_allocation) { - if (std::all_of(quantizer_list.begin(), quantizer_list.end(), - [](const py::handle &quantizer) -> bool { - return detail::IsFloat8BlockwiseQuantizers(quantizer.ptr()); - })) { - allocation_method = AllocationMethod::BULK_FP8_BLOCKWISE; - } else if (std::all_of(quantizer_list.begin(), quantizer_list.end(), - [](const py::handle &quantizer) -> bool { - return detail::IsMXFP8Quantizers(quantizer.ptr()); - })) { - allocation_method = AllocationMethod::BULK_MXFP8; - } else if (std::all_of(quantizer_list.begin(), quantizer_list.end(), - [](const py::handle &quantizer) -> bool { - return detail::IsNVFP4Quantizers(quantizer.ptr()); - })) { - allocation_method = AllocationMethod::BULK_NVFP4; - quantization_method = QuantizationMethod::FUSED_NVFP4; - } - } - - // Allocate output tensors - std::vector output_cpp_list; - std::vector output_py_list; - switch (allocation_method) { - case AllocationMethod::BULK_FP8_BLOCKWISE: { - // Bulk allocation for FP8 block-scaling tensors - std::vector blockwise_quantizers; - for (auto &quantizer : quantizer_cpp_list) { - blockwise_quantizers.push_back(static_cast(quantizer.get())); - } - std::tie(output_py_list, output_cpp_list) = - bulk_allocate_fp8_blockwise_tensors(split_shapes, quantizer_list, blockwise_quantizers); - break; - } - case AllocationMethod::BULK_MXFP8: { - // Bulk allocation for MXFP8 tensors - std::vector mxfp8_quantizers; - for (auto &quantizer : quantizer_cpp_list) { - mxfp8_quantizers.push_back(static_cast(quantizer.get())); - } - std::tie(output_py_list, output_cpp_list) = - bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers); - break; - } - case AllocationMethod::BULK_NVFP4: { - // Bulk allocation for NVFP4 tensors - std::vector nvfp4_quantizers; - for (auto &quantizer : quantizer_cpp_list) { - nvfp4_quantizers.push_back(static_cast(quantizer.get())); - } - bool contiguous_data_and_scale = false; - std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale) = - bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers); - if (!input_shape.empty() && input_shape.back() % 128 != 0) { - static std::once_flag once_unfused_nvfp4_fallback_warning; - std::call_once(once_unfused_nvfp4_fallback_warning, []() { - NVTE_WARN( - "Unfused NVFP4 quantization fallback is triggered because the input tensor inner " - "dimension is not a multiple of 128, disabling NVFP4 grouped kernel fusion. " - "NVFP4 might bring performance regressions for this input tensor shape."); - }); - quantization_method = QuantizationMethod::UNFUSED; - } - if (!contiguous_data_and_scale) { - // Avoid fused quantize kernel if data is not contiguous - quantization_method = QuantizationMethod::UNFUSED; - } - break; - } - default: { - // Allocate output tensors individually - for (size_t i = 0; i < num_splits; ++i) { - auto [output_cpp, output_py] = - quantizer_cpp_list[i]->create_tensor(split_shapes[i], input_dtype); - output_cpp_list.emplace_back(std::move(output_cpp)); - output_py_list.emplace_back(std::move(output_py)); - } - } - } - - // Quantize into output tensors - switch (quantization_method) { - case QuantizationMethod::FUSED_NVFP4: { - // Fused NVFP4 quantize kernel - auto input_nvte = makeTransformerEngineTensor(input_dptr, input_shape, input_dtype); - std::vector nvfp4_quantizers; - for (auto &quantizer : quantizer_cpp_list) { - nvfp4_quantizers.push_back(static_cast(quantizer.get())); - } - split_quantize_nvfp4_impl(input_nvte, input_list, output_cpp_list, split_sections, - nvfp4_quantizers); - break; - } - default: - // General multi-tensor quantization - multi_tensor_quantize_impl(input_list, quantizer_list, quantizer_cpp_list, output_cpp_list); - } - - return output_py_list; +STABLE_TORCH_LIBRARY_IMPL(transformer_engine_stable, CUDA, m) { + using namespace transformer_engine::pytorch::stable; + m.impl("quantize", TORCH_BOX(quantize)); + m.impl("quantize_with_amax", TORCH_BOX(quantize_with_amax)); + m.impl("quantize_from_amax", TORCH_BOX(quantize_from_amax)); + m.impl("quantize_bidirectional", TORCH_BOX(quantize_bidirectional)); + m.impl("dequantize", TORCH_BOX(dequantize)); } - -} // namespace pytorch -} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index a126ab0d60..10275fc22e 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -4,317 +4,530 @@ * See LICENSE for license information. ************************************************************************/ -#include "../extensions.h" -#include "transformer_engine/transformer_engine.h" +#include +#include -#define HALF_BYTES 2 -#define UB_MAX_SM 32 +#include +#include -using namespace torch::indexing; -using namespace std::placeholders; +#include "../stable_common.h" +namespace transformer_engine::pytorch::stable { + +using Tensor = torch::stable::Tensor; namespace te = transformer_engine; -/*************************************************************************************************** - * CommOverlapHelper - **************************************************************************************************/ - -CommOverlapHelper::CommOverlapHelper() { -#ifndef NVTE_UB_WITH_MPI - NVTE_ERROR("Internal TE error: Dummy CommOverlapHelper init without NVTE_UB_WITH_MPI=1!"); -#endif -} // empty constructor for NVTE_UB_WITH_MPI=1 - -CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, - std::optional intra_domain_group) { -#ifndef NVTE_UB_WITH_MPI - pgs.insert({"world", world_group}); - myrank = pgs["world"]->getRank(); - numranks = pgs["world"]->getSize(); - c10d::ProcessGroup::BackendType backend = pgs["world"]->getBackendType(); - backend_is_nccl = (backend == c10d::ProcessGroup::BackendType::NCCL); - - if (intra_domain_group.has_value()) { - // Get local rank on node and number of local ranks - NVTE_CHECK(intra_domain_group.value()->getBackendType() == backend, - "Internal TE error: Intra-node group must be on the same backend (%s) as the world ", - "group!", pgs["world"]->getBackendName()); - pgs.insert({"intra", intra_domain_group.value()}); - mylocal = pgs["intra"]->getRank(); - numlocal = pgs["intra"]->getSize(); - - if (numlocal == numranks) { - // Intra-node group is same as the world group so there can only be 1 node - NVTE_CHECK( - mylocal == myrank, - "Internal TE error: Local rank must be equal to global rank when intra-node group size ", - "is equal to the world group size!"); - mynode = 0; - numnodes = 1; - } else { - // Get node ID and number of nodes - mynode = myrank / numlocal; - numnodes = numranks / numlocal; - } - } else { - // Intra-node group is not set so we assume there is only 1 node - mylocal = myrank; - numlocal = numranks; - pgs.insert({"intra", world_group}); +// ============================================================================ +// CommOverlap object registry +// +// CommOverlap objects are created by the Python shim and stored here. +// The stable ABI passes opaque int64_t handles (pointers cast to int). +// ============================================================================ + +static std::mutex g_comm_overlap_mutex; +static std::unordered_map> g_comm_overlaps; +static std::unordered_map> g_comm_overlaps_p2p; +static int64_t g_next_handle = 1; + +// ============================================================================ +// Allgather/barrier callback registration +// +// Python registers a callback pair that implements allgather/barrier using +// torch.distributed. The callbacks are stored here and passed to +// CommOverlapCore during construction. +// ============================================================================ + +using AllgatherCallback = void (*)(void* global, size_t global_bytes, void* local, + size_t local_bytes, const char* group); +using BarrierCallback = void (*)(const char* group); + +static AllgatherCallback g_allgather_cb = nullptr; +static BarrierCallback g_barrier_cb = nullptr; + +void register_comm_callbacks(int64_t allgather_fn_ptr, int64_t barrier_fn_ptr) { + g_allgather_cb = reinterpret_cast(allgather_fn_ptr); + g_barrier_cb = reinterpret_cast(barrier_fn_ptr); +} - mynode = 0; - numnodes = 1; +// ============================================================================ +// CommOverlapBase construction/destruction +// ============================================================================ + +int64_t create_comm_overlap(std::vector buffer_shape, int64_t buffer_dtype, int64_t myrank, + int64_t numranks, int64_t mylocal, int64_t numlocal, int64_t mynode, + int64_t numnodes, int64_t tp_size, int64_t num_splits, + int64_t num_max_streams, int64_t comm_cga_size, int64_t gemm_priority, + int64_t comm_priority, int64_t num_comm_sm, bool set_sm_margin, + bool atomic_gemm, bool rs_overlap_first_gemm) { + std::vector shape(buffer_shape.begin(), buffer_shape.end()); + auto dtype = static_cast(buffer_dtype); + + ExtAllgatherOp allgather_op; + ExtBarrierOp barrier_op; + + if (g_allgather_cb && g_barrier_cb) { + allgather_op = [](void* g, size_t gb, void* l, size_t lb, ExtComm comm) { + g_allgather_cb(g, gb, l, lb, comm); + }; + barrier_op = [](ExtComm comm) { g_barrier_cb(comm); }; } - initialized = true; -#else - NVTE_ERROR("Internal TE error: CommOverlapHelper cannot be initialized with valid PyTorch ", - "distributed process groups when TE is compiled with NVTE_UB_WITH_MPI=1!"); -#endif + auto co = std::make_unique( + shape, dtype, static_cast(myrank), static_cast(numranks), static_cast(mylocal), + static_cast(numlocal), static_cast(mynode), static_cast(numnodes), + static_cast(tp_size), allgather_op, barrier_op, static_cast(num_splits), + static_cast(num_max_streams), static_cast(comm_cga_size), + static_cast(gemm_priority), static_cast(comm_priority), + static_cast(num_comm_sm), set_sm_margin, atomic_gemm, rs_overlap_first_gemm); + + std::lock_guard lock(g_comm_overlap_mutex); + int64_t handle = g_next_handle++; + g_comm_overlaps[handle] = std::move(co); + return handle; } -CommOverlapHelper::~CommOverlapHelper() { -#ifndef NVTE_UB_WITH_MPI - for (auto &pg : pgs) pg.second = nullptr; - backend_is_nccl = false; - initialized = false; -#endif +void destroy_comm_overlap(int64_t handle) { + std::lock_guard lock(g_comm_overlap_mutex); + g_comm_overlaps.erase(handle); } -void CommOverlapHelper::ub_allgather(void *globaldata, size_t globalbytes, void *localdata, - size_t localbytes, ExtComm group) { -#ifndef NVTE_UB_WITH_MPI - NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ", - "with valid process groups!"); - - auto localtensor = - torch::from_blob(localdata, {static_cast(localbytes / sizeof(uint8_t))}, - at::device(torch::kCPU).dtype(torch::kUInt8)); - auto localtmp = (backend_is_nccl) ? localtensor.cuda() : localtensor; - auto globaltensor = - torch::from_blob(globaldata, {static_cast(globalbytes / sizeof(uint8_t))}, - at::device(torch::kCPU).dtype(torch::kUInt8)); - auto globaltmp = (backend_is_nccl) ? globaltensor.cuda() : globaltensor; - - std::vector> globalchunks = {globaltmp.chunk(pgs[group]->getSize())}; - std::vector localchunk = {localtmp}; - auto work = pgs[group]->allgather(globalchunks, localchunk); - work->wait(); - - if (backend_is_nccl) { - globaltensor.copy_(globaltmp.cpu()); - globaltmp = torch::Tensor(); - localtmp = torch::Tensor(); +// ============================================================================ +// CommOverlapP2PBase construction/destruction +// ============================================================================ + +int64_t create_comm_overlap_p2p(std::vector buffer_shape, int64_t buffer_dtype, + int64_t myrank, int64_t numranks, int64_t mylocal, int64_t numlocal, + int64_t mynode, int64_t numnodes, int64_t tp_size, + int64_t comm_type, int64_t num_max_streams, int64_t comm_cga_size, + int64_t gemm_priority, int64_t comm_priority, int64_t num_comm_sm, + bool set_sm_margin, bool use_ce, bool atomic_gemm, bool aggregate) { + std::vector shape(buffer_shape.begin(), buffer_shape.end()); + auto dtype = static_cast(buffer_dtype); + + ExtAllgatherOp allgather_op; + ExtBarrierOp barrier_op; + + if (g_allgather_cb && g_barrier_cb) { + allgather_op = [](void* g, size_t gb, void* l, size_t lb, ExtComm comm) { + g_allgather_cb(g, gb, l, lb, comm); + }; + barrier_op = [](ExtComm comm) { g_barrier_cb(comm); }; } -#else - NVTE_ERROR("Internal TE error: CommOverlapHelper::ub_allgather is a no-op when TE is compiled ", - "with NVTE_UB_WITH_MPI=1!"); -#endif + + auto co = std::make_unique( + shape, dtype, static_cast(myrank), static_cast(numranks), static_cast(mylocal), + static_cast(numlocal), static_cast(mynode), static_cast(numnodes), + static_cast(tp_size), allgather_op, barrier_op, + static_cast(comm_type), static_cast(num_max_streams), + static_cast(comm_cga_size), static_cast(gemm_priority), + static_cast(comm_priority), static_cast(num_comm_sm), set_sm_margin, use_ce, + atomic_gemm, aggregate); + + std::lock_guard lock(g_comm_overlap_mutex); + int64_t handle = g_next_handle++; + g_comm_overlaps_p2p[handle] = std::move(co); + return handle; } -void CommOverlapHelper::ub_barrier(ExtComm group) { -#ifndef NVTE_UB_WITH_MPI - NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ", - "with valid process groups!"); - auto work = pgs[group]->barrier(); - work->wait(); -#else - NVTE_ERROR("Internal TE error: CommOverlapHelper::ub_barrier is a no-op when TE is compiled ", - "with NVTE_UB_WITH_MPI=1!"); -#endif +void destroy_comm_overlap_p2p(int64_t handle) { + std::lock_guard lock(g_comm_overlap_mutex); + g_comm_overlaps_p2p.erase(handle); } -/*************************************************************************************************** - * CommOverlap - **************************************************************************************************/ - -CommOverlap::CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, - CommOverlapHelper *helper, int tp_size, int num_splits, - int num_max_streams, int comm_cga_size, int gemm_priority, - int comm_priority, int num_comm_sm, bool set_sm_margin, bool atomic_gemm, - bool rs_overlap_first_gemm) - : te::CommOverlapBase(buffer_shape, te::pytorch::GetTransformerEngineDType(buffer_dtype), - helper->myrank, helper->numranks, helper->mylocal, helper->numlocal, - helper->mynode, helper->numnodes, tp_size, - std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), - std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, - num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, - set_sm_margin, atomic_gemm, rs_overlap_first_gemm) {} - -/* -** Helper function to copy input to _ubuf -*/ -void CommOverlap::copy_into_buffer(const at::Tensor &input, bool local_chunk) { - const auto &input_ = input.contiguous(); - - // Check element size - const size_t element_size = input.element_size(); - NVTE_CHECK(_ubuf.element_size() == element_size, - "Tried to copy data into a Userbuffers buffer but dtypes are not compatible ", - "(input dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(), - " bytes)"); - - // Input data - const size_t input_size = input_.numel(); - const void *src_ptr = input_.data_ptr(); - - // Userbuffers data - const size_t ubuf_size = _ubuf.numel(); - void *dst_ptr = _ubuf.dptr(); +// ============================================================================ +// Buffer operations (hot path wrappers — just pointer extraction) +// ============================================================================ + +static te::CommOverlapCore* get_core(int64_t handle) { + { + auto it = g_comm_overlaps.find(handle); + if (it != g_comm_overlaps.end()) return it->second.get(); + } + { + auto it = g_comm_overlaps_p2p.find(handle); + if (it != g_comm_overlaps_p2p.end()) return it->second.get(); + } + NVTE_ERROR("Invalid CommOverlap handle: ", handle); +} + +void comm_overlap_copy_into_buffer(Tensor input, int64_t handle, bool local_chunk) { + auto input_ = torch::stable::contiguous(input); + auto* co = get_core(handle); + + const size_t elem_size = input_.element_size(); + const size_t input_numel = static_cast(input_.numel()); + const void* src = input_.data_ptr(); + + const size_t ubuf_numel = co->get_ubuf().numel(); + void* dst = co->get_ubuf().dptr(); + int tp_size = co->get_tp_size(); + int tp_id = co->get_tp_id(); + if (local_chunk) { - NVTE_CHECK(input_size * _tp_size == ubuf_size, - "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", - "(input_size=", input_size, ", tensor_parallel_size=", _tp_size, - ", ubuf_size=", ubuf_size, ")"); - dst_ptr = (reinterpret_cast(dst_ptr) + (ubuf_size / _tp_size) * _tp_id * element_size); + NVTE_CHECK(input_numel * tp_size == ubuf_numel, "Invalid tensor for local chunk copy"); + dst = reinterpret_cast(dst) + (ubuf_numel / tp_size) * tp_id * elem_size; } else { - NVTE_CHECK(input_size == ubuf_size, - "Tried to copy an invalid tensor into a Userbuffers buffer ", - "(input_size=", input_size, ", ubuf_size=", ubuf_size, ")"); + NVTE_CHECK(input_numel == ubuf_numel, "Invalid tensor for buffer copy"); } - // Copy data - auto stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); - NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, input_size * element_size, - cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm)); + auto stream = getCurrentCUDAStreamRaw(input_.get_device_index()); + cudaMemcpyAsync(dst, src, input_numel * elem_size, cudaMemcpyDeviceToDevice, stream); } -at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional> shape) { - // Check buffer shape - const size_t ubuf_size = _ubuf.numel(); - if (shape) { - const size_t requested_size = transformer_engine::pytorch::product(*shape); - if (local_chunk) { - NVTE_CHECK(requested_size * _tp_size == ubuf_size, - "Invalid shape for local chunk of a Userbuffers buffer (requested shape=", *shape, - ", tensor_parallel_size=", _tp_size, ", ubuf_size=", ubuf_size, ")"); - } else { - NVTE_CHECK(requested_size == ubuf_size, - "Invalid shape for a Userbuffers buffer (requested shape=", *shape, - ", ubuf_size=", ubuf_size, ")"); - } - } else { - int64_t dim0 = _ubuf.size(0); - int64_t dim1 = _ubuf.size(1); - if (local_chunk) { - dim0 /= _tp_size; - } - shape = {dim0, dim1}; +Tensor comm_overlap_get_buffer(int64_t handle, bool local_chunk, int64_t dim0, int64_t dim1) { + auto* co = get_core(handle); + int tp_size = co->get_tp_size(); + int tp_id = co->get_tp_id(); + const auto& ubuf = co->get_ubuf(); + const size_t ubuf_numel = ubuf.numel(); + + if (dim0 <= 0 || dim1 <= 0) { + dim0 = static_cast(ubuf.size(0)); + dim1 = static_cast(ubuf.size(1)); + if (local_chunk) dim0 /= tp_size; } - // Data pointer - void *ubuf_ptr = _ubuf.dptr(); + void* ptr = ubuf.dptr(); if (local_chunk) { - ubuf_ptr = (reinterpret_cast(ubuf_ptr) + - (ubuf_size / _tp_size) * _tp_id * _ubuf.element_size()); + ptr = reinterpret_cast(ptr) + (ubuf_numel / tp_size) * tp_id * ubuf.element_size(); } - // Construct PyTorch tensor - const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype()); - return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA)); + auto dtype = GetStableScalarType(ubuf.dtype()); + auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); + std::vector shape = {dim0, dim1}; + std::vector strides = {dim1, 1}; + torch::headeronly::IntHeaderOnlyArrayRef size_ref(shape.data(), shape.size()); + torch::headeronly::IntHeaderOnlyArrayRef stride_ref(strides.data(), strides.size()); + torch::stable::Device device(torch::headeronly::DeviceType::CUDA, device_idx); + + return torch::stable::from_blob(ptr, size_ref, stride_ref, device, dtype); } -std::pair CommOverlap::get_communication_stream() { - // Return the same stream for both send and recv - return {at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device()), - at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device())}; +// Return communication stream as raw cudaStream_t (cast to int64) +// Python wraps with torch.cuda.ExternalStream +int64_t comm_overlap_get_stream(int64_t handle) { + auto it = g_comm_overlaps.find(handle); + if (it != g_comm_overlaps.end()) { + return reinterpret_cast(it->second->get_comm_stream()); + } + NVTE_ERROR("Invalid CommOverlapBase handle: ", handle); } -/*************************************************************************************************** - * CommOverlapP2P - **************************************************************************************************/ - -CommOverlapP2P::CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, - CommOverlapHelper *helper, int tp_size, - te::CommOverlapType comm_type, int num_max_streams, - int comm_cga_size, int gemm_priority, int comm_priority, - int num_comm_sm, bool set_sm_margin, bool atomic_gemm, bool use_ce, - bool aggregate) - : te::CommOverlapP2PBase( - buffer_shape, te::pytorch::GetTransformerEngineDType(buffer_dtype), helper->myrank, - helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, - tp_size, std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), - std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, num_max_streams, - comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, - atomic_gemm, aggregate) {} - -/* -** Copy input to _ubufs[0] -*/ -void CommOverlapP2P::copy_into_buffer(const at::Tensor &input, bool local_chunk) { - const auto &input_ = input.contiguous(); - - // Check element size - const size_t element_size = input.element_size(); - NVTE_CHECK(_ubuf.element_size() == element_size, - "Tried to copy data into a Userbuffers buffer but dtypes are not compatible ", - "(input dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(), - " bytes)"); - - // Input data - const size_t input_size = input_.numel(); - const void *src_ptr = input_.data_ptr(); - - // Userbuffers data - void *dst_ptr; - if (local_chunk) { - NVTE_CHECK(_ubufs[_tp_id].numel() == input_size, - "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", - "(input_size=", input_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); - dst_ptr = _ubufs[_tp_id].dptr(); - } else { - NVTE_CHECK(_ubuf.numel() == input_size, - "Tried to copy an invalid tensor into a Userbuffers buffer ", - "(input_size=", input_size, ", ubuf_size=", _ubuf.numel(), ")"); - dst_ptr = _ubuf.dptr(); +std::tuple comm_overlap_p2p_get_streams(int64_t handle) { + auto it = g_comm_overlaps_p2p.find(handle); + if (it != g_comm_overlaps_p2p.end()) { + auto& streams = it->second->get_send_streams(); + auto recv = it->second->get_recv_stream(); + return std::make_tuple(reinterpret_cast(streams.empty() ? nullptr : streams[0]), + reinterpret_cast(recv)); + } + NVTE_ERROR("Invalid CommOverlapP2PBase handle: ", handle); +} + +// Bulk overlap AG with external GEMM +void bulk_overlap_ag_with_external_gemm(int64_t handle, int64_t send_stream_ptr, + int64_t recv_stream_ptr) { + auto it = g_comm_overlaps.find(handle); + NVTE_CHECK(it != g_comm_overlaps.end(), "Invalid CommOverlapBase handle"); + auto main_stream = getCurrentCUDAStreamRaw(); + it->second->bulk_overlap_external_ag(reinterpret_cast(send_stream_ptr), + reinterpret_cast(recv_stream_ptr), + main_stream); +} + +// Query helpers +int64_t comm_overlap_get_tp_size(int64_t handle) { return get_core(handle)->get_tp_size(); } + +bool comm_overlap_is_atomic_gemm(int64_t handle) { return get_core(handle)->is_atomic_gemm(); } + +bool comm_overlap_is_p2p(int64_t handle) { return get_core(handle)->is_p2p_overlap(); } + +bool comm_overlap_is_fp8_ubuf(int64_t handle) { return get_core(handle)->is_fp8_ubuf(); } + +// ============================================================================ +// GEMM helpers (mirrors stable/gemm.cpp, kept local to this TU) +// ============================================================================ + +namespace { + +bool requiresScaleSwizzle(NVTEScalingMode scaling_mode) { + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: + case NVTE_NVFP4_1D_SCALING: + return true; + case NVTE_INVALID_SCALING: + NVTE_ERROR("Invalid scaling mode for swizzling scaling factors."); + default: + return false; } +} - // Copy data - NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, input_size * element_size, - cudaMemcpyDeviceToDevice, - (cudaStream_t)at::cuda::getCurrentCUDAStream())); +DType getScaleInvDtype(NVTEScalingMode scaling_mode) { + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: + return DType::kFloat8E8M0; + case NVTE_NVFP4_1D_SCALING: + return DType::kFloat8E4M3; + default: + return DType::kFloat32; + } } -at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional> shape) { - // Check buffer shape - if (shape) { - const size_t requested_size = transformer_engine::pytorch::product(*shape); - if (local_chunk) { - NVTE_CHECK(requested_size == _ubufs[_tp_id].numel(), - "Invalid shape for local chunk of a Userbuffers buffer (requested shape=", *shape, - ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); +Tensor swizzleScaleForGemm(const Tensor& data, int64_t te_dtype, const Tensor& scale_inv, + int64_t scaling_mode) { + auto tensor_dtype = static_cast(te_dtype); + auto tensor_scaling_mode = static_cast(scaling_mode); + auto data_shape = getStableTensorShape(data); + + auto input_tensor = makeQuantizedTensorWrapper(data, tensor_dtype, data_shape, std::nullopt, + std::nullopt, scale_inv, tensor_scaling_mode); + + auto input_scales_nvte = input_tensor.get_rowwise_scale_inv(); + auto scales_dtype = static_cast(input_scales_nvte.dtype); + std::vector scale_shape(input_scales_nvte.shape.data, + input_scales_nvte.shape.data + input_scales_nvte.shape.ndim); + auto output_scale_inv = allocateStableTensor(scale_shape, scales_dtype, data.get_device_index()); + + TensorWrapper input_nvte(tensor_scaling_mode); + input_nvte.set_rowwise_data(nullptr, tensor_dtype, data_shape); + input_nvte.set_rowwise_scale_inv(input_scales_nvte.data_ptr, scales_dtype, + input_scales_nvte.shape); + + TensorWrapper output_nvte(tensor_scaling_mode); + output_nvte.set_rowwise_data(nullptr, tensor_dtype, data_shape); + output_nvte.set_rowwise_scale_inv(output_scale_inv.data_ptr(), scales_dtype, + input_scales_nvte.shape); + output_nvte.set_with_gemm_swizzled_scales(true); + + nvte_swizzle_scaling_factors(input_nvte.data(), output_nvte.data(), + getCurrentCUDAStreamRaw(data.get_device_index())); + + return output_scale_inv; +} + +TensorWrapper buildInputTensorWrapper(const Tensor& rowwise_data, DType te_dtype, + const std::optional& rowwise_scale_inv, + const std::optional& colwise_data, + const std::optional& colwise_scale_inv, + NVTEScalingMode scaling_mode) { + DType si_dtype = getScaleInvDtype(scaling_mode); + + TensorWrapper out(scaling_mode); + if (rowwise_data.numel() > 0) { + auto shape = getStableTensorShape(rowwise_data); + out.set_rowwise_data(rowwise_data.data_ptr(), te_dtype, shape); + } + if (rowwise_scale_inv.has_value() && rowwise_scale_inv->numel() > 0) { + auto si_shape = getStableTensorShape(*rowwise_scale_inv); + out.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), si_dtype, si_shape); + } + if (colwise_data.has_value() && colwise_data->numel() > 0) { + auto cw_shape = getStableTensorShape(*colwise_data); + out.set_columnwise_data(colwise_data->data_ptr(), te_dtype, cw_shape); + if (colwise_scale_inv.has_value() && colwise_scale_inv->numel() > 0) { + auto csi_shape = getStableTensorShape(*colwise_scale_inv); + out.set_columnwise_scale_inv(colwise_scale_inv->data_ptr(), si_dtype, csi_shape); + } + } + return out; +} + +} // namespace + +// ============================================================================ +// GEMM with comm overlap +// ============================================================================ + +void gemm_with_comm_overlap( + Tensor A_data, int64_t A_te_dtype, std::optional A_scale_inv, + std::optional A_colwise_data, std::optional A_colwise_scale_inv, + int64_t A_scaling_mode, bool A_with_gemm_swizzled_scales, bool transa, Tensor B_data, + int64_t B_te_dtype, std::optional B_scale_inv, std::optional B_colwise_data, + std::optional B_colwise_scale_inv, int64_t B_scaling_mode, + bool B_with_gemm_swizzled_scales, bool transb, Tensor D_data, int64_t D_te_dtype, + std::optional D_amax, std::optional D_scale, std::optional D_scale_inv, + int64_t D_scaling_mode, std::optional bias, int64_t bias_type, + std::optional pre_gelu_out, Tensor workspace, bool grad, bool accumulate, + bool use_split_accumulator, int64_t overlap_handle, int64_t comm_type, bool bulk_overlap_flag, + std::optional extra_output) { + auto A_te = static_cast(A_te_dtype); + auto B_te = static_cast(B_te_dtype); + auto D_te = static_cast(D_te_dtype); + auto A_sm = static_cast(A_scaling_mode); + auto B_sm = static_cast(B_scaling_mode); + auto D_sm = static_cast(D_scaling_mode); + + auto D_shape = getStableTensorShape(D_data); + + // Swizzle scales for MXFP8/NVFP4 if not already pre-swizzled. + std::vector swizzled_scale_inverses; + if (!A_with_gemm_swizzled_scales && requiresScaleSwizzle(A_sm)) { + if (transa) { + if (A_scale_inv.has_value() && A_scale_inv->numel() > 0) { + swizzled_scale_inverses.emplace_back( + swizzleScaleForGemm(A_data, A_te_dtype, *A_scale_inv, A_scaling_mode)); + A_scale_inv = swizzled_scale_inverses.back(); + } } else { - NVTE_CHECK(requested_size == _ubuf.numel(), - "Invalid shape for a Userbuffers buffer (requested shape=", *shape, - ", ubuf_size=", _ubuf.numel(), ")"); + if (A_colwise_data.has_value() && A_colwise_scale_inv.has_value() && + A_colwise_data->numel() > 0 && A_colwise_scale_inv->numel() > 0) { + swizzled_scale_inverses.emplace_back( + swizzleScaleForGemm(*A_colwise_data, A_te_dtype, *A_colwise_scale_inv, A_scaling_mode)); + A_colwise_scale_inv = swizzled_scale_inverses.back(); + } } - } else { - int64_t dim0 = _ubuf.size(0); - int64_t dim1 = _ubuf.size(1); - if (local_chunk) { - dim0 /= _tp_size; + A_with_gemm_swizzled_scales = true; + } + if (!B_with_gemm_swizzled_scales && requiresScaleSwizzle(B_sm)) { + if (!transb) { + if (B_scale_inv.has_value() && B_scale_inv->numel() > 0) { + swizzled_scale_inverses.emplace_back( + swizzleScaleForGemm(B_data, B_te_dtype, *B_scale_inv, B_scaling_mode)); + B_scale_inv = swizzled_scale_inverses.back(); + } + } else { + if (B_colwise_data.has_value() && B_colwise_scale_inv.has_value() && + B_colwise_data->numel() > 0 && B_colwise_scale_inv->numel() > 0) { + swizzled_scale_inverses.emplace_back( + swizzleScaleForGemm(*B_colwise_data, B_te_dtype, *B_colwise_scale_inv, B_scaling_mode)); + B_colwise_scale_inv = swizzled_scale_inverses.back(); + } } - shape = {dim0, dim1}; + B_with_gemm_swizzled_scales = true; + } + + auto A_tensor = + buildInputTensorWrapper(A_data, A_te, A_scale_inv, A_colwise_data, A_colwise_scale_inv, A_sm); + auto B_tensor = + buildInputTensorWrapper(B_data, B_te, B_scale_inv, B_colwise_data, B_colwise_scale_inv, B_sm); + auto D_tensor = + makeQuantizedTensorWrapper(D_data, D_te, D_shape, D_amax, D_scale, D_scale_inv, D_sm); + A_tensor.set_with_gemm_swizzled_scales(A_with_gemm_swizzled_scales); + B_tensor.set_with_gemm_swizzled_scales(B_with_gemm_swizzled_scales); + + TensorWrapper bias_tensor; + if (bias.has_value()) { + auto bias_te = static_cast(bias_type); + auto bias_shape = getStableTensorShape(bias.value()); + bias_tensor = makeTransformerEngineTensor(bias->data_ptr(), bias_shape, bias_te); + } + + TensorWrapper pre_gelu_tensor; + if (pre_gelu_out.has_value()) { + pre_gelu_tensor = makeTransformerEngineTensor(pre_gelu_out.value()); + } + + auto ws_tensor = makeTransformerEngineTensor(workspace); + + TensorWrapper extra_out_tensor; + if (extra_output.has_value()) { + extra_out_tensor = makeTransformerEngineTensor(extra_output.value()); } - // Data pointer - void *ubuf_ptr = local_chunk ? _ubufs[_tp_id].dptr() : _ubuf.dptr(); + auto device_idx = A_data.get_device_index(); + auto stream = getCurrentCUDAStreamRaw(device_idx); + + auto* co = get_core(overlap_handle); + auto co_type = static_cast(comm_type); + + if (bulk_overlap_flag) { + co->bulk_overlap(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, pre_gelu_tensor, + ws_tensor, grad, accumulate, use_split_accumulator, co_type, extra_out_tensor, + stream); + } else if (co_type == te::CommOverlapType::AG) { + if (co->is_atomic_gemm()) { + co->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + pre_gelu_tensor, ws_tensor, grad, accumulate, + use_split_accumulator, extra_out_tensor, stream); + } else { + co->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + pre_gelu_tensor, ws_tensor, grad, accumulate, use_split_accumulator, + extra_out_tensor, stream); + } + } else { + if (co->is_atomic_gemm()) { + co->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + pre_gelu_tensor, ws_tensor, grad, accumulate, + use_split_accumulator, extra_out_tensor, stream); + } else { + co->split_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + pre_gelu_tensor, ws_tensor, grad, accumulate, use_split_accumulator, + extra_out_tensor, stream); + } + } +} - // Construct PyTorch tensor - const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype()); - return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA)); +} // namespace transformer_engine::pytorch::stable + +STABLE_TORCH_LIBRARY_FRAGMENT(transformer_engine_stable, m) { + // Callback registration + m.def("register_comm_callbacks(int allgather_fn_ptr, int barrier_fn_ptr) -> ()"); + // CommOverlapBase lifecycle + m.def( + "create_comm_overlap(int[] buffer_shape, int buffer_dtype, int myrank, int numranks, int " + "mylocal, int numlocal, int mynode, int numnodes, int tp_size, int num_splits, int " + "num_max_streams, int comm_cga_size, int gemm_priority, int comm_priority, int num_comm_sm, " + "bool set_sm_margin, bool atomic_gemm, bool rs_overlap_first_gemm) -> int"); + m.def("destroy_comm_overlap(int handle) -> ()"); + // CommOverlapP2PBase lifecycle + m.def( + "create_comm_overlap_p2p(int[] buffer_shape, int buffer_dtype, int myrank, int numranks, int " + "mylocal, int numlocal, int mynode, int numnodes, int tp_size, int comm_type, int " + "num_max_streams, int comm_cga_size, int gemm_priority, int comm_priority, int num_comm_sm, " + "bool set_sm_margin, bool use_ce, bool atomic_gemm, bool aggregate) -> int"); + m.def("destroy_comm_overlap_p2p(int handle) -> ()"); + // Buffer operations + m.def("comm_overlap_copy_into_buffer(Tensor input, int handle, bool local_chunk) -> ()"); + m.def("comm_overlap_get_buffer(int handle, bool local_chunk, int dim0, int dim1) -> Tensor"); + // Stream access + m.def("comm_overlap_get_stream(int handle) -> int"); + m.def("comm_overlap_p2p_get_streams(int handle) -> (int, int)"); + // Queries + m.def( + "bulk_overlap_ag_with_external_gemm(int handle, int send_stream_ptr, int recv_stream_ptr) -> " + "()"); + m.def("comm_overlap_get_tp_size(int handle) -> int"); + m.def("comm_overlap_is_atomic_gemm(int handle) -> bool"); + m.def("comm_overlap_is_p2p(int handle) -> bool"); + m.def("comm_overlap_is_fp8_ubuf(int handle) -> bool"); + // GEMM with comm overlap + m.def( + "gemm_with_comm_overlap(" + "Tensor A_data, int A_te_dtype, Tensor? A_scale_inv, " + "Tensor? A_colwise_data, Tensor? A_colwise_scale_inv, " + "int A_scaling_mode, bool A_with_gemm_swizzled_scales, bool transa, " + "Tensor B_data, int B_te_dtype, Tensor? B_scale_inv, " + "Tensor? B_colwise_data, Tensor? B_colwise_scale_inv, " + "int B_scaling_mode, bool B_with_gemm_swizzled_scales, bool transb, " + "Tensor D_data, int D_te_dtype, Tensor? D_amax, Tensor? D_scale, Tensor? D_scale_inv, " + "int D_scaling_mode, Tensor? bias, int bias_type, Tensor? pre_gelu_out, " + "Tensor workspace, bool grad, bool accumulate, bool use_split_accumulator, " + "int overlap_handle, int comm_type, bool bulk_overlap_flag, " + "Tensor? extra_output) -> ()"); } -std::pair CommOverlapP2P::get_communication_stream() { - return {at::cuda::getStreamFromExternal(_stream_send[0], at::cuda::current_device()), - at::cuda::getStreamFromExternal(_stream_recv, at::cuda::current_device())}; +// Ops with tensor INPUT arguments → CUDA dispatch key +STABLE_TORCH_LIBRARY_IMPL(transformer_engine_stable, CUDA, m) { + using namespace transformer_engine::pytorch::stable; + m.impl("comm_overlap_copy_into_buffer", TORCH_BOX(comm_overlap_copy_into_buffer)); + m.impl("gemm_with_comm_overlap", TORCH_BOX(gemm_with_comm_overlap)); } -void transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm( - CommOverlap &allgather_communicator, at::Stream send_stream, at::Stream recv_stream) { - auto main_stream = at::cuda::getCurrentCUDAStream(); - allgather_communicator.bulk_overlap_external_ag(at::cuda::CUDAStream(send_stream), - at::cuda::CUDAStream(recv_stream), main_stream); +// Ops without tensor INPUT arguments need CompositeImplicitAutograd since +// PyTorch dispatches based on input tensors, not output tensors. +STABLE_TORCH_LIBRARY_IMPL(transformer_engine_stable, CompositeImplicitAutograd, m) { + using namespace transformer_engine::pytorch::stable; + m.impl("register_comm_callbacks", TORCH_BOX(register_comm_callbacks)); + m.impl("create_comm_overlap", TORCH_BOX(create_comm_overlap)); + m.impl("destroy_comm_overlap", TORCH_BOX(destroy_comm_overlap)); + m.impl("create_comm_overlap_p2p", TORCH_BOX(create_comm_overlap_p2p)); + m.impl("destroy_comm_overlap_p2p", TORCH_BOX(destroy_comm_overlap_p2p)); + m.impl("comm_overlap_get_buffer", TORCH_BOX(comm_overlap_get_buffer)); + m.impl("comm_overlap_get_stream", TORCH_BOX(comm_overlap_get_stream)); + m.impl("comm_overlap_p2p_get_streams", TORCH_BOX(comm_overlap_p2p_get_streams)); + m.impl("bulk_overlap_ag_with_external_gemm", TORCH_BOX(bulk_overlap_ag_with_external_gemm)); + m.impl("comm_overlap_get_tp_size", TORCH_BOX(comm_overlap_get_tp_size)); + m.impl("comm_overlap_is_atomic_gemm", TORCH_BOX(comm_overlap_is_atomic_gemm)); + m.impl("comm_overlap_is_p2p", TORCH_BOX(comm_overlap_is_p2p)); + m.impl("comm_overlap_is_fp8_ubuf", TORCH_BOX(comm_overlap_is_fp8_ubuf)); } diff --git a/transformer_engine/pytorch/csrc/extensions/dropout.cpp b/transformer_engine/pytorch/csrc/extensions/dropout.cpp index bea8f3a7b5..5179602884 100644 --- a/transformer_engine/pytorch/csrc/extensions/dropout.cpp +++ b/transformer_engine/pytorch/csrc/extensions/dropout.cpp @@ -4,86 +4,86 @@ * See LICENSE for license information. ************************************************************************/ -#include "transformer_engine/dropout.h" - -#include -#include - -#include - -#include "../common.h" -#include "../extensions.h" -#include "../pybind.h" -#include "transformer_engine/transformer_engine.h" - -namespace transformer_engine { -namespace pytorch { - -std::vector dropout_fwd(const py::handle &input, float dropout_probability, - std::optional out) { - using namespace transformer_engine::pytorch::detail; - - // Input tensor - const TensorWrapper input_nvte = makeTransformerEngineTensor(input, py::none()); - - // Allocate output tensor if needed - if (!out) { - at::ScalarType dtype = GetATenDType(input_nvte.dtype()); - if (dtype == at::kFloat8_e4m3fn || dtype == at::kFloat8_e5m2) { - dtype = input.attr("dtype").cast(); - } - const auto shape_uint64 = convertShape(input_nvte.shape()); - const std::vector shape_int64(shape_uint64.begin(), shape_uint64.end()); - const auto opts = at::TensorOptions().dtype(dtype).device(torch::kCUDA); - out = at::empty(shape_int64, opts); - } - TensorWrapper out_nvte = makeTransformerEngineTensor(*out); - - // Mask tensor - auto mask_pyt = allocateTorchTensor(input_nvte.numel() / 8, DType::kByte); - auto mask_nvte = makeTransformerEngineTensor(mask_pyt); - - // RNG state tensor - auto gen = at::get_generator_or_default( - std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); - at::PhiloxCudaState philox_args; - { - std::lock_guard lock(gen->mutex_); - constexpr int64_t rng_elts_per_thread = 4; - philox_args = gen->philox_cuda_state(rng_elts_per_thread); - } - auto rng_state_pyt = allocateTorchTensor(2, DType::kInt64); - NVTE_SCOPED_GIL_RELEASE({ - nvte_extract_seed_and_offset( - reinterpret_cast(rng_state_pyt.data_ptr()), philox_args.captured_, - philox_args.seed_.ptr, philox_args.seed_.val, philox_args.offset_.ptr, - philox_args.offset_.val, philox_args.offset_intragraph_, at::cuda::getCurrentCUDAStream()); - }); - auto rng_state_nvte = makeTransformerEngineTensor(rng_state_pyt); - - // Launch kernel - NVTE_SCOPED_GIL_RELEASE({ - nvte_dropout_fwd(input_nvte.data(), out_nvte.data(), mask_nvte.data(), rng_state_nvte.data(), - dropout_probability, at::cuda::getCurrentCUDAStream()); - }); - - return {py::cast(std::move(*out)), py::cast(mask_pyt)}; +#include + +#include "../stable_common.h" + +namespace transformer_engine::pytorch::stable { + +using Tensor = torch::stable::Tensor; + +// ============================================================================ +// Dropout forward — RNG state extracted in Python, passed as tensor +// +// Python shim does: +// gen = torch.cuda.default_generators[device] +// philox_state = gen.get_state() # or philox_cuda_state for graph capture +// seed, offset = extract_seed_offset(philox_state) +// rng_state = torch.tensor([seed, offset], dtype=torch.int64, device='cuda') +// ============================================================================ + +std::tuple dropout_fwd(Tensor input, Tensor rng_state, double dropout_probability) { + auto input_cu = makeTransformerEngineTensor(input); + + auto device_idx = input.get_device_index(); + auto shape = getStableTensorShape(input); + size_t total = 1; + for (auto s : shape) total *= s; + + // Mask: 1 bit per element, packed into uint8 + auto mask = + allocateStableTensor({static_cast((total + 7) / 8)}, ScalarType::Byte, device_idx); + + auto output = torch::stable::empty_like(input); + + auto output_cu = makeTransformerEngineTensor(output); + auto mask_cu = makeTransformerEngineTensor(mask); + auto rng_state_cu = makeTransformerEngineTensor(rng_state); + + nvte_dropout_fwd(input_cu.data(), output_cu.data(), mask_cu.data(), rng_state_cu.data(), + static_cast(dropout_probability), getCurrentCUDAStreamRaw(device_idx)); + + return std::make_tuple(output, mask); } -py::object dropout_bwd(const at::Tensor &grad_output, const at::Tensor &mask, - const float dropout_probability, std::optional grad_input) { - const auto grad_output_nvte = makeTransformerEngineTensor(grad_output); - const auto mask_nvte = makeTransformerEngineTensor(mask); - if (!grad_input) { - grad_input = at::empty_like(grad_output); +// ============================================================================ +// Dropout backward +// ============================================================================ + +Tensor dropout_bwd(Tensor grad_output, Tensor mask, double dropout_probability, + std::optional grad_input) { + auto grad_output_ = torch::stable::contiguous(grad_output); + + Tensor grad_in; + if (grad_input.has_value()) { + grad_in = grad_input.value(); + } else { + grad_in = torch::stable::empty_like(grad_output_); } - auto grad_input_nvte = makeTransformerEngineTensor(*grad_input); - NVTE_SCOPED_GIL_RELEASE({ - nvte_dropout_bwd(grad_output_nvte.data(), mask_nvte.data(), grad_input_nvte.data(), - dropout_probability, at::cuda::getCurrentCUDAStream()); - }); - return py::cast(std::move(*grad_input)); + + auto grad_output_cu = makeTransformerEngineTensor(grad_output_); + auto mask_cu = makeTransformerEngineTensor(mask); + auto grad_input_cu = makeTransformerEngineTensor(grad_in); + + nvte_dropout_bwd(grad_output_cu.data(), mask_cu.data(), grad_input_cu.data(), + static_cast(dropout_probability), + getCurrentCUDAStreamRaw(grad_output_.get_device_index())); + + return grad_in; } -} // namespace pytorch -} // namespace transformer_engine +} // namespace transformer_engine::pytorch::stable + +STABLE_TORCH_LIBRARY_FRAGMENT(transformer_engine_stable, m) { + m.def( + "dropout_fwd(Tensor input, Tensor rng_state, float dropout_probability) -> (Tensor, Tensor)"); + m.def( + "dropout_bwd(Tensor grad_output, Tensor mask, float dropout_probability, Tensor? grad_input) " + "-> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(transformer_engine_stable, CUDA, m) { + using namespace transformer_engine::pytorch::stable; + m.impl("dropout_fwd", TORCH_BOX(dropout_fwd)); + m.impl("dropout_bwd", TORCH_BOX(dropout_bwd)); +} diff --git a/transformer_engine/pytorch/csrc/extensions/fp8_partial_cast.cpp b/transformer_engine/pytorch/csrc/extensions/fp8_partial_cast.cpp deleted file mode 100644 index d6693a485e..0000000000 --- a/transformer_engine/pytorch/csrc/extensions/fp8_partial_cast.cpp +++ /dev/null @@ -1,89 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "../extensions.h" - -namespace transformer_engine::pytorch { - -void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h, - size_t w, size_t start_offset, size_t block_len) { - TORCH_CHECK(block_len == 128, "Currently only block_len = 128 is supported"); - TORCH_CHECK(amax.dim() == 2, "amax must be a 2D tensor"); - TORCH_CHECK(amax.scalar_type() == at::ScalarType::Float, "amax must be a float tensor"); - TORCH_CHECK(tensor.scalar_type() == at::ScalarType::Float || - tensor.scalar_type() == at::ScalarType::BFloat16, - "tensor must be a float or bfloat16 tensor"); - - const TensorWrapper tensor_cu = makeTransformerEngineTensor(tensor); - TensorWrapper amax_cu = makeTransformerEngineTensor(amax); - - nvte_fp8_block_scaling_compute_partial_amax(tensor_cu.data(), amax_cu.data(), h, w, - amax.stride(0), amax.stride(1), start_offset, - block_len, at::cuda::getCurrentCUDAStream()); -} - -void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale, - size_t h, size_t w, size_t start_offset, size_t block_len, - const transformer_engine::DType out_dtype) { - TORCH_CHECK(block_len == 128, "Currently only block_len = 128 is supported"); - TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor"); - TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor"); - TORCH_CHECK( - inp.scalar_type() == at::ScalarType::Float || inp.scalar_type() == at::ScalarType::BFloat16, - "input must be a float or bfloat16 tensor"); - TORCH_CHECK(out.scalar_type() == at::ScalarType::Byte, "output must be a uint8 tensor"); - TORCH_CHECK(out_dtype == transformer_engine::DType::kFloat8E4M3 || - out_dtype == transformer_engine::DType::kFloat8E5M2, - "out_dtype must be kFloat8E4M3 or kFloat8E5M2"); - - const TensorWrapper inp_cu = makeTransformerEngineTensor(inp); - TensorWrapper out_cu = makeTransformerEngineTensor(out); - const TensorWrapper scale_cu = makeTransformerEngineTensor(scale); - - nvte_fp8_block_scaling_partial_cast( - inp_cu.data(), out_cu.data(), scale_cu.data(), h, w, scale.stride(0), scale.stride(1), - start_offset, block_len, static_cast(out_dtype), at::cuda::getCurrentCUDAStream()); -} - -void mxfp8_scaling_compute_partial_amax(const at::Tensor &input, at::Tensor amax_rowwise, - at::Tensor amax_colwise, int rows, int cols, - size_t start_offset) { - TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); - TORCH_CHECK(amax_rowwise.is_contiguous(), "amax_rowwise must be contiguous"); - TORCH_CHECK(amax_colwise.is_contiguous(), "amax_colwise must be contiguous"); - - const TensorWrapper input_cu = makeTransformerEngineTensor(input); - TensorWrapper amax_rowwise_cu = makeTransformerEngineTensor(amax_rowwise); - TensorWrapper amax_colwise_cu = makeTransformerEngineTensor(amax_colwise); - - nvte_mxfp8_scaling_compute_partial_amax(input_cu.data(), amax_rowwise_cu.data(), - amax_colwise_cu.data(), rows, cols, start_offset, - at::cuda::getCurrentCUDAStream()); -} - -void mxfp8_scaling_partial_cast(const at::Tensor &input, at::Tensor output_rowwise, - at::Tensor output_colwise, const at::Tensor &scale_inv_rowwise, - const at::Tensor &scale_inv_colwise, int rows, int cols, - size_t start_offset) { - TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); - TORCH_CHECK(output_rowwise.is_contiguous(), "output_rowwise must be contiguous"); - TORCH_CHECK(output_colwise.is_contiguous(), "output_colwise must be contiguous"); - TORCH_CHECK(scale_inv_rowwise.is_contiguous(), "scale_inv_rowwise must be contiguous"); - TORCH_CHECK(scale_inv_colwise.is_contiguous(), "scale_inv_colwise must be contiguous"); - - const TensorWrapper input_cu = makeTransformerEngineTensor(input); - TensorWrapper output_rowwise_cu = makeTransformerEngineTensor(output_rowwise); - TensorWrapper output_colwise_cu = makeTransformerEngineTensor(output_colwise); - const TensorWrapper scale_inv_rowwise_cu = makeTransformerEngineTensor(scale_inv_rowwise); - const TensorWrapper scale_inv_colwise_cu = makeTransformerEngineTensor(scale_inv_colwise); - - nvte_mxfp8_scaling_partial_cast(input_cu.data(), output_rowwise_cu.data(), - output_colwise_cu.data(), scale_inv_rowwise_cu.data(), - scale_inv_colwise_cu.data(), rows, cols, start_offset, - at::cuda::getCurrentCUDAStream()); -} - -} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 1431ebdfb4..e2e1f34866 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -4,786 +4,515 @@ * See LICENSE for license information. ************************************************************************/ -#include +#include +#include -#include -#include - -#include "../common.h" -#include "../extensions.h" -#include "common.h" +#include "../stable_common.h" #include "common/util/cuda_runtime.h" -#include "common/util/system.h" -#include "pybind.h" -#include "transformer_engine/transformer_engine.h" -#include "util.h" - -namespace { - -void* get_data_ptr(transformer_engine::pytorch::MaybeTensor tensor) { - if (tensor.has_value()) return tensor->data_ptr(); - return nullptr; -} - -size_t get_size(transformer_engine::pytorch::MaybeTensor tensor, int dim) { - if (tensor.has_value()) return static_cast(tensor->size(dim)); - return 0; -} -} // namespace +namespace transformer_engine::pytorch::stable { -namespace transformer_engine::pytorch { +using Tensor = torch::stable::Tensor; -namespace detail { - -bool is_low_precision(const DType type) { - return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; -} +namespace { -std::vector getGemmOutputShape(const NVTEShape& A_shape, const bool transa, - const NVTEShape& B_shape, const bool transb) { - // Flatten outer dims to get 2D matrices - const size_t A0 = A_shape.ndim > 0 ? product(A_shape, 0, A_shape.ndim - 1) : 1; - const size_t A1 = A_shape.ndim > 0 ? A_shape.data[A_shape.ndim - 1] : 1; - const size_t B0 = B_shape.ndim > 0 ? product(B_shape, 0, B_shape.ndim - 1) : 1; - const size_t B1 = B_shape.ndim > 0 ? B_shape.data[B_shape.ndim - 1] : 1; - - // Check matrix dims - NVTE_CHECK((transa ? A1 : A0) == (transb ? B0 : B1), "Invalid matrix dimensions for GEMM (A=(", - A0, ",", A1, "), transa=", transa, ", B=(", B0, ",", B1, "), transb=", transb, ")"); - - // Construct output dims - std::vector ret; - if (transb) { - ret.emplace_back(B1); - } else { - // Unflatten B0 - for (size_t i = 0; i < B_shape.ndim - 1; ++i) { - ret.emplace_back(B_shape.data[i]); - } - } - if (transa) { - ret.emplace_back(A0); - } else { - ret.emplace_back(A1); +bool requiresScaleSwizzle(NVTEScalingMode scaling_mode) { + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: + case NVTE_NVFP4_1D_SCALING: + return true; + case NVTE_INVALID_SCALING: + NVTE_ERROR("Invalid scaling mode for swizzling scaling factors."); + default: + return false; } - return ret; } -bool checkGemmShape(const std::vector& expected, const NVTEShape& actual) { - if (expected.size() != actual.ndim) return false; - for (size_t i = 0; i < expected.size(); ++i) { - if (expected[i] != actual.data[i]) return false; +// Return the DType used for scale_inv given a scaling mode. +DType getScaleInvDtype(NVTEScalingMode scaling_mode) { + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: + return DType::kFloat8E8M0; + case NVTE_NVFP4_1D_SCALING: + return DType::kFloat8E4M3; + default: + return DType::kFloat32; } - return true; } -struct GroupedGemmConfig { - TensorWrapper te_alpha; - TensorWrapper te_beta; - TensorWrapper te_workspace_setup; - TensorWrapper te_workspace_cublas; - std::optional matmul_config; -}; - -GroupedGemmConfig prepare_grouped_gemm_config(at::Tensor alpha, at::Tensor beta, - at::Tensor workspace_setup, - at::Tensor workspace_cublas, size_t num_tensors, - int math_sm_count, bool use_split_accumulator) { - NVTE_CHECK(alpha.numel() == static_cast(num_tensors), - "Grouped GEMM expects alpha to have num_tensors elements."); - NVTE_CHECK(beta.numel() == static_cast(num_tensors), - "Grouped GEMM expects beta to have num_tensors elements."); - - GroupedGemmConfig grouped_gemm_config{ - makeTransformerEngineTensor(alpha), - makeTransformerEngineTensor(beta), - makeTransformerEngineTensor(workspace_setup.data_ptr(), - std::vector{static_cast(workspace_setup.numel())}, - DType::kByte), - makeTransformerEngineTensor( - workspace_cublas.data_ptr(), - std::vector{static_cast(workspace_cublas.numel())}, DType::kByte), - std::nullopt, - }; - - if (math_sm_count > 0 || use_split_accumulator) { - grouped_gemm_config.matmul_config.emplace(); - if (math_sm_count > 0) { - grouped_gemm_config.matmul_config->set_sm_count(math_sm_count); - } - grouped_gemm_config.matmul_config->set_use_split_accumulator(use_split_accumulator); +Tensor swizzleScaleForGemm(const Tensor& data, int64_t te_dtype, const Tensor& scale_inv, + int64_t scaling_mode, bool columnwise = false) { + auto tensor_dtype = static_cast(te_dtype); + auto tensor_scaling_mode = static_cast(scaling_mode); + auto data_shape = getStableTensorShape(data); + // FP4 data is packed (2 elements per byte). Double last dim to report + // logical element count. For rowwise [M, K/2] → [M, K]; for columnwise + // [K, M/2] → [K, M]. TensorWrapper::shape() handles the transpose. + if (is_fp4_dtype(tensor_dtype) && !data_shape.empty()) { + data_shape.back() *= 2; } + DType si_dtype = getScaleInvDtype(tensor_scaling_mode); + auto si_shape = getStableTensorShape(scale_inv); + + // Allocate output scale tensor with the same shape as input + std::vector si_shape_i64(si_shape.begin(), si_shape.end()); + auto output_scale_inv = allocateStableTensor(si_shape_i64, si_dtype, data.get_device_index()); + + TensorWrapper input_nvte(tensor_scaling_mode); + TensorWrapper output_nvte(tensor_scaling_mode); + if (columnwise) { + input_nvte.set_columnwise_data(nullptr, tensor_dtype, data_shape); + input_nvte.set_columnwise_scale_inv(scale_inv.data_ptr(), si_dtype, si_shape); + output_nvte.set_columnwise_data(nullptr, tensor_dtype, data_shape); + output_nvte.set_columnwise_scale_inv(output_scale_inv.data_ptr(), si_dtype, si_shape); + } else { + input_nvte.set_rowwise_data(nullptr, tensor_dtype, data_shape); + input_nvte.set_rowwise_scale_inv(scale_inv.data_ptr(), si_dtype, si_shape); + output_nvte.set_rowwise_data(nullptr, tensor_dtype, data_shape); + output_nvte.set_rowwise_scale_inv(output_scale_inv.data_ptr(), si_dtype, si_shape); + } + output_nvte.set_with_gemm_swizzled_scales(true); - return grouped_gemm_config; -} - -} // namespace detail + nvte_swizzle_scaling_factors(input_nvte.data(), output_nvte.data(), + getCurrentCUDAStreamRaw(data.get_device_index())); -std::pair createOutputTensor(const std::vector& shape, - DType dtype, py::handle quantizer) { - std::unique_ptr my_quantizer = convert_quantizer(quantizer); - return my_quantizer->create_tensor(shape, dtype); + return output_scale_inv; } -std::vector gemm(py::handle A, bool transa, py::handle B, bool transb, py::object D, - py::handle quantizer, std::optional out_dtype, MaybeTensor bias, - DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad, - at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, CommOverlapCore* comm_overlap, - std::optional comm_type, MaybeTensor extra_output, - bool bulk_overlap, float alpha, std::optional beta) { - using namespace transformer_engine::pytorch::detail; - - // Ensure that cublasLt handle is created on the correct device, - // overriding torch.cuda.set_device calls from user side. - // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(workspace.device()); - - // Input tensors - NVTE_CHECK(!A.is_none(), "Tensor A has not been provided"); - NVTE_CHECK(!B.is_none(), "Tensor B has not been provided"); - auto none = py::none(); - TensorWrapper A_tensor = makeTransformerEngineTensor(A, none); - TensorWrapper B_tensor = makeTransformerEngineTensor(B, none); - - const bool low_precision = - detail::is_low_precision(A_tensor.dtype()) || detail::is_low_precision(B_tensor.dtype()); - const bool fp8_block_scaling = A_tensor.scaling_mode() == NVTE_BLOCK_SCALING_1D || - A_tensor.scaling_mode() == NVTE_BLOCK_SCALING_2D || - B_tensor.scaling_mode() == NVTE_BLOCK_SCALING_1D || - B_tensor.scaling_mode() == NVTE_BLOCK_SCALING_2D; - - // Check tensor dimensions - const auto& A_shape = A_tensor.shape(); - const auto& B_shape = B_tensor.shape(); - const auto& D_shape = detail::getGemmOutputShape(A_shape, transa, B_shape, transb); - NVTE_CHECK(A_shape.ndim >= 1, "Tensor A needs to have at least 1 dimension"); - NVTE_CHECK(B_shape.ndim >= 1, "Tensor B needs to have at least 1 dimension"); - - // Check scaling factors - if (accumulate) { - if (!beta) { - beta = 1.0f; - } - } else { - if (!beta) { - beta = 0.0f; - } - NVTE_CHECK(beta == 0.0, "Trying to use non-zero beta while not accumulating ", - "into D tensor. Beta has nothing to be applied to."); - } +// Build a TensorWrapper with both rowwise and (optionally) columnwise data/scales. +// A_data always holds the rowwise buffer with the LOGICAL shape. +// When A_colwise_data is provided it is set as columnwise_data on the wrapper, +// allowing CanonicalizeGemmInput to select the right buffer based on transa/transb. +TensorWrapper buildInputTensorWrapper(const Tensor& rowwise_data, DType te_dtype, + const std::optional& rowwise_scale_inv, + const std::optional& colwise_data, + const std::optional& colwise_scale_inv, + NVTEScalingMode scaling_mode, + const std::optional& amax = std::nullopt, + const std::optional& colwise_amax = std::nullopt) { + DType si_dtype = getScaleInvDtype(scaling_mode); + std::vector scalar_shape = {1}; + + TensorWrapper out(scaling_mode); + // Only set rowwise data when there is actual data (numel > 0). + // When a Float8Tensor has columnwise-only storage (_data=None), the caller + // passes an empty tensor here; we skip set_rowwise_data so TensorWrapper + // has_data() returns false, and NVTE's CanonicalizeGemmInput uses the + // columnwise buffer instead. + // FP4 data is packed (2 elements per byte). The physical tensor has K/2 + // bytes but the logical element count is K. Report logical shape so that + // CheckScaleTensorShape (which derives expected scale shape from + // flat_last_dim()) and the swizzle kernel get the correct K. + const bool fp4_packed = is_fp4_dtype(te_dtype); + + // Flatten >2D shapes to 2D for GEMM. + // Only for MXFP8/DELAYED scaling where TensorWrapper::shape() returns + // data shape as-is (no transpose). + // Rowwise data is [M_dims..., K] → [M, K] (keep last dim). + // Columnwise data is [K, M_dims...] → [K, M] (keep first dim). + auto flatten2D = [](std::vector& shape, NVTEScalingMode sm) { + if (shape.size() <= 2) return; + if (sm != NVTE_MXFP8_1D_SCALING && sm != NVTE_DELAYED_TENSOR_SCALING) return; + size_t K = shape.back(); + size_t M = 1; + for (size_t i = 0; i + 1 < shape.size(); ++i) M *= shape[i]; + shape = {M, K}; + }; + auto flatten2D_columnwise = [](std::vector& shape, NVTEScalingMode sm) { + if (shape.size() <= 2) return; + if (sm != NVTE_MXFP8_1D_SCALING && sm != NVTE_DELAYED_TENSOR_SCALING) return; + size_t K = shape.front(); + size_t M = 1; + for (size_t i = 1; i < shape.size(); ++i) M *= shape[i]; + shape = {K, M}; + }; - DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype(); - // Output tensor - TensorWrapper D_tensor; - if (D.is_none()) { - std::tie(D_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer); - } else { - D_tensor = makeTransformerEngineTensor(D, quantizer); - NVTE_CHECK(detail::checkGemmShape(D_shape, D_tensor.shape()), - "GEMM output has invalid dims (expected ", std::to_string(D_shape), ", got ", - std::to_string(D_tensor.shape()), ")"); - if (out_dtype) { - NVTE_CHECK(*out_dtype == D_tensor.dtype(), "GEMM output has invalid dtype (expected ", - static_cast(*out_dtype), ", found ", static_cast(D_tensor.dtype()), ")"); - } + if (rowwise_data.numel() > 0) { + auto shape = getStableTensorShape(rowwise_data); + if (fp4_packed && !shape.empty()) shape.back() *= 2; + flatten2D(shape, scaling_mode); + out.set_rowwise_data(rowwise_data.data_ptr(), te_dtype, shape); } - // maintain unquantized tensor in case we need unfused quantization support. - TensorWrapper unquantized_D_tensor; - py::object unquantized_out; - // Unfused quantization is needed in the following cases - // 1. Inputs: BF16, Output: FP8 (GEMM output has to be BF16, so FP8 quantization needed after that) - // 2. Inputs: FP8, Output: FP8 (For any quantization apart from delayed scaling, - // GEMM Output needs to be in BF16, to allow for unfused quantization) - bool unfused_quantization_needed = !quantizer.is_none(); - if (low_precision) { - // At the moment, only use-case for fused GEMM: - // Delayed scaling quantizer with per-tensor scaling inputs - bool is_per_tensor_scaling_input = IsFloat8Tensor(A.ptr()) || IsFloat8Tensor(B.ptr()); - if (IsFloat8Quantizers(quantizer.ptr()) && is_per_tensor_scaling_input) - unfused_quantization_needed = false; + if (rowwise_scale_inv.has_value() && rowwise_scale_inv->numel() > 0) { + auto si_shape = getStableTensorShape(*rowwise_scale_inv); + out.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), si_dtype, si_shape); } - if (unfused_quantization_needed) { - NoneQuantizer q{none}; - std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(D_shape, output_dtype); + // Set amax for NVFP4 tensors. The pybind path (NVTETensorFromNVFP4Tensor) + // sets amax on the TensorWrapper; cuBLAS uses it in the GEMM formula: + // output = fp4_value * scale_e4m3 * amax / (6 * 448) + if (amax.has_value() && amax->numel() > 0) { + out.set_amax(amax->data_ptr(), DType::kFloat32, scalar_shape); + } else if (colwise_amax.has_value() && colwise_amax->numel() > 0 && is_fp4_dtype(te_dtype)) { + // Columnwise-only NVFP4: set rowwise amax from columnwise amax. + // nvte_nvfp4_compute_per_tensor_scale reads columnwise_amax directly + // for NT layout, but other paths may check amax.dptr. + out.set_amax(colwise_amax->data_ptr(), DType::kFloat32, scalar_shape); } - TensorWrapper& out_tensor = unfused_quantization_needed ? unquantized_D_tensor : D_tensor; - // Bias tensor - TensorWrapper bias_tensor; - MaybeTensor bias_grad = std::nullopt; - if (bias.has_value()) { - if (grad) { - auto opts = - torch::TensorOptions().dtype(GetATenDType(out_tensor.dtype())).device(torch::kCUDA); - bias_grad = at::empty({static_cast(B_shape.data[B_shape.ndim - 1])}, opts); - bias_tensor = makeTransformerEngineTensor(*bias_grad); - } else { - if (!bias->is_contiguous()) { - bias = bias->contiguous(); - } - bias_tensor = makeTransformerEngineTensor(*bias); + if (colwise_data.has_value() && colwise_data->numel() > 0) { + auto cw_shape = getStableTensorShape(*colwise_data); + // FP4 columnwise data [K, M/2]: double last dim to get [K, M]. + // TensorWrapper::shape() transposes to [M, K], producing correct + // flat_first_dim()=M and flat_last_dim()=K for CheckScaleTensorShape. + if (fp4_packed && !cw_shape.empty()) cw_shape.back() *= 2; + // Do NOT flatten columnwise data. For DELAYED/Block scaling the columnwise + // buffer is the physical transpose: shape [K, M_dims...]. + // TensorWrapper::shape() moves the first dim to the end, giving the logical + // [M_dims..., K] shape from which flat_first_dim()=M and flat_last_dim()=K + // are computed correctly, even for 3D+ tensors. + // Applying flatten2D here would incorrectly treat the last dim as K. + out.set_columnwise_data(colwise_data->data_ptr(), te_dtype, cw_shape); + if (colwise_scale_inv.has_value() && colwise_scale_inv->numel() > 0) { + auto csi_shape = getStableTensorShape(*colwise_scale_inv); + out.set_columnwise_scale_inv(colwise_scale_inv->data_ptr(), si_dtype, csi_shape); + } + if (colwise_amax.has_value() && colwise_amax->numel() > 0) { + out.set_columnwise_amax(colwise_amax->data_ptr(), DType::kFloat32, scalar_shape); } } + return out; +} + +} // namespace - // Activation input tensor - MaybeTensor pre_gelu_out = std::nullopt; - DType gelu_type = low_precision ? bias_type : out_tensor.dtype(); - if (gelu) { - if (!grad) { - auto dtype = GetATenDType(gelu_type); - auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); - std::vector torch_shape; - for (auto v : D_shape) { - torch_shape.push_back(v); +// ============================================================================ +// Core GEMM (no CommOverlap) +// +// This is the stable ABI version of nvte_cublas_gemm_v2. +// CommOverlap remains in the pybind11 module since it requires +// opaque class handles that can't cross the stable ABI boundary. +// +// The Python shim: +// 1. Extracts raw buffers from quantized A, B tensors +// 2. Pre-allocates output D (quantized or unquantized) +// 3. Calls this op for the core GEMM +// 4. Wraps output in the appropriate Python tensor type +// +// For FP8 tensors with separate rowwise/columnwise storage: +// - A_data / B_data always hold the ROWWISE (logical-shape) buffer +// - A_colwise_data / B_colwise_data (optional) hold the columnwise buffer +// - Both are set on the TensorWrapper; CanonicalizeGemmInput selects the +// right one based on transa/transb at run time +// ============================================================================ + +void gemm( + // Input A (rowwise data = logical shape; colwise optional for FP8) + Tensor A_data, int64_t A_te_dtype, std::optional A_scale_inv, + std::optional A_colwise_data, std::optional A_colwise_scale_inv, + int64_t A_scaling_mode, bool A_with_gemm_swizzled_scales, bool transa, + // Input B (rowwise data = logical shape; colwise optional for FP8) + Tensor B_data, int64_t B_te_dtype, std::optional B_scale_inv, + std::optional B_colwise_data, std::optional B_colwise_scale_inv, + int64_t B_scaling_mode, bool B_with_gemm_swizzled_scales, bool transb, + // Output D (pre-allocated) + Tensor D_data, int64_t D_te_dtype, std::optional D_amax, std::optional D_scale, + std::optional D_scale_inv, int64_t D_scaling_mode, + // Optional bias + std::optional bias, int64_t bias_type, + // Optional pre-gelu output + std::optional pre_gelu_out, + // Workspace + Tensor workspace, + // Config + bool grad, bool accumulate, bool use_split_accumulator, double alpha, + // NVFP4 amax (per-tensor amax for GEMM formula: out = fp4 * scale * amax / 2688) + std::optional A_amax = std::nullopt, + std::optional A_colwise_amax = std::nullopt, + std::optional B_amax = std::nullopt, + std::optional B_colwise_amax = std::nullopt) { + auto A_te = static_cast(A_te_dtype); + auto B_te = static_cast(B_te_dtype); + auto D_te = static_cast(D_te_dtype); + auto A_sm = static_cast(A_scaling_mode); + auto B_sm = static_cast(B_scaling_mode); + auto D_sm = static_cast(D_scaling_mode); + + auto D_shape = getStableTensorShape(D_data); + + // Auto-swizzle scales if needed (MXFP8/NVFP4, not pre-swizzled). + // Swizzle the direction that CanonicalizeGemmInput will actually use: + // transa=True → CanonicalizeGemmInput uses rowwise data → swizzle rowwise scale + // transa=False → CanonicalizeGemmInput uses colwise data → swizzle colwise scale + // transb=False → CanonicalizeGemmInput uses rowwise data → swizzle rowwise scale + // transb=True → CanonicalizeGemmInput uses colwise data → swizzle colwise scale + // This matches pybind: swizzle_scales_for_gemm(A, transa, !transa) / (B, !transb, transb). + std::vector swizzled_scale_inverses; + if (!A_with_gemm_swizzled_scales && requiresScaleSwizzle(A_sm)) { + if (transa) { + // transa=True → rowwise direction → swizzle rowwise scale + if (A_scale_inv.has_value() && A_scale_inv->numel() > 0) { + swizzled_scale_inverses.emplace_back( + swizzleScaleForGemm(A_data, A_te_dtype, *A_scale_inv, A_scaling_mode)); + A_scale_inv = swizzled_scale_inverses.back(); } - pre_gelu_out = at::empty(torch_shape, opts); } else { - if (gelu_in.has_value()) { - pre_gelu_out = *gelu_in; + // transa=False → colwise direction → swizzle colwise scale + if (A_colwise_data.has_value() && A_colwise_scale_inv.has_value() && + A_colwise_data->numel() > 0 && A_colwise_scale_inv->numel() > 0) { + swizzled_scale_inverses.emplace_back( + swizzleScaleForGemm(*A_colwise_data, A_te_dtype, *A_colwise_scale_inv, A_scaling_mode, + /*columnwise=*/true)); + A_colwise_scale_inv = swizzled_scale_inverses.back(); } } + A_with_gemm_swizzled_scales = true; } - const auto gelu_shape = gelu ? D_shape : std::vector{0}; - - auto te_pre_gelu_out = - makeTransformerEngineTensor(get_data_ptr(pre_gelu_out), gelu_shape, gelu_type); - - // Workspace - auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), - std::vector{workspaceSize}, DType::kByte); - - // Set an external SM Margin to all the GEMMs. - // This comes in handy when DP is overlapped with GEMMs - const int device_id = at::cuda::current_device(); - const int sm_count = transformer_engine::cuda::sm_count(device_id); - int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); - - // Construct GEMM config - transformer_engine::MatmulConfigWrapper config; - if (grad) { - config.set_dbias_tensor(bias_tensor.data()); - config.set_with_dgelu_epilogue(gelu); - } else { - config.set_bias_tensor(bias_tensor.data()); - config.set_with_gelu_epilogue(gelu); - } - config.set_epilogue_aux_tensor(te_pre_gelu_out.data()); - config.set_use_split_accumulator(use_split_accumulator); - config.set_sm_count(num_math_sms); - - // Keep the swizzled scaling factor tensors alive during the GEMM. - std::vector> swizzled_scale_inverses_list; - auto main_stream = at::cuda::getCurrentCUDAStream(); - if (A_tensor.numel() != 0 && B_tensor.numel() != 0) { - // Optionally swizzle the scaling factors - auto [A_row_scales, A_col_scales] = swizzle_scales_for_gemm(A_tensor, transa, !transa); - auto [B_row_scales, B_col_scales] = swizzle_scales_for_gemm(B_tensor, !transb, transb); - swizzled_scale_inverses_list.emplace_back(std::move(A_row_scales)); - swizzled_scale_inverses_list.emplace_back(std::move(A_col_scales)); - swizzled_scale_inverses_list.emplace_back(std::move(B_row_scales)); - swizzled_scale_inverses_list.emplace_back(std::move(B_col_scales)); - - // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer - // as it is not natively supported by cublasLt - if (fp8_block_scaling && transformer_engine::cuda::sm_arch() >= 100) { - // Convert tensors to mxfp8 and swizzle their scaling factors - swizzled_scale_inverses_list.emplace_back( - std::move(convert_block_scaling_to_mxfp8_tensor(A_tensor, transa))); - swizzled_scale_inverses_list.emplace_back( - std::move(convert_block_scaling_to_mxfp8_tensor(B_tensor, !transb))); - // Use TN GEMM to avoid having to transpose data. - transa = true; - transb = false; - } - - if (comm_overlap) { - // Prepare extra output tensor - TensorWrapper extra_output_tensor; - if (extra_output.has_value()) { - extra_output_tensor = makeTransformerEngineTensor(*extra_output); - } else { - extra_output_tensor = - makeTransformerEngineTensor(nullptr, std::vector{0}, DType::kByte); - } - - // Direct GEMM call to the correct overlap - if (bulk_overlap) { - NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, - te_pre_gelu_out, te_workspace, grad, accumulate, - use_split_accumulator, comm_type.value(), extra_output_tensor, - main_stream); - }); - } else if (comm_type.value() == CommOverlapType::AG) { - if (comm_overlap->is_atomic_gemm()) { - NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor, - bias_tensor, te_pre_gelu_out, te_workspace, grad, - accumulate, use_split_accumulator, - extra_output_tensor, main_stream); - }); - } else { - NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor, - bias_tensor, te_pre_gelu_out, te_workspace, grad, - accumulate, use_split_accumulator, extra_output_tensor, - main_stream); - }); - } - } else { - if (comm_overlap->is_atomic_gemm()) { - NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, - bias_tensor, te_pre_gelu_out, te_workspace, grad, - accumulate, use_split_accumulator, - extra_output_tensor, main_stream); - }); - } else { - NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, - bias_tensor, te_pre_gelu_out, te_workspace, grad, - accumulate, use_split_accumulator, extra_output_tensor, - main_stream); - }); - } + if (!B_with_gemm_swizzled_scales && requiresScaleSwizzle(B_sm)) { + if (!transb) { + // transb=False → rowwise direction → swizzle rowwise scale + if (B_scale_inv.has_value() && B_scale_inv->numel() > 0) { + swizzled_scale_inverses.emplace_back( + swizzleScaleForGemm(B_data, B_te_dtype, *B_scale_inv, B_scaling_mode)); + B_scale_inv = swizzled_scale_inverses.back(); } } else { - // Launch GEMM - NVTE_SCOPED_GIL_RELEASE({ - nvte_cublas_gemm_v2(transa, transb, &alpha, A_tensor.data(), B_tensor.data(), &beta.value(), - out_tensor.data(), out_tensor.data(), te_workspace.data(), config, - main_stream); - }); - } - } else { - if (out_tensor.numel() != 0 && !accumulate) { - out_tensor.zero_(main_stream); - } - if (bias.has_value()) { - if (bias->numel() != 0 && grad) { - bias_grad->zero_(); + // transb=True → colwise direction → swizzle colwise scale + if (B_colwise_data.has_value() && B_colwise_scale_inv.has_value() && + B_colwise_data->numel() > 0 && B_colwise_scale_inv->numel() > 0) { + swizzled_scale_inverses.emplace_back( + swizzleScaleForGemm(*B_colwise_data, B_te_dtype, *B_colwise_scale_inv, B_scaling_mode, + /*columnwise=*/true)); + B_colwise_scale_inv = swizzled_scale_inverses.back(); } } + B_with_gemm_swizzled_scales = true; } - if (unfused_quantization_needed) { - // Quantize the output - std::unique_ptr my_quantizer = convert_quantizer(quantizer); - my_quantizer->quantize(unquantized_D_tensor, D_tensor); - } - // Pack outputs - std::vector out; - out.emplace_back(std::move(D)); - out.emplace_back(py::cast(bias_grad)); - if (gelu && !grad) { - out.emplace_back(py::cast(*pre_gelu_out)); - } else { - out.emplace_back(py::none()); - } - if (extra_output.has_value()) { - out.emplace_back(py::cast(extra_output)); - } else { - out.emplace_back(py::none()); - } - return out; -} - -void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, - std::vector A_scaling_mode, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, DType B_type, std::vector B_scaling_mode, - bool transb, at::Tensor D, at::Tensor D_scale, DType D_type, at::Tensor D_amax, - at::Tensor bias, DType bias_type, at::Tensor pre_gelu_out, bool grad, - at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, int math_sm_count, int m_split, int n_split, - bool gemm_producer, at::Tensor counter) { - // Ensure that cublasLt handle is created on the correct device, - // overriding torch.cuda.set_device calls from user side. - // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(workspace.device()); - - // TODO: Handle scaling modes - NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING; - NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING; - - auto te_A = makeTransformerEngineTensor( - A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, - nullptr, nullptr, A_scale_inverse.data_ptr(), getTensorShape(A_scale_inverse), - nvte_scaling_modeA); - auto te_B = makeTransformerEngineTensor( - B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_type, - nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse), - nvte_scaling_modeB); - // TODO: D_scale_inv cannot be nullptr when D_type is FP8. - auto te_D = makeTransformerEngineTensor( - D.data_ptr(), - std::vector{static_cast(D.size(0)), static_cast(D.size(1))}, D_type, - D_amax.data_ptr(), D_scale.data_ptr(), nullptr); - auto te_bias = makeTransformerEngineTensor( - bias.data_ptr(), std::vector{static_cast(bias.size(0))}, bias_type); - auto te_counter = makeTransformerEngineTensor( - counter.data_ptr(), std::vector{static_cast(counter.size(0))}, DType::kInt32); - - const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr - ? std::vector{static_cast(pre_gelu_out.size(0))} - : std::vector{static_cast(pre_gelu_out.size(0)), - static_cast(pre_gelu_out.size(1))}; - auto te_pre_gelu_out = makeTransformerEngineTensor( - pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); - auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), - std::vector{workspaceSize}, DType::kByte); - - NVTE_SCOPED_GIL_RELEASE({ - nvte_cublas_atomic_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), - te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), - accumulate, use_split_accumulator, math_sm_count, m_split, n_split, - gemm_producer, te_counter.data(), at::cuda::getCurrentCUDAStream()); - }); -} - -std::optional> te_general_grouped_gemm( - std::vector A, bool transa, std::vector B, bool transb, - std::optional> D, DType D_type, std::vector m_splits, - std::vector bias, DType bias_type, bool single_output, - std::vector pre_gelu_out, bool grad, std::vector workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) { - if (single_output && D == std::nullopt) { - NVTE_ERROR("not implemented, D should be allocated for single output case."); - } - - // Ensure that cublasLt handle is created on the correct device, - // overriding torch.cuda.set_device calls from user side. - // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(workspace[0].device()); - void* output_data_ptr = nullptr; - if (single_output) { - output_data_ptr = (*D)[0].data_ptr(); - } - - const auto none = py::none(); - std::vector te_A_wrappers, te_B_wrappers, te_D_wrappers, te_bias_wrappers, - te_pre_gelu_out_wrappers; - std::vector D_vectors; - for (size_t i = 0; i < A.size(); i++) { - auto te_A = makeTransformerEngineTensor(A[i], none); - auto te_B = makeTransformerEngineTensor(B[i], none); - - // if there is single output - at::Tensor out_tensor; - auto size_t_shape = - pytorch::detail::getGemmOutputShape(te_A.shape(), transa, te_B.shape(), transb); - bool D_numel_is_zero = false; - std::vector D_shape; - for (size_t t : size_t_shape) { - D_shape.push_back(t); - if (t == 0) { - D_numel_is_zero = true; + // On Blackwell (sm >= 100), cuBLAS does not natively support BLOCK_SCALING. + // Convert to MXFP8, matching the pybind path's convert_block_scaling_to_mxfp8_tensor. + const bool fp8_block_scaling = (A_sm == NVTE_BLOCK_SCALING_1D || A_sm == NVTE_BLOCK_SCALING_2D || + B_sm == NVTE_BLOCK_SCALING_1D || B_sm == NVTE_BLOCK_SCALING_2D); + std::vector block_to_mxfp8_buffers; // keep alive during GEMM + + auto convertOneDirection = [&](const Tensor& sel_data, DType dtype, const Tensor& sel_si, + NVTEScalingMode orig_sm, int device_idx, + bool is_columnwise) -> Tensor { + auto d_shape = getStableTensorShape(sel_data); + const bool is_fp4 = is_fp4_dtype(dtype); + if (is_fp4 && !d_shape.empty()) d_shape.back() *= 2; + size_t flat_first = 1, flat_last = 1; + if (!d_shape.empty()) { + if (is_columnwise) { + // Columnwise data: first dim is K, rest are M → [K, M_dims...] + flat_first = d_shape[0]; + for (size_t i = 1; i < d_shape.size(); ++i) flat_last *= d_shape[i]; + } else { + // Rowwise data: last dim is K, rest are M → [M_dims..., K] + for (size_t i = 0; i + 1 < d_shape.size(); ++i) flat_first *= d_shape[i]; + flat_last = d_shape.back(); } } - auto dtype = GetATenDType(D_type); - auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); - if (single_output) { - if (output_data_ptr == nullptr) { - out_tensor = at::empty(D_shape, opts); - } else { - // We need to check !D_numel_is_zero because if the final input portion has zero elements, - // output_data_ptr would point beyond the allocated memory of D. This would cause - // at::from_blob to fail as it would reference memory not allocated by CUDA. - if (!D_numel_is_zero) { - out_tensor = at::from_blob(output_data_ptr, D_shape, opts); - } + std::vector data_2d = {flat_first, flat_last}; + TensorWrapper input_cu(orig_sm); + input_cu.set_rowwise_data(sel_data.data_ptr(), dtype, data_2d); + auto si_shape = getStableTensorShape(sel_si); + input_cu.set_rowwise_scale_inv(sel_si.data_ptr(), DType::kFloat32, si_shape); + + TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); + output_cu.set_rowwise_data(sel_data.data_ptr(), dtype, data_2d); + size_t sw_first = ((flat_first + 127) / 128) * 128; + size_t sw_last = ((flat_last + 127) / 128) * 4; + std::vector sw_shape = {sw_first, sw_last}; + auto sw_si = allocateStableTensor( + {static_cast(sw_first), static_cast(sw_last)}, DType::kByte, device_idx); + output_cu.set_rowwise_scale_inv(sw_si.data_ptr(), DType::kFloat8E8M0, sw_shape); + output_cu.set_with_gemm_swizzled_scales(true); + + nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(input_cu.data(), output_cu.data(), + getCurrentCUDAStreamRaw(device_idx)); + return sw_si; + }; + + auto convertBlockToMxfp8 = [&](Tensor& data, DType dtype, std::optional& scale_inv, + std::optional& cw_data, std::optional& cw_si, + NVTEScalingMode& sm, bool& swizzled, bool use_rowwise) { + if (sm != NVTE_BLOCK_SCALING_1D && sm != NVTE_BLOCK_SCALING_2D) return; + int device_idx = data.numel() > 0 ? data.get_device_index() + : (cw_data.has_value() ? cw_data->get_device_index() : 0); + + // Convert the direction that will be used by CanonicalizeGemmInput + if (use_rowwise) { + if (scale_inv.has_value() && scale_inv->numel() > 0) { + auto sw = convertOneDirection(data, dtype, *scale_inv, sm, device_idx, + /*is_columnwise=*/false); + scale_inv = sw; + block_to_mxfp8_buffers.push_back(std::move(sw)); } - char* char_ptr = reinterpret_cast(output_data_ptr); - char_ptr += D_shape[0] * D_shape[1] * (*D)[0].element_size(); - output_data_ptr = reinterpret_cast(char_ptr); - D_vectors.emplace_back(out_tensor); } else { - if (D == std::nullopt) { - auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); - out_tensor = at::empty(D_shape, opts); - D_vectors.emplace_back(out_tensor); - } else { - out_tensor = (*D)[i]; + if (cw_data.has_value() && cw_si.has_value() && cw_si->numel() > 0) { + auto sw = convertOneDirection(*cw_data, dtype, *cw_si, sm, device_idx, + /*is_columnwise=*/true); + cw_si = sw; + // Also convert rowwise scale if present (for CanonicalizeGemmInput + // which may fall back to rowwise on Blackwell). + if (scale_inv.has_value() && scale_inv->numel() > 0 && data.numel() > 0) { + auto sw2 = convertOneDirection(data, dtype, *scale_inv, sm, device_idx, + /*is_columnwise=*/false); + scale_inv = sw2; + block_to_mxfp8_buffers.push_back(std::move(sw2)); + } + block_to_mxfp8_buffers.push_back(std::move(sw)); } } + sm = NVTE_MXFP8_1D_SCALING; + swizzled = true; + }; - if (te_A.numel() == 0 || te_B.numel() == 0) { - if (out_tensor.numel() != 0 && !accumulate) out_tensor.zero_(); - if (bias[i].numel() != 0 && grad) { - bias[i].zero_(); - } - if (pre_gelu_out[i].numel() != 0) pre_gelu_out[i].zero_(); - continue; + // Build a MXFP8 TensorWrapper from block-scaling data, matching the pybind + // convert_block_scaling_to_mxfp8_tensor pattern: select rowwise or colwise data, + // flatten to 2D, set as rowwise data on MXFP8 wrapper. + auto buildMxfp8FromBlock = [&](const Tensor& sel_data, DType dtype, const Tensor& mxfp8_si, + bool is_colwise) -> TensorWrapper { + auto d_shape = getStableTensorShape(sel_data); + const bool is_fp4 = is_fp4_dtype(dtype); + if (is_fp4 && !d_shape.empty()) d_shape.back() *= 2; + size_t flat_first, flat_last; + if (is_colwise) { + // Colwise: [K, M_dims...] → [K, M] + flat_first = d_shape.empty() ? 1 : d_shape[0]; + flat_last = 1; + for (size_t i = 1; i < d_shape.size(); ++i) flat_last *= d_shape[i]; + } else { + // Rowwise: [M_dims..., K] → [M, K] + flat_first = 1; + for (size_t i = 0; i + 1 < d_shape.size(); ++i) flat_first *= d_shape[i]; + flat_last = d_shape.empty() ? 1 : d_shape.back(); } + std::vector data_2d = {flat_first, flat_last}; + auto si_shape = getStableTensorShape(mxfp8_si); + TensorWrapper out(NVTE_MXFP8_1D_SCALING); + out.set_rowwise_data(sel_data.data_ptr(), dtype, data_2d); + out.set_rowwise_scale_inv(mxfp8_si.data_ptr(), DType::kFloat8E8M0, si_shape); + out.set_with_gemm_swizzled_scales(true); + return out; + }; - auto te_D = makeTransformerEngineTensor(out_tensor); - auto te_bias = makeTransformerEngineTensor(bias[i]); - auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]); - - const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr - ? std::vector{static_cast(te_pre_gelu_out.size(0))} - : std::vector{static_cast(te_pre_gelu_out.size(0)), - static_cast(te_pre_gelu_out.size(1))}; - - DType gelu_type = bias_type; - te_pre_gelu_out = - makeTransformerEngineTensor(get_data_ptr(pre_gelu_out[i]), gelu_shape, gelu_type); - - te_A_wrappers.emplace_back(std::move(te_A)); - te_B_wrappers.emplace_back(std::move(te_B)); - te_D_wrappers.emplace_back(std::move(te_D)); - te_bias_wrappers.emplace_back(std::move(te_bias)); - te_pre_gelu_out_wrappers.emplace_back(std::move(te_pre_gelu_out)); - } - - // Keep the swizzled scaling factor tensors alive during the GEMM. - std::vector> swizzled_scale_inverses_list; - - // Optionally swizzle the scaling factors - swizzled_scale_inverses_list.emplace_back( - multi_tensor_swizzle_scales_for_gemm(te_A_wrappers, transa, !transa)); - swizzled_scale_inverses_list.emplace_back( - multi_tensor_swizzle_scales_for_gemm(te_B_wrappers, !transb, transb)); - - // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer - // as it is not natively supported by cublasLt - if (transformer_engine::cuda::sm_arch() >= 100) { - // Check if is using FP8 block scaling - bool exists_tensor_using_fp8_block_scaling = false; - bool exists_tensor_not_using_fp8_block_scaling = false; - for (const auto& tensor_wrappers : {&te_A_wrappers, &te_B_wrappers}) { - for (const TensorWrapper& tensor : *tensor_wrappers) { - const NVTEScalingMode scaling_mode = tensor.scaling_mode(); - if (scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) - exists_tensor_using_fp8_block_scaling = true; - else - exists_tensor_not_using_fp8_block_scaling = true; - } - } - if (exists_tensor_using_fp8_block_scaling) { - NVTE_CHECK(!exists_tensor_not_using_fp8_block_scaling, - "Either all tensors or no tensor must be FP8 block scaling tensors"); - // Convert tensors to mxfp8 and swizzle their scaling factors - for (TensorWrapper& A_tensor : te_A_wrappers) { - swizzled_scale_inverses_list.emplace_back( - convert_block_scaling_to_mxfp8_tensor(A_tensor, transa)); - } - for (TensorWrapper& B_tensor : te_B_wrappers) { - swizzled_scale_inverses_list.emplace_back( - convert_block_scaling_to_mxfp8_tensor(B_tensor, !transb)); + TensorWrapper A_tensor, B_tensor; + + if (fp8_block_scaling && transformer_engine::cuda::sm_arch() >= 100) { + // Convert block scaling to MXFP8 and build TensorWrappers directly. + // Select the direction CanonicalizeGemmInput will use, convert its + // scale format, then force TN layout (matching pybind path). + auto convertAndBuild = [&](Tensor& data, DType dtype, std::optional& si, + std::optional& cw_data, std::optional& cw_si, + NVTEScalingMode sm, bool use_rowwise) -> TensorWrapper { + if (sm != NVTE_BLOCK_SCALING_1D && sm != NVTE_BLOCK_SCALING_2D) { + // Not block scaling — build normally + auto t = buildInputTensorWrapper(data, dtype, si, cw_data, cw_si, sm); + return t; } - // Use TN GEMM to avoid having to transpose data. + int dev = data.numel() > 0 ? data.get_device_index() + : (cw_data.has_value() ? cw_data->get_device_index() : 0); + auto& sel_data = use_rowwise ? data : *cw_data; + auto& sel_si = use_rowwise ? si : cw_si; + auto sw = convertOneDirection(sel_data, dtype, *sel_si, sm, dev, use_rowwise ? false : true); + block_to_mxfp8_buffers.push_back(sw); + return buildMxfp8FromBlock(sel_data, dtype, sw, !use_rowwise); + }; + + // Only convert and force TN when BOTH tensors are block-scaling + const bool A_is_block = (A_sm == NVTE_BLOCK_SCALING_1D || A_sm == NVTE_BLOCK_SCALING_2D); + const bool B_is_block = (B_sm == NVTE_BLOCK_SCALING_1D || B_sm == NVTE_BLOCK_SCALING_2D); + if (A_is_block && B_is_block) { + A_tensor = convertAndBuild(A_data, A_te, A_scale_inv, A_colwise_data, A_colwise_scale_inv, + A_sm, transa); + B_tensor = convertAndBuild(B_data, B_te, B_scale_inv, B_colwise_data, B_colwise_scale_inv, + B_sm, !transb); transa = true; transb = false; + } else { + // Mixed: one is block scaling, one isn't — shouldn't happen but handle gracefully + A_tensor = buildInputTensorWrapper(A_data, A_te, A_scale_inv, A_colwise_data, + A_colwise_scale_inv, A_sm, A_amax, A_colwise_amax); + B_tensor = buildInputTensorWrapper(B_data, B_te, B_scale_inv, B_colwise_data, + B_colwise_scale_inv, B_sm, B_amax, B_colwise_amax); } + } else { + A_tensor = buildInputTensorWrapper(A_data, A_te, A_scale_inv, A_colwise_data, + A_colwise_scale_inv, A_sm, A_amax, A_colwise_amax); + B_tensor = buildInputTensorWrapper(B_data, B_te, B_scale_inv, B_colwise_data, + B_colwise_scale_inv, B_sm, B_amax, B_colwise_amax); } - std::vector te_A_vector, te_B_vector, te_D_vector, te_bias_vector, - te_pre_gelu_out_vector; - for (size_t i = 0; i < te_A_wrappers.size(); i++) { - te_A_vector.emplace_back(te_A_wrappers[i].data()); - te_B_vector.emplace_back(te_B_wrappers[i].data()); - te_D_vector.emplace_back(te_D_wrappers[i].data()); - te_bias_vector.emplace_back(te_bias_wrappers[i].data()); - te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out_wrappers[i].data()); + auto D_tensor = + makeQuantizedTensorWrapper(D_data, D_te, D_shape, D_amax, D_scale, D_scale_inv, D_sm); + // For block-to-MXFP8 path, swizzled flag is already set by buildMxfp8FromBlock. + // For non-block path, set from the Python-side flag. + if (!fp8_block_scaling || transformer_engine::cuda::sm_arch() < 100) { + A_tensor.set_with_gemm_swizzled_scales(A_with_gemm_swizzled_scales); + B_tensor.set_with_gemm_swizzled_scales(B_with_gemm_swizzled_scales); } - std::vector te_workspace_vector; - std::vector te_workspace_wrappers; - for (size_t i = 0; i < workspace.size(); i++) { - auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), - std::vector{workspaceSize}, DType::kByte); - te_workspace_vector.emplace_back(wsp.data()); - te_workspace_wrappers.emplace_back(std::move(wsp)); + TensorWrapper bias_tensor; + if (bias.has_value()) { + auto bias_te = static_cast(bias_type); + auto bias_shape = getStableTensorShape(bias.value()); + bias_tensor = makeTransformerEngineTensor(bias->data_ptr(), bias_shape, bias_te); } - // For now, we only have multi-stream cublas backend. - NVTE_SCOPED_GIL_RELEASE({ - nvte_multi_tensor_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), - te_bias_vector.data(), te_pre_gelu_out_vector.data(), te_A_vector.size(), - transa, transb, grad, te_workspace_vector.data(), accumulate, - use_split_accumulator, math_sm_count, at::cuda::getCurrentCUDAStream()); - }); - return bias; -} - -py::object te_general_grouped_gemm_for_grouped_tensor( - py::handle A, bool transa, py::handle B, bool transb, py::handle D, py::object bias, - at::Tensor alpha, at::Tensor beta, at::Tensor workspace_setup, at::Tensor workspace_cublas, - bool use_split_accumulator, int math_sm_count) { - using namespace transformer_engine::pytorch::detail; - - init_extension(); - - // Ensure that cublasLt handle is created on the correct device, - // overriding torch.cuda.set_device calls from user side. - // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(workspace_cublas.device()); - - auto grouped_A = GroupedTensorFromPyTorchGroupedTensor(A); - auto grouped_B = GroupedTensorFromPyTorchGroupedTensor(B); - auto grouped_D = GroupedTensorFromPyTorchGroupedTensor(D); - - const size_t num_tensors = grouped_A.num_tensors(); - NVTE_CHECK(num_tensors > 0, "Grouped GEMM requires non-empty inputs."); - NVTE_CHECK(grouped_B.num_tensors() == num_tensors, - "Grouped GEMM requires A and B to have the same num_tensors."); - NVTE_CHECK(grouped_D.num_tensors() == num_tensors, - "Grouped GEMM requires D to have the same num_tensors as inputs."); - - auto gemm_config = prepare_grouped_gemm_config(alpha, beta, workspace_setup, workspace_cublas, - num_tensors, math_sm_count, use_split_accumulator); - - [[maybe_unused]] auto swizzled_scales_A = maybe_swizzle_grouped_tensor_for_gemm(grouped_A); - [[maybe_unused]] auto swizzled_scales_B = maybe_swizzle_grouped_tensor_for_gemm(grouped_B); - - NVTE_SCOPED_GIL_RELEASE({ - nvte_grouped_gemm(grouped_A.data(), transa, grouped_B.data(), transb, grouped_D.data(), - grouped_D.data(), gemm_config.te_alpha.data(), gemm_config.te_beta.data(), - gemm_config.te_workspace_setup.data(), gemm_config.te_workspace_cublas.data(), - gemm_config.matmul_config.has_value() - ? static_cast(*gemm_config.matmul_config) - : nullptr, - at::cuda::getCurrentCUDAStream()); - }); - - if (!bias.is_none()) { - auto grouped_bias = GroupedTensorFromPyTorchGroupedTensor(bias); - NVTE_SCOPED_GIL_RELEASE({ - nvte_grouped_bias_add(grouped_D.data(), grouped_bias.data(), - at::cuda::getCurrentCUDAStream()); - }); + TensorWrapper pre_gelu_tensor; + if (pre_gelu_out.has_value()) { + pre_gelu_tensor = makeTransformerEngineTensor(pre_gelu_out.value()); } - return py::reinterpret_borrow(D); -} + auto ws_tensor = makeTransformerEngineTensor(workspace); -py::object te_general_grouped_gemm_for_discrete_in(py::handle A, bool transa, py::handle B, - bool transb, py::handle D, py::object bias, - at::Tensor alpha, at::Tensor beta, - at::Tensor workspace_setup, - at::Tensor workspace_cublas, - bool use_split_accumulator, int math_sm_count) { - using namespace transformer_engine::pytorch::detail; - - init_extension(); - - // Ensure that cublasLt handle is created on the correct device, - // overriding torch.cuda.set_device calls from user side. - // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(workspace_cublas.device()); - - auto grouped_B = GroupedTensorFromPyTorchGroupedTensor(B); - auto grouped_D = GroupedTensorFromPyTorchGroupedTensor(D); - - const auto A_list = py::cast>(A); - const size_t num_tensors = grouped_B.num_tensors(); - NVTE_CHECK(num_tensors > 0, "Grouped GEMM requires non-empty inputs."); - NVTE_CHECK(A_list.size() == num_tensors, - "Grouped GEMM requires A_list to have num_tensors elements."); - NVTE_CHECK(grouped_D.num_tensors() == num_tensors, - "Grouped GEMM requires D to have the same num_tensors as inputs."); - - auto gemm_config = prepare_grouped_gemm_config(alpha, beta, workspace_setup, workspace_cublas, - num_tensors, math_sm_count, use_split_accumulator); - - std::vector te_A_wrappers; - std::vector te_A_vector; - te_A_wrappers.reserve(num_tensors); - te_A_vector.reserve(num_tensors); - const auto none = py::none(); - for (const auto& tensor : A_list) { - te_A_wrappers.emplace_back(makeTransformerEngineTensor(tensor, none)); - te_A_vector.emplace_back(te_A_wrappers.back().data()); - } + auto device_idx = A_data.get_device_index(); + auto stream = getCurrentCUDAStreamRaw(device_idx); + + float alpha_f = static_cast(alpha); + float beta_f = accumulate ? 1.0f : 0.0f; - std::vector> swizzled_scale_inverses_list; - swizzled_scale_inverses_list.emplace_back( - multi_tensor_swizzle_scales_for_gemm(te_A_wrappers, transa, !transa)); - - [[maybe_unused]] auto swizzled_scales_B = maybe_swizzle_grouped_tensor_for_gemm(grouped_B); - - NVTE_SCOPED_GIL_RELEASE({ - nvte_grouped_gemm_with_discrete_inputA( - te_A_vector.data(), num_tensors, transa, grouped_B.data(), transb, grouped_D.data(), - grouped_D.data(), gemm_config.te_alpha.data(), gemm_config.te_beta.data(), - gemm_config.te_workspace_setup.data(), gemm_config.te_workspace_cublas.data(), - gemm_config.matmul_config.has_value() - ? static_cast(*gemm_config.matmul_config) - : nullptr, - at::cuda::getCurrentCUDAStream()); - }); - - if (!bias.is_none()) { - auto grouped_bias = GroupedTensorFromPyTorchGroupedTensor(bias); - NVTE_SCOPED_GIL_RELEASE({ - nvte_grouped_bias_add(grouped_D.data(), grouped_bias.data(), - at::cuda::getCurrentCUDAStream()); - }); + // Configure GEMM + MatmulConfigWrapper config; + bool gelu_flag = pre_gelu_out.has_value() && pre_gelu_out->numel() > 0; + if (bias.has_value()) { + if (grad) { + config.set_dbias_tensor(bias_tensor.data()); + } else { + config.set_bias_tensor(bias_tensor.data()); + } + } + if (grad) { + config.set_with_dgelu_epilogue(gelu_flag); + } else { + config.set_with_gelu_epilogue(gelu_flag); } + if (pre_gelu_out.has_value()) { + config.set_epilogue_aux_tensor(pre_gelu_tensor.data()); + } + config.set_use_split_accumulator(use_split_accumulator); - return py::reinterpret_borrow(D); + nvte_cublas_gemm_v2(transa, transb, &alpha_f, A_tensor.data(), B_tensor.data(), &beta_f, + D_tensor.data(), D_tensor.data(), ws_tensor.data(), config, stream); } -py::object te_general_grouped_gemm_for_discrete_out(py::handle A, bool transa, py::handle B, - bool transb, py::handle D, py::object bias, - at::Tensor alpha, at::Tensor beta, - at::Tensor workspace_setup, - at::Tensor workspace_cublas, - bool use_split_accumulator, int math_sm_count) { - using namespace transformer_engine::pytorch::detail; - - init_extension(); - - // Ensure that cublasLt handle is created on the correct device, - // overriding torch.cuda.set_device calls from user side. - // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(workspace_cublas.device()); - - NVTE_CHECK(bias.is_none(), "Bias is not supported for discrete output grouped GEMM."); - - auto grouped_A = GroupedTensorFromPyTorchGroupedTensor(A); - auto grouped_B = GroupedTensorFromPyTorchGroupedTensor(B); - - const auto D_list = py::cast>(D); - const size_t num_tensors = grouped_A.num_tensors(); - NVTE_CHECK(num_tensors > 0, "Grouped GEMM requires non-empty inputs."); - NVTE_CHECK(grouped_B.num_tensors() == num_tensors, - "Grouped GEMM requires A and B to have the same num_tensors."); - NVTE_CHECK(D_list.size() == num_tensors, - "Grouped GEMM requires D_list to have num_tensors elements."); - - auto gemm_config = prepare_grouped_gemm_config(alpha, beta, workspace_setup, workspace_cublas, - num_tensors, math_sm_count, use_split_accumulator); - - std::vector te_D_wrappers; - std::vector te_D_vector; - te_D_wrappers.reserve(num_tensors); - te_D_vector.reserve(num_tensors); - const auto none = py::none(); - for (const auto& tensor : D_list) { - te_D_wrappers.emplace_back(makeTransformerEngineTensor(tensor, none)); - te_D_vector.emplace_back(te_D_wrappers.back().data()); - } - - [[maybe_unused]] auto swizzled_scales_A = maybe_swizzle_grouped_tensor_for_gemm(grouped_A); - [[maybe_unused]] auto swizzled_scales_B = maybe_swizzle_grouped_tensor_for_gemm(grouped_B); - - NVTE_SCOPED_GIL_RELEASE({ - nvte_grouped_gemm_with_discrete_out( - grouped_A.data(), transa, grouped_B.data(), transb, te_D_vector.data(), num_tensors, - te_D_vector.data(), num_tensors, gemm_config.te_alpha.data(), gemm_config.te_beta.data(), - gemm_config.te_workspace_setup.data(), gemm_config.te_workspace_cublas.data(), - gemm_config.matmul_config.has_value() - ? static_cast(*gemm_config.matmul_config) - : nullptr, - at::cuda::getCurrentCUDAStream()); - }); - - return py::reinterpret_borrow(D); +} // namespace transformer_engine::pytorch::stable + +STABLE_TORCH_LIBRARY_FRAGMENT(transformer_engine_stable, m) { + m.def( + "gemm(" + "Tensor A_data, int A_te_dtype, Tensor? A_scale_inv, " + "Tensor? A_colwise_data, Tensor? A_colwise_scale_inv, " + "int A_scaling_mode, bool A_with_gemm_swizzled_scales, bool transa, " + "Tensor B_data, int B_te_dtype, Tensor? B_scale_inv, " + "Tensor? B_colwise_data, Tensor? B_colwise_scale_inv, " + "int B_scaling_mode, bool B_with_gemm_swizzled_scales, bool transb, " + "Tensor D_data, int D_te_dtype, Tensor? D_amax, Tensor? D_scale, Tensor? D_scale_inv, " + "int D_scaling_mode, Tensor? bias, int bias_type, Tensor? pre_gelu_out, " + "Tensor workspace, bool grad, bool accumulate, bool use_split_accumulator, " + "float alpha, " + "Tensor? A_amax=None, Tensor? A_colwise_amax=None, " + "Tensor? B_amax=None, Tensor? B_colwise_amax=None) -> ()"); + m.def( + "swizzle_scale_for_gemm(Tensor data, Tensor scale_inv, int te_dtype, int scaling_mode) -> " + "Tensor"); } -} // namespace transformer_engine::pytorch +STABLE_TORCH_LIBRARY_IMPL(transformer_engine_stable, CUDA, m) { + using namespace transformer_engine::pytorch::stable; + m.impl("gemm", TORCH_BOX(gemm)); + m.impl("swizzle_scale_for_gemm", TORCH_BOX(swizzleScaleForGemm)); +} diff --git a/transformer_engine/pytorch/csrc/extensions/grouped_gemm.cpp b/transformer_engine/pytorch/csrc/extensions/grouped_gemm.cpp new file mode 100644 index 0000000000..9c0a3d63b6 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/grouped_gemm.cpp @@ -0,0 +1,458 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include "../stable_common.h" + +namespace transformer_engine::pytorch::stable { + +using Tensor = torch::stable::Tensor; + +namespace { + +// ============================================================================ +// GroupedTensorWrapper construction helpers +// ============================================================================ + +// Return the DType used for scale_inv given a scaling mode. +DType getGroupedScaleInvDtype(NVTEScalingMode scaling_mode) { + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: + return DType::kFloat8E8M0; + case NVTE_NVFP4_1D_SCALING: + return DType::kFloat8E4M3; + default: + return DType::kFloat32; + } +} + +// Build a GroupedTensorWrapper from the flat buffer tensors extracted from a +// Python GroupedTensor. All optional Tensor? args may have numel()==0 to +// indicate "not present". +GroupedTensorWrapper buildGroupedTensorWrapper( + const std::optional& rowwise_data, const std::optional& colwise_data, + const std::optional& scale_inv, const std::optional& colwise_scale_inv, + const std::optional& first_dims, const std::optional& last_dims, + const std::optional& tensor_offsets, int64_t te_dtype, int64_t scaling_mode, + int64_t logical_0, int64_t logical_1, int64_t num_tensors, bool with_gemm_swizzled_scales) { + auto dtype = static_cast(te_dtype); + auto sm = static_cast(scaling_mode); + DType si_dtype = getGroupedScaleInvDtype(sm); + + std::vector logical_shape = {static_cast(logical_0), + static_cast(logical_1)}; + GroupedTensorWrapper gtw(static_cast(num_tensors), logical_shape, sm); + + if (rowwise_data.has_value() && rowwise_data->numel() > 0) { + auto shape = getStableTensorShape(*rowwise_data); + gtw.set_rowwise_data(rowwise_data->data_ptr(), dtype, shape); + } + if (colwise_data.has_value() && colwise_data->numel() > 0) { + auto shape = getStableTensorShape(*colwise_data); + gtw.set_columnwise_data(colwise_data->data_ptr(), dtype, shape); + } + if (scale_inv.has_value() && scale_inv->numel() > 0) { + auto shape = getStableTensorShape(*scale_inv); + gtw.set_rowwise_scale_inv(scale_inv->data_ptr(), si_dtype, shape); + } + if (colwise_scale_inv.has_value() && colwise_scale_inv->numel() > 0) { + auto shape = getStableTensorShape(*colwise_scale_inv); + gtw.set_columnwise_scale_inv(colwise_scale_inv->data_ptr(), si_dtype, shape); + } + if (first_dims.has_value() && first_dims->numel() > 0) { + auto shape = getStableTensorShape(*first_dims); + gtw.set_first_dims(first_dims->data_ptr(), DType::kInt64, shape); + } + if (last_dims.has_value() && last_dims->numel() > 0) { + auto shape = getStableTensorShape(*last_dims); + gtw.set_last_dims(last_dims->data_ptr(), DType::kInt64, shape); + } + if (tensor_offsets.has_value() && tensor_offsets->numel() > 0) { + auto shape = getStableTensorShape(*tensor_offsets); + gtw.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, shape); + } + gtw.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); + return gtw; +} + +// Build a GroupedMatmulConfigWrapper with the given options. +GroupedMatmulConfigWrapper buildGroupedGemmConfig(bool use_split_accumulator, int64_t sm_count) { + GroupedMatmulConfigWrapper config; + config.set_use_split_accumulator(use_split_accumulator); + if (sm_count > 0) { + config.set_sm_count(static_cast(sm_count)); + } + return config; +} + +} // namespace + +// ============================================================================ +// grouped_gemm_for_grouped_tensor +// +// Wraps nvte_grouped_gemm (Blackwell+). +// A, B, D are Python GroupedTensors whose flat buffers are passed individually. +// Optional bias GroupedTensor is controlled by has_bias. +// ============================================================================ + +void grouped_gemm_for_grouped_tensor( + // A (GroupedTensor) — 13 fields + transa + std::optional A_rowwise, std::optional A_colwise, std::optional A_si, + std::optional A_colwise_si, std::optional A_first_dims, + std::optional A_last_dims, std::optional A_tensor_offsets, int64_t A_te_dtype, + int64_t A_scaling_mode, int64_t A_logical_0, int64_t A_logical_1, int64_t A_num_tensors, + bool A_swizzled, bool transa, + // B (GroupedTensor) — 13 fields + transb + std::optional B_rowwise, std::optional B_colwise, std::optional B_si, + std::optional B_colwise_si, std::optional B_first_dims, + std::optional B_last_dims, std::optional B_tensor_offsets, int64_t B_te_dtype, + int64_t B_scaling_mode, int64_t B_logical_0, int64_t B_logical_1, int64_t B_num_tensors, + bool B_swizzled, bool transb, + // D (GroupedTensor) — 13 fields (no trans) + std::optional D_rowwise, std::optional D_colwise, std::optional D_si, + std::optional D_colwise_si, std::optional D_first_dims, + std::optional D_last_dims, std::optional D_tensor_offsets, int64_t D_te_dtype, + int64_t D_scaling_mode, int64_t D_logical_0, int64_t D_logical_1, int64_t D_num_tensors, + // Config + Tensor alpha, Tensor beta, Tensor workspace_setup, Tensor workspace_cublas, + bool use_split_accumulator, int64_t sm_count, + // Optional bias (GroupedTensor) — 13 fields, guarded by has_bias + bool has_bias, std::optional bias_rowwise, std::optional bias_colwise, + std::optional bias_si, std::optional bias_colwise_si, + std::optional bias_first_dims, std::optional bias_last_dims, + std::optional bias_tensor_offsets, int64_t bias_te_dtype, int64_t bias_scaling_mode, + int64_t bias_logical_0, int64_t bias_logical_1, int64_t bias_num_tensors, bool bias_swizzled) { + auto A_gt = buildGroupedTensorWrapper(A_rowwise, A_colwise, A_si, A_colwise_si, A_first_dims, + A_last_dims, A_tensor_offsets, A_te_dtype, A_scaling_mode, + A_logical_0, A_logical_1, A_num_tensors, A_swizzled); + auto B_gt = buildGroupedTensorWrapper(B_rowwise, B_colwise, B_si, B_colwise_si, B_first_dims, + B_last_dims, B_tensor_offsets, B_te_dtype, B_scaling_mode, + B_logical_0, B_logical_1, B_num_tensors, B_swizzled); + auto D_gt = buildGroupedTensorWrapper(D_rowwise, D_colwise, D_si, D_colwise_si, D_first_dims, + D_last_dims, D_tensor_offsets, D_te_dtype, D_scaling_mode, + D_logical_0, D_logical_1, D_num_tensors, false); + + auto alpha_tw = makeTransformerEngineTensor(alpha); + auto beta_tw = makeTransformerEngineTensor(beta); + auto ws_setup_tw = makeTransformerEngineTensor(workspace_setup); + auto ws_cublas_tw = makeTransformerEngineTensor(workspace_cublas); + + auto config = buildGroupedGemmConfig(use_split_accumulator, sm_count); + + // Determine device from whichever data buffer is present + int device_idx = 0; + if (A_rowwise.has_value() && A_rowwise->numel() > 0) + device_idx = A_rowwise->get_device_index(); + else if (B_rowwise.has_value() && B_rowwise->numel() > 0) + device_idx = B_rowwise->get_device_index(); + + nvte_grouped_gemm(A_gt.data(), static_cast(transa), B_gt.data(), static_cast(transb), + D_gt.data(), D_gt.data(), alpha_tw.data(), beta_tw.data(), ws_setup_tw.data(), + ws_cublas_tw.data(), static_cast(config), + getCurrentCUDAStreamRaw(device_idx)); + + if (has_bias) { + auto bias_gt = buildGroupedTensorWrapper(bias_rowwise, bias_colwise, bias_si, bias_colwise_si, + bias_first_dims, bias_last_dims, bias_tensor_offsets, + bias_te_dtype, bias_scaling_mode, bias_logical_0, + bias_logical_1, bias_num_tensors, bias_swizzled); + nvte_grouped_bias_add(D_gt.data(), bias_gt.data(), getCurrentCUDAStreamRaw(device_idx)); + } +} + +// ============================================================================ +// grouped_gemm_for_discrete_in +// +// Wraps nvte_grouped_gemm_with_discrete_inputA (Blackwell+). +// A is provided as packed pointer arrays (rowwise_ptrs, colwise_ptrs, etc.) +// B and D are GroupedTensors. +// ============================================================================ + +void grouped_gemm_for_discrete_in( + // A — packed pointer arrays, one entry per expert tensor + Tensor A_rowwise_ptrs, Tensor A_colwise_ptrs, Tensor A_si_ptrs, Tensor A_csi_ptrs, + Tensor A_shapes, // (num_a, 2) int64: [rows, cols] per tensor + Tensor A_te_dtypes, // (num_a,) int32 + Tensor A_scaling_modes, // (num_a,) int32 + int64_t num_a_tensors, + // B (GroupedTensor) — 13 fields + transb + std::optional B_rowwise, std::optional B_colwise, std::optional B_si, + std::optional B_colwise_si, std::optional B_first_dims, + std::optional B_last_dims, std::optional B_tensor_offsets, int64_t B_te_dtype, + int64_t B_scaling_mode, int64_t B_logical_0, int64_t B_logical_1, int64_t B_num_tensors, + bool B_swizzled, bool transb, + // D (GroupedTensor) — 13 fields + std::optional D_rowwise, std::optional D_colwise, std::optional D_si, + std::optional D_colwise_si, std::optional D_first_dims, + std::optional D_last_dims, std::optional D_tensor_offsets, int64_t D_te_dtype, + int64_t D_scaling_mode, int64_t D_logical_0, int64_t D_logical_1, int64_t D_num_tensors, + // Config + Tensor alpha, Tensor beta, Tensor workspace_setup, Tensor workspace_cublas, + bool use_split_accumulator, int64_t sm_count, + // Optional bias + bool has_bias, std::optional bias_rowwise, std::optional bias_colwise, + std::optional bias_si, std::optional bias_colwise_si, + std::optional bias_first_dims, std::optional bias_last_dims, + std::optional bias_tensor_offsets, int64_t bias_te_dtype, int64_t bias_scaling_mode, + int64_t bias_logical_0, int64_t bias_logical_1, int64_t bias_num_tensors, bool bias_swizzled) { + // Build the individual A TensorWrappers from the packed pointer arrays + const auto* rw_ptrs = reinterpret_cast(A_rowwise_ptrs.data_ptr()); + const auto* cw_ptrs = reinterpret_cast(A_colwise_ptrs.data_ptr()); + const auto* si_ptrs = reinterpret_cast(A_si_ptrs.data_ptr()); + const auto* csi_ptrs = reinterpret_cast(A_csi_ptrs.data_ptr()); + const auto* shapes = reinterpret_cast(A_shapes.data_ptr()); + const auto* dtypes = reinterpret_cast(A_te_dtypes.data_ptr()); + const auto* modes = reinterpret_cast(A_scaling_modes.data_ptr()); + + std::vector A_wrappers; + std::vector A_nvte; + A_wrappers.reserve(static_cast(num_a_tensors)); + A_nvte.reserve(static_cast(num_a_tensors)); + + for (int64_t i = 0; i < num_a_tensors; ++i) { + auto dtype = static_cast(dtypes[i]); + auto sm = static_cast(modes[i]); + DType si_dtype = getGroupedScaleInvDtype(sm); + int64_t rows = shapes[2 * i]; + int64_t cols = shapes[2 * i + 1]; + std::vector shape = {static_cast(rows), static_cast(cols)}; + + TensorWrapper tw(sm); + if (rw_ptrs[i] != 0) { + tw.set_rowwise_data(reinterpret_cast(rw_ptrs[i]), dtype, shape); + } + if (cw_ptrs[i] != 0) { + tw.set_columnwise_data(reinterpret_cast(cw_ptrs[i]), dtype, shape); + } + if (si_ptrs[i] != 0) { + // Scale shape for a (rows, cols) tensor: (rows, ceil(cols/block)) etc. + // We pass a placeholder shape; NVTE uses the tensor's logical shape for scale counts. + std::vector si_shape = {static_cast(rows)}; + tw.set_rowwise_scale_inv(reinterpret_cast(si_ptrs[i]), si_dtype, si_shape); + } + if (csi_ptrs[i] != 0) { + std::vector csi_shape = {static_cast(cols)}; + tw.set_columnwise_scale_inv(reinterpret_cast(csi_ptrs[i]), si_dtype, csi_shape); + } + A_wrappers.emplace_back(std::move(tw)); + A_nvte.emplace_back(A_wrappers.back().data()); + } + + // NOTE: transa is not passed to this function per the NVTE API convention — it is + // implied by the physical layout set via set_rowwise_data vs set_columnwise_data. + // The Python side calls _extract_gemm_operand(Ai, transa) so the right buffer + // (rowwise or columnwise) is already selected. We always pass transa=false here + // since CanonicalizeGemmInput in the C++ kernel handles orientation. + int transa_int = 0; + + auto B_gt = buildGroupedTensorWrapper(B_rowwise, B_colwise, B_si, B_colwise_si, B_first_dims, + B_last_dims, B_tensor_offsets, B_te_dtype, B_scaling_mode, + B_logical_0, B_logical_1, B_num_tensors, B_swizzled); + auto D_gt = buildGroupedTensorWrapper(D_rowwise, D_colwise, D_si, D_colwise_si, D_first_dims, + D_last_dims, D_tensor_offsets, D_te_dtype, D_scaling_mode, + D_logical_0, D_logical_1, D_num_tensors, false); + + auto alpha_tw = makeTransformerEngineTensor(alpha); + auto beta_tw = makeTransformerEngineTensor(beta); + auto ws_setup_tw = makeTransformerEngineTensor(workspace_setup); + auto ws_cublas_tw = makeTransformerEngineTensor(workspace_cublas); + auto config = buildGroupedGemmConfig(use_split_accumulator, sm_count); + + int device_idx = A_rowwise_ptrs.get_device_index(); + + nvte_grouped_gemm_with_discrete_inputA( + A_nvte.data(), static_cast(num_a_tensors), transa_int, B_gt.data(), + static_cast(transb), D_gt.data(), D_gt.data(), alpha_tw.data(), beta_tw.data(), + ws_setup_tw.data(), ws_cublas_tw.data(), static_cast(config), + getCurrentCUDAStreamRaw(device_idx)); + + if (has_bias) { + auto bias_gt = buildGroupedTensorWrapper(bias_rowwise, bias_colwise, bias_si, bias_colwise_si, + bias_first_dims, bias_last_dims, bias_tensor_offsets, + bias_te_dtype, bias_scaling_mode, bias_logical_0, + bias_logical_1, bias_num_tensors, bias_swizzled); + nvte_grouped_bias_add(D_gt.data(), bias_gt.data(), getCurrentCUDAStreamRaw(device_idx)); + } +} + +// ============================================================================ +// grouped_gemm_for_discrete_out +// +// Wraps nvte_grouped_gemm_with_discrete_out (Blackwell+). +// A and B are GroupedTensors; D is provided as packed pointer arrays. +// ============================================================================ + +void grouped_gemm_for_discrete_out( + // A (GroupedTensor) — 13 fields + transa + std::optional A_rowwise, std::optional A_colwise, std::optional A_si, + std::optional A_colwise_si, std::optional A_first_dims, + std::optional A_last_dims, std::optional A_tensor_offsets, int64_t A_te_dtype, + int64_t A_scaling_mode, int64_t A_logical_0, int64_t A_logical_1, int64_t A_num_tensors, + bool A_swizzled, bool transa, + // B (GroupedTensor) — 13 fields + transb + std::optional B_rowwise, std::optional B_colwise, std::optional B_si, + std::optional B_colwise_si, std::optional B_first_dims, + std::optional B_last_dims, std::optional B_tensor_offsets, int64_t B_te_dtype, + int64_t B_scaling_mode, int64_t B_logical_0, int64_t B_logical_1, int64_t B_num_tensors, + bool B_swizzled, bool transb, + // D — packed pointer arrays + Tensor D_rowwise_ptrs, Tensor D_si_ptrs, + Tensor D_shapes, // (num_d, 2) int64: [rows, cols] per tensor + Tensor D_te_dtypes, // (num_d,) int32 + Tensor D_scaling_modes, // (num_d,) int32 + int64_t num_d_tensors, + // Config + Tensor alpha, Tensor beta, Tensor workspace_setup, Tensor workspace_cublas, + bool use_split_accumulator, int64_t sm_count) { + auto A_gt = buildGroupedTensorWrapper(A_rowwise, A_colwise, A_si, A_colwise_si, A_first_dims, + A_last_dims, A_tensor_offsets, A_te_dtype, A_scaling_mode, + A_logical_0, A_logical_1, A_num_tensors, A_swizzled); + auto B_gt = buildGroupedTensorWrapper(B_rowwise, B_colwise, B_si, B_colwise_si, B_first_dims, + B_last_dims, B_tensor_offsets, B_te_dtype, B_scaling_mode, + B_logical_0, B_logical_1, B_num_tensors, B_swizzled); + + // Build D TensorWrapper array from packed pointers + const auto* rw_ptrs = reinterpret_cast(D_rowwise_ptrs.data_ptr()); + const auto* si_ptrs = reinterpret_cast(D_si_ptrs.data_ptr()); + const auto* shapes = reinterpret_cast(D_shapes.data_ptr()); + const auto* dtypes = reinterpret_cast(D_te_dtypes.data_ptr()); + + std::vector D_wrappers; + std::vector D_nvte; + D_wrappers.reserve(static_cast(num_d_tensors)); + D_nvte.reserve(static_cast(num_d_tensors)); + + for (int64_t i = 0; i < num_d_tensors; ++i) { + auto dtype = static_cast(dtypes[i]); + int64_t rows = shapes[2 * i]; + int64_t cols = shapes[2 * i + 1]; + std::vector shape = {static_cast(rows), static_cast(cols)}; + + TensorWrapper tw; + if (rw_ptrs[i] != 0) { + tw.set_rowwise_data(reinterpret_cast(rw_ptrs[i]), dtype, shape); + } + if (si_ptrs[i] != 0) { + std::vector si_shape = {static_cast(rows)}; + tw.set_rowwise_scale_inv(reinterpret_cast(si_ptrs[i]), DType::kFloat32, si_shape); + } + D_wrappers.emplace_back(std::move(tw)); + D_nvte.emplace_back(D_wrappers.back().data()); + } + + auto alpha_tw = makeTransformerEngineTensor(alpha); + auto beta_tw = makeTransformerEngineTensor(beta); + auto ws_setup_tw = makeTransformerEngineTensor(workspace_setup); + auto ws_cublas_tw = makeTransformerEngineTensor(workspace_cublas); + auto config = buildGroupedGemmConfig(use_split_accumulator, sm_count); + + int device_idx = 0; + if (A_rowwise.has_value() && A_rowwise->numel() > 0) device_idx = A_rowwise->get_device_index(); + + nvte_grouped_gemm_with_discrete_out( + A_gt.data(), static_cast(transa), B_gt.data(), static_cast(transb), D_nvte.data(), + static_cast(num_d_tensors), D_nvte.data(), static_cast(num_d_tensors), + alpha_tw.data(), beta_tw.data(), ws_setup_tw.data(), ws_cublas_tw.data(), + static_cast(config), getCurrentCUDAStreamRaw(device_idx)); +} + +} // namespace transformer_engine::pytorch::stable + +// ============================================================================ +// Op registration +// ============================================================================ + +STABLE_TORCH_LIBRARY_FRAGMENT(transformer_engine_stable, m) { + // grouped_gemm_for_grouped_tensor: A(13) + transa + B(13) + transb + D(13) + + // alpha + beta + ws_setup + ws_cublas + use_split_accum + sm_count + + // has_bias + bias(13) = 58 args total + m.def( + "grouped_gemm_for_grouped_tensor(" + // A + "Tensor? A_rowwise, Tensor? A_colwise, Tensor? A_si, Tensor? A_colwise_si, " + "Tensor? A_first_dims, Tensor? A_last_dims, Tensor? A_tensor_offsets, " + "int A_te_dtype, int A_scaling_mode, int A_logical_0, int A_logical_1, " + "int A_num_tensors, bool A_swizzled, bool transa, " + // B + "Tensor? B_rowwise, Tensor? B_colwise, Tensor? B_si, Tensor? B_colwise_si, " + "Tensor? B_first_dims, Tensor? B_last_dims, Tensor? B_tensor_offsets, " + "int B_te_dtype, int B_scaling_mode, int B_logical_0, int B_logical_1, " + "int B_num_tensors, bool B_swizzled, bool transb, " + // D + "Tensor? D_rowwise, Tensor? D_colwise, Tensor? D_si, Tensor? D_colwise_si, " + "Tensor? D_first_dims, Tensor? D_last_dims, Tensor? D_tensor_offsets, " + "int D_te_dtype, int D_scaling_mode, int D_logical_0, int D_logical_1, " + "int D_num_tensors, " + // config + "Tensor alpha, Tensor beta, Tensor workspace_setup, Tensor workspace_cublas, " + "bool use_split_accumulator, int sm_count, " + // bias + "bool has_bias, " + "Tensor? bias_rowwise, Tensor? bias_colwise, Tensor? bias_si, Tensor? bias_colwise_si, " + "Tensor? bias_first_dims, Tensor? bias_last_dims, Tensor? bias_tensor_offsets, " + "int bias_te_dtype, int bias_scaling_mode, int bias_logical_0, int bias_logical_1, " + "int bias_num_tensors, bool bias_swizzled" + ") -> ()"); + + // grouped_gemm_for_discrete_in: A_ptrs(7) + A_meta(3) + num_a + + // B(13) + transb + D(13) + config(6) + has_bias + bias(13) = 57 args + m.def( + "grouped_gemm_for_discrete_in(" + // A packed pointers + "Tensor A_rowwise_ptrs, Tensor A_colwise_ptrs, Tensor A_si_ptrs, Tensor A_csi_ptrs, " + "Tensor A_shapes, Tensor A_te_dtypes, Tensor A_scaling_modes, int num_a_tensors, " + // B + "Tensor? B_rowwise, Tensor? B_colwise, Tensor? B_si, Tensor? B_colwise_si, " + "Tensor? B_first_dims, Tensor? B_last_dims, Tensor? B_tensor_offsets, " + "int B_te_dtype, int B_scaling_mode, int B_logical_0, int B_logical_1, " + "int B_num_tensors, bool B_swizzled, bool transb, " + // D + "Tensor? D_rowwise, Tensor? D_colwise, Tensor? D_si, Tensor? D_colwise_si, " + "Tensor? D_first_dims, Tensor? D_last_dims, Tensor? D_tensor_offsets, " + "int D_te_dtype, int D_scaling_mode, int D_logical_0, int D_logical_1, " + "int D_num_tensors, " + // config + "Tensor alpha, Tensor beta, Tensor workspace_setup, Tensor workspace_cublas, " + "bool use_split_accumulator, int sm_count, " + // bias + "bool has_bias, " + "Tensor? bias_rowwise, Tensor? bias_colwise, Tensor? bias_si, Tensor? bias_colwise_si, " + "Tensor? bias_first_dims, Tensor? bias_last_dims, Tensor? bias_tensor_offsets, " + "int bias_te_dtype, int bias_scaling_mode, int bias_logical_0, int bias_logical_1, " + "int bias_num_tensors, bool bias_swizzled" + ") -> ()"); + + // grouped_gemm_for_discrete_out: A(13) + transa + B(13) + transb + + // D_ptrs(5) + num_d + config(6) = 51 args + m.def( + "grouped_gemm_for_discrete_out(" + // A + "Tensor? A_rowwise, Tensor? A_colwise, Tensor? A_si, Tensor? A_colwise_si, " + "Tensor? A_first_dims, Tensor? A_last_dims, Tensor? A_tensor_offsets, " + "int A_te_dtype, int A_scaling_mode, int A_logical_0, int A_logical_1, " + "int A_num_tensors, bool A_swizzled, bool transa, " + // B + "Tensor? B_rowwise, Tensor? B_colwise, Tensor? B_si, Tensor? B_colwise_si, " + "Tensor? B_first_dims, Tensor? B_last_dims, Tensor? B_tensor_offsets, " + "int B_te_dtype, int B_scaling_mode, int B_logical_0, int B_logical_1, " + "int B_num_tensors, bool B_swizzled, bool transb, " + // D packed pointers + "Tensor D_rowwise_ptrs, Tensor D_si_ptrs, " + "Tensor D_shapes, Tensor D_te_dtypes, Tensor D_scaling_modes, int num_d_tensors, " + // config + "Tensor alpha, Tensor beta, Tensor workspace_setup, Tensor workspace_cublas, " + "bool use_split_accumulator, int sm_count" + ") -> ()"); +} + +STABLE_TORCH_LIBRARY_IMPL(transformer_engine_stable, CUDA, m) { + using namespace transformer_engine::pytorch::stable; + m.impl("grouped_gemm_for_grouped_tensor", TORCH_BOX(grouped_gemm_for_grouped_tensor)); + m.impl("grouped_gemm_for_discrete_in", TORCH_BOX(grouped_gemm_for_discrete_in)); + m.impl("grouped_gemm_for_discrete_out", TORCH_BOX(grouped_gemm_for_discrete_out)); +} diff --git a/transformer_engine/pytorch/csrc/extensions/misc.cpp b/transformer_engine/pytorch/csrc/extensions/misc.cpp index c5707fa53c..10044385ee 100644 --- a/transformer_engine/pytorch/csrc/extensions/misc.cpp +++ b/transformer_engine/pytorch/csrc/extensions/misc.cpp @@ -4,30 +4,35 @@ * See LICENSE for license information. ************************************************************************/ -#include "../extensions.h" +#include -namespace transformer_engine::pytorch { +#include "../stable_common.h" -size_t get_cublasLt_version() { return cublasLtGetVersion(); } +namespace transformer_engine::pytorch::stable { -size_t get_cudnn_version() { return cudnnGetVersion(); } +using Tensor = torch::stable::Tensor; -at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_dim) { - NVTE_CHECK(first_dims.is_cuda(), "first_dims must be on CUDA."); - NVTE_CHECK(first_dims.scalar_type() == at::kLong, "first_dims must have dtype int64."); - NVTE_CHECK(first_dims.dim() == 1, "first_dims must be a 1D tensor."); - NVTE_CHECK(logical_last_dim > 0, "logical_last_dim must be greater than 0."); +Tensor splits_to_offsets(Tensor first_dims, int64_t logical_last_dim) { + STD_TORCH_CHECK(first_dims.is_cuda(), "first_dims must be on CUDA."); + STD_TORCH_CHECK(first_dims.scalar_type() == ScalarType::Long, + "first_dims must have dtype int64."); + STD_TORCH_CHECK(first_dims.dim() == 1, "first_dims must be a 1D tensor."); + STD_TORCH_CHECK(logical_last_dim > 0, "logical_last_dim must be greater than 0."); - auto first_dims_contiguous = first_dims.contiguous(); - const auto num_tensors = static_cast(first_dims_contiguous.numel()); - auto output = at::empty({static_cast(num_tensors) + 1}, - first_dims_contiguous.options().dtype(at::kLong)); + auto first_dims_c = torch::stable::contiguous(first_dims); + const auto num_tensors = static_cast(first_dims_c.numel()); + auto output = allocateStableTensor({static_cast(num_tensors) + 1}, ScalarType::Long, + first_dims_c.get_device_index()); - nvte_splits_to_offsets(static_cast(first_dims_contiguous.data_ptr()), - static_cast(output.data_ptr()), num_tensors, logical_last_dim, - at::cuda::getCurrentCUDAStream()); + nvte_splits_to_offsets(static_cast(first_dims_c.data_ptr()), + static_cast(output.data_ptr()), num_tensors, logical_last_dim, + getCurrentCUDAStreamRaw(first_dims_c.get_device_index())); return output; } -} // namespace transformer_engine::pytorch +STABLE_TORCH_LIBRARY_IMPL(transformer_engine_stable, CUDA, m) { + m.impl("splits_to_offsets", TORCH_BOX(splits_to_offsets)); +} + +} // namespace transformer_engine::pytorch::stable diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor.cpp new file mode 100644 index 0000000000..f3bf690bf5 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor.cpp @@ -0,0 +1,397 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "../stable_common.h" + +namespace transformer_engine::pytorch::stable { + +using Tensor = torch::stable::Tensor; + +// ============================================================================ +// Multi-tensor helper: reconstruct NVTETensor** from flat pointer/shape tensors +// +// Python packs tensor_lists as: +// ptrs: int64 tensor [num_lists * num_tensors] — data_ptr() for each tensor +// shapes: int64 tensor [num_lists * num_tensors * 2] — (numel, element_size) +// dtypes: int64 tensor [num_lists * num_tensors] — TE DType values +// +// C++ reconstructs the 2D NVTETensor** structure. +// ============================================================================ + +namespace { + +struct MultiTensorPack { + std::vector wrappers; + std::vector> lists; + std::vector list_ptrs; + + void build(const Tensor& ptrs, const Tensor& shapes, const Tensor& dtypes, int64_t num_lists, + int64_t num_tensors) { + auto ptrs_cpu = ptrs; // already on CPU or we read via data_ptr + auto shapes_cpu = shapes; + auto dtypes_cpu = dtypes; + + const int64_t* p = static_cast(ptrs_cpu.data_ptr()); + const int64_t* s = static_cast(shapes_cpu.data_ptr()); + const int64_t* d = static_cast(dtypes_cpu.data_ptr()); + + wrappers.reserve(num_lists * num_tensors); + lists.resize(num_lists); + + for (int64_t li = 0; li < num_lists; ++li) { + lists[li].reserve(num_tensors); + for (int64_t ti = 0; ti < num_tensors; ++ti) { + int64_t idx = li * num_tensors + ti; + void* data = reinterpret_cast(p[idx]); + size_t numel = static_cast(s[idx * 2]); + auto dtype = static_cast(d[idx]); + wrappers.emplace_back(makeTransformerEngineTensor(data, std::vector{numel}, dtype)); + lists[li].push_back(wrappers.back().data()); + } + } + + list_ptrs.reserve(num_lists); + for (auto& l : lists) { + list_ptrs.push_back(l.data()); + } + } +}; + +} // namespace + +// ============================================================================ +// Multi-tensor scale +// ============================================================================ + +void multi_tensor_scale(int64_t chunk_size, Tensor is_infinite, Tensor ptrs, Tensor shapes, + Tensor dtypes, int64_t num_lists, int64_t num_tensors, double scale) { + MultiTensorPack pack; + pack.build(ptrs, shapes, dtypes, num_lists, num_tensors); + + auto is_inf_cu = makeTransformerEngineTensor(is_infinite); + nvte_multi_tensor_scale_cuda(static_cast(chunk_size), is_inf_cu.data(), + pack.list_ptrs.data(), static_cast(num_lists), + static_cast(num_tensors), static_cast(scale), + getCurrentCUDAStreamRaw(is_infinite.get_device_index())); +} + +void multi_tensor_scale_tensor(int64_t chunk_size, Tensor is_infinite, Tensor ptrs, Tensor shapes, + Tensor dtypes, int64_t num_lists, int64_t num_tensors, + Tensor scale) { + MultiTensorPack pack; + pack.build(ptrs, shapes, dtypes, num_lists, num_tensors); + + auto is_inf_cu = makeTransformerEngineTensor(is_infinite); + auto scale_cu = makeTransformerEngineTensor(scale); + nvte_multi_tensor_scale_tensor_cuda(static_cast(chunk_size), is_inf_cu.data(), + pack.list_ptrs.data(), static_cast(num_lists), + static_cast(num_tensors), scale_cu.data(), + getCurrentCUDAStreamRaw(is_infinite.get_device_index())); +} + +// ============================================================================ +// Multi-tensor L2 norm +// ============================================================================ + +std::tuple multi_tensor_l2norm(int64_t chunk_size, Tensor noop_flag, Tensor ptrs, + Tensor shapes, Tensor dtypes, int64_t num_lists, + int64_t num_tensors, bool per_tensor) { + MultiTensorPack pack; + pack.build(ptrs, shapes, dtypes, num_lists, num_tensors); + + auto device_idx = noop_flag.get_device_index(); + auto noop_cu = makeTransformerEngineTensor(noop_flag); + + // Max chunks per tensor + int max_chunks_per_tensor = -1; + const int64_t* s = static_cast(shapes.data_ptr()); + for (int64_t ti = 0; ti < num_tensors; ++ti) { + int chunks = (static_cast(s[ti * 2]) + chunk_size - 1) / static_cast(chunk_size); + if (chunks > max_chunks_per_tensor) max_chunks_per_tensor = chunks; + } + + auto output = allocateStableTensorZeros({320}, ScalarType::Float, device_idx); + auto ret = allocateStableTensor({1}, ScalarType::Float, device_idx); + auto output_per_tensor = + per_tensor + ? allocateStableTensorZeros({static_cast(num_tensors) * max_chunks_per_tensor}, + ScalarType::Float, device_idx) + : allocateStableTensor({1}, ScalarType::Float, device_idx); + auto ret_per_tensor = per_tensor ? allocateStableTensor({static_cast(num_tensors)}, + ScalarType::Float, device_idx) + : allocateStableTensor({1}, ScalarType::Float, device_idx); + + auto output_cu = makeTransformerEngineTensor(output); + auto ret_cu = makeTransformerEngineTensor(ret); + auto opt_cu = makeTransformerEngineTensor(output_per_tensor); + auto rpt_cu = makeTransformerEngineTensor(ret_per_tensor); + + nvte_multi_tensor_l2norm_cuda(static_cast(chunk_size), noop_cu.data(), pack.list_ptrs.data(), + static_cast(num_lists), static_cast(num_tensors), + output_cu.data(), opt_cu.data(), ret_cu.data(), rpt_cu.data(), + per_tensor, max_chunks_per_tensor, + getCurrentCUDAStreamRaw(device_idx)); + + return std::make_tuple(ret, ret_per_tensor); +} + +// ============================================================================ +// Multi-tensor Adam +// ============================================================================ + +void multi_tensor_adam(int64_t chunk_size, Tensor noop_flag, Tensor ptrs, Tensor shapes, + Tensor dtypes, int64_t num_lists, int64_t num_tensors, double lr, + double beta1, double beta2, double epsilon, int64_t step, int64_t mode, + int64_t bias_correction, double weight_decay) { + MultiTensorPack pack; + pack.build(ptrs, shapes, dtypes, num_lists, num_tensors); + auto noop_cu = makeTransformerEngineTensor(noop_flag); + + nvte_multi_tensor_adam_cuda( + static_cast(chunk_size), noop_cu.data(), pack.list_ptrs.data(), + static_cast(num_lists), static_cast(num_tensors), static_cast(lr), + static_cast(beta1), static_cast(beta2), static_cast(epsilon), + static_cast(step), static_cast(mode), static_cast(bias_correction), + static_cast(weight_decay), getCurrentCUDAStreamRaw(noop_flag.get_device_index())); +} + +void multi_tensor_adam_capturable(int64_t chunk_size, Tensor noop_flag, Tensor ptrs, Tensor shapes, + Tensor dtypes, int64_t num_lists, int64_t num_tensors, Tensor lr, + double beta1, double beta2, double epsilon, Tensor step, + int64_t mode, int64_t bias_correction, double weight_decay, + Tensor inv_scale) { + MultiTensorPack pack; + pack.build(ptrs, shapes, dtypes, num_lists, num_tensors); + auto noop_cu = makeTransformerEngineTensor(noop_flag); + auto lr_cu = makeTransformerEngineTensor(lr); + auto step_cu = makeTransformerEngineTensor(step); + auto inv_cu = makeTransformerEngineTensor(inv_scale); + + nvte_multi_tensor_adam_capturable_cuda( + static_cast(chunk_size), noop_cu.data(), pack.list_ptrs.data(), + static_cast(num_lists), static_cast(num_tensors), lr_cu.data(), + static_cast(beta1), static_cast(beta2), static_cast(epsilon), + step_cu.data(), static_cast(mode), static_cast(bias_correction), + static_cast(weight_decay), inv_cu.data(), + getCurrentCUDAStreamRaw(noop_flag.get_device_index())); +} + +// ============================================================================ +// Multi-tensor SGD +// ============================================================================ + +void multi_tensor_sgd(int64_t chunk_size, Tensor noop_flag, Tensor ptrs, Tensor shapes, + Tensor dtypes, int64_t num_lists, int64_t num_tensors, double wd, + double momentum, double dampening, double lr, bool nesterov, bool first_run, + bool wd_after_momentum, double scale) { + MultiTensorPack pack; + pack.build(ptrs, shapes, dtypes, num_lists, num_tensors); + auto noop_cu = makeTransformerEngineTensor(noop_flag); + + nvte_multi_tensor_sgd_cuda(static_cast(chunk_size), noop_cu.data(), pack.list_ptrs.data(), + static_cast(num_lists), static_cast(num_tensors), + static_cast(wd), static_cast(momentum), + static_cast(dampening), static_cast(lr), nesterov, + first_run, wd_after_momentum, static_cast(scale), + getCurrentCUDAStreamRaw(noop_flag.get_device_index())); +} + +// ============================================================================ +// Remaining Adam variants +// ============================================================================ + +void multi_tensor_adam_param_remainder(int64_t chunk_size, Tensor noop_flag, Tensor ptrs, + Tensor shapes, Tensor dtypes, int64_t num_lists, + int64_t num_tensors, double lr, double beta1, double beta2, + double epsilon, int64_t step, int64_t mode, + int64_t bias_correction, double weight_decay) { + MultiTensorPack pack; + pack.build(ptrs, shapes, dtypes, num_lists, num_tensors); + auto noop_cu = makeTransformerEngineTensor(noop_flag); + nvte_multi_tensor_adam_param_remainder_cuda( + static_cast(chunk_size), noop_cu.data(), pack.list_ptrs.data(), + static_cast(num_lists), static_cast(num_tensors), static_cast(lr), + static_cast(beta1), static_cast(beta2), static_cast(epsilon), + static_cast(step), static_cast(mode), static_cast(bias_correction), + static_cast(weight_decay), getCurrentCUDAStreamRaw(noop_flag.get_device_index())); +} + +void multi_tensor_adam_fp8(int64_t chunk_size, Tensor noop_flag, Tensor ptrs, Tensor shapes, + Tensor dtypes, int64_t num_lists, int64_t num_tensors, double lr, + double beta1, double beta2, double epsilon, int64_t step, int64_t mode, + int64_t bias_correction, double weight_decay, int64_t fp8_dtype) { + MultiTensorPack pack; + pack.build(ptrs, shapes, dtypes, num_lists, num_tensors); + auto noop_cu = makeTransformerEngineTensor(noop_flag); + nvte_multi_tensor_adam_fp8_cuda( + static_cast(chunk_size), noop_cu.data(), pack.list_ptrs.data(), + static_cast(num_lists), static_cast(num_tensors), static_cast(lr), + static_cast(beta1), static_cast(beta2), static_cast(epsilon), + static_cast(step), static_cast(mode), static_cast(bias_correction), + static_cast(weight_decay), static_cast(fp8_dtype), + getCurrentCUDAStreamRaw(noop_flag.get_device_index())); +} + +void multi_tensor_adam_capturable_master(int64_t chunk_size, Tensor noop_flag, Tensor ptrs, + Tensor shapes, Tensor dtypes, int64_t num_lists, + int64_t num_tensors, Tensor lr, double beta1, double beta2, + double epsilon, Tensor step, int64_t mode, + int64_t bias_correction, double weight_decay, + Tensor inv_scale) { + MultiTensorPack pack; + pack.build(ptrs, shapes, dtypes, num_lists, num_tensors); + auto noop_cu = makeTransformerEngineTensor(noop_flag); + auto lr_cu = makeTransformerEngineTensor(lr); + auto step_cu = makeTransformerEngineTensor(step); + auto inv_cu = makeTransformerEngineTensor(inv_scale); + nvte_multi_tensor_adam_capturable_master_cuda( + static_cast(chunk_size), noop_cu.data(), pack.list_ptrs.data(), + static_cast(num_lists), static_cast(num_tensors), lr_cu.data(), + static_cast(beta1), static_cast(beta2), static_cast(epsilon), + step_cu.data(), static_cast(mode), static_cast(bias_correction), + static_cast(weight_decay), inv_cu.data(), + getCurrentCUDAStreamRaw(noop_flag.get_device_index())); +} + +// ============================================================================ +// Multi-tensor scale computation +// ============================================================================ + +void multi_tensor_compute_scale_and_scale_inv(int64_t chunk_size, Tensor noop_flag, Tensor ptrs, + Tensor shapes, Tensor dtypes, int64_t num_lists, + int64_t num_tensors, double max_fp8, + bool force_pow_2_scales, double epsilon) { + MultiTensorPack pack; + pack.build(ptrs, shapes, dtypes, num_lists, num_tensors); + auto noop_cu = makeTransformerEngineTensor(noop_flag); + nvte_multi_tensor_compute_scale_and_scale_inv_cuda( + static_cast(chunk_size), noop_cu.data(), pack.list_ptrs.data(), + static_cast(num_lists), static_cast(num_tensors), static_cast(max_fp8), + force_pow_2_scales, static_cast(epsilon), + getCurrentCUDAStreamRaw(noop_flag.get_device_index())); +} + +void multi_tensor_compute_scale_inv_e8m0(int64_t chunk_size, + Tensor dummy_cuda, // dummy CUDA tensor for dispatch + Tensor ptrs, Tensor shapes, Tensor dtypes, + int64_t num_lists, int64_t num_tensors) { + MultiTensorPack pack; + pack.build(ptrs, shapes, dtypes, num_lists, num_tensors); + nvte_multi_tensor_compute_scale_inv_e8m0_cuda( + static_cast(chunk_size), pack.list_ptrs.data(), static_cast(num_lists), + static_cast(num_tensors), getCurrentCUDAStreamRaw()); +} + +std::tuple multi_tensor_unscale_l2norm(int64_t chunk_size, Tensor noop_flag, + Tensor ptrs, Tensor shapes, Tensor dtypes, + int64_t num_lists, int64_t num_tensors, + Tensor inv_scale, bool per_tensor) { + MultiTensorPack pack; + pack.build(ptrs, shapes, dtypes, num_lists, num_tensors); + + auto device_idx = noop_flag.get_device_index(); + auto noop_cu = makeTransformerEngineTensor(noop_flag); + auto inv_cu = makeTransformerEngineTensor(inv_scale); + + int max_chunks_per_tensor = -1; + const int64_t* s = static_cast(shapes.data_ptr()); + for (int64_t ti = 0; ti < num_tensors; ++ti) { + int chunks = (static_cast(s[ti * 2]) + static_cast(chunk_size) - 1) / + static_cast(chunk_size); + if (chunks > max_chunks_per_tensor) max_chunks_per_tensor = chunks; + } + + auto output = allocateStableTensorZeros({320}, ScalarType::Float, device_idx); + auto ret = allocateStableTensor({1}, ScalarType::Float, device_idx); + auto opt = per_tensor ? allocateStableTensorZeros({num_tensors * max_chunks_per_tensor}, + ScalarType::Float, device_idx) + : allocateStableTensor({1}, ScalarType::Float, device_idx); + auto rpt = per_tensor ? allocateStableTensor({num_tensors}, ScalarType::Float, device_idx) + : allocateStableTensor({1}, ScalarType::Float, device_idx); + + auto output_cu = makeTransformerEngineTensor(output); + auto ret_cu = makeTransformerEngineTensor(ret); + auto opt_cu = makeTransformerEngineTensor(opt); + auto rpt_cu = makeTransformerEngineTensor(rpt); + + nvte_multi_tensor_unscale_l2norm_cuda( + static_cast(chunk_size), noop_cu.data(), pack.list_ptrs.data(), + static_cast(num_lists), static_cast(num_tensors), output_cu.data(), + opt_cu.data(), ret_cu.data(), rpt_cu.data(), inv_cu.data(), per_tensor, max_chunks_per_tensor, + getCurrentCUDAStreamRaw(device_idx)); + + return std::make_tuple(ret, rpt); +} + +} // namespace transformer_engine::pytorch::stable + +STABLE_TORCH_LIBRARY_FRAGMENT(transformer_engine_stable, m) { + m.def( + "multi_tensor_scale(int chunk_size, Tensor is_infinite, Tensor ptrs, Tensor shapes, Tensor " + "dtypes, int num_lists, int num_tensors, float scale) -> ()"); + m.def( + "multi_tensor_scale_tensor(int chunk_size, Tensor is_infinite, Tensor ptrs, Tensor shapes, " + "Tensor dtypes, int num_lists, int num_tensors, Tensor scale) -> ()"); + m.def( + "multi_tensor_l2norm(int chunk_size, Tensor noop_flag, Tensor ptrs, Tensor shapes, Tensor " + "dtypes, int num_lists, int num_tensors, bool per_tensor) -> (Tensor, Tensor)"); + m.def( + "multi_tensor_adam(int chunk_size, Tensor noop_flag, Tensor ptrs, Tensor shapes, Tensor " + "dtypes, int num_lists, int num_tensors, float lr, float beta1, float beta2, float epsilon, " + "int step, int mode, int bias_correction, float weight_decay) -> ()"); + m.def( + "multi_tensor_adam_capturable(int chunk_size, Tensor noop_flag, Tensor ptrs, Tensor shapes, " + "Tensor dtypes, int num_lists, int num_tensors, Tensor lr, float beta1, float beta2, float " + "epsilon, Tensor step, int mode, int bias_correction, float weight_decay, Tensor inv_scale) " + "-> ()"); + m.def( + "multi_tensor_sgd(int chunk_size, Tensor noop_flag, Tensor ptrs, Tensor shapes, Tensor " + "dtypes, int num_lists, int num_tensors, float wd, float momentum, float dampening, float " + "lr, bool nesterov, bool first_run, bool wd_after_momentum, float scale) -> ()"); + m.def( + "multi_tensor_adam_param_remainder(int chunk_size, Tensor noop_flag, Tensor ptrs, Tensor " + "shapes, Tensor dtypes, int num_lists, int num_tensors, float lr, float beta1, float beta2, " + "float epsilon, int step, int mode, int bias_correction, float weight_decay) -> ()"); + m.def( + "multi_tensor_adam_fp8(int chunk_size, Tensor noop_flag, Tensor ptrs, Tensor shapes, Tensor " + "dtypes, int num_lists, int num_tensors, float lr, float beta1, float beta2, float epsilon, " + "int step, int mode, int bias_correction, float weight_decay, int fp8_dtype) -> ()"); + m.def( + "multi_tensor_adam_capturable_master(int chunk_size, Tensor noop_flag, Tensor ptrs, Tensor " + "shapes, Tensor dtypes, int num_lists, int num_tensors, Tensor lr, float beta1, float beta2, " + "float epsilon, Tensor step, int mode, int bias_correction, float weight_decay, Tensor " + "inv_scale) -> ()"); + m.def( + "multi_tensor_compute_scale_and_scale_inv(int chunk_size, Tensor noop_flag, Tensor ptrs, " + "Tensor shapes, Tensor dtypes, int num_lists, int num_tensors, float max_fp8, bool " + "force_pow_2_scales, float epsilon) -> ()"); + m.def( + "multi_tensor_compute_scale_inv_e8m0(int chunk_size, Tensor dummy_cuda, Tensor ptrs, Tensor " + "shapes, Tensor dtypes, int num_lists, int num_tensors) -> ()"); + m.def( + "multi_tensor_unscale_l2norm(int chunk_size, Tensor noop_flag, Tensor ptrs, Tensor shapes, " + "Tensor dtypes, int num_lists, int num_tensors, Tensor inv_scale, bool per_tensor) -> " + "(Tensor, Tensor)"); +} + +STABLE_TORCH_LIBRARY_IMPL(transformer_engine_stable, CUDA, m) { + using namespace transformer_engine::pytorch::stable; + m.impl("multi_tensor_scale", TORCH_BOX(multi_tensor_scale)); + m.impl("multi_tensor_scale_tensor", TORCH_BOX(multi_tensor_scale_tensor)); + m.impl("multi_tensor_l2norm", TORCH_BOX(multi_tensor_l2norm)); + m.impl("multi_tensor_adam", TORCH_BOX(multi_tensor_adam)); + m.impl("multi_tensor_adam_capturable", TORCH_BOX(multi_tensor_adam_capturable)); + m.impl("multi_tensor_sgd", TORCH_BOX(multi_tensor_sgd)); + m.impl("multi_tensor_adam_param_remainder", TORCH_BOX(multi_tensor_adam_param_remainder)); + m.impl("multi_tensor_adam_fp8", TORCH_BOX(multi_tensor_adam_fp8)); + m.impl("multi_tensor_adam_capturable_master", TORCH_BOX(multi_tensor_adam_capturable_master)); + m.impl("multi_tensor_compute_scale_and_scale_inv", + TORCH_BOX(multi_tensor_compute_scale_and_scale_inv)); + m.impl("multi_tensor_compute_scale_inv_e8m0", TORCH_BOX(multi_tensor_compute_scale_inv_e8m0)); + m.impl("multi_tensor_unscale_l2norm", TORCH_BOX(multi_tensor_unscale_l2norm)); +} diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp deleted file mode 100644 index 145e1d4b40..0000000000 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp +++ /dev/null @@ -1,92 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "../../extensions.h" - -namespace transformer_engine::pytorch { - -void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, const float lr, - const float beta1, const float beta2, const float epsilon, - const int step, const int mode, const int bias_correction, - const float weight_decay) { - auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); - auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = - makeTransformerEngineTensorList(tensor_lists); - - nvte_multi_tensor_adam_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, - num_tensors, lr, beta1, beta2, epsilon, step, mode, bias_correction, - weight_decay, at::cuda::getCurrentCUDAStream()); -} - -void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, - const float lr, const float beta1, const float beta2, - const float epsilon, const int step, const int mode, - const int bias_correction, const float weight_decay) { - auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); - auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = - makeTransformerEngineTensorList(tensor_lists); - - nvte_multi_tensor_adam_param_remainder_cuda( - chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, lr, beta1, - beta2, epsilon, step, mode, bias_correction, weight_decay, at::cuda::getCurrentCUDAStream()); -} - -void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, const float lr, - const float beta1, const float beta2, const float epsilon, - const int step, const int mode, const int bias_correction, - const float weight_decay, DType fp8_dtype) { - auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); - auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = - makeTransformerEngineTensorList(tensor_lists); - - nvte_multi_tensor_adam_fp8_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), - num_lists, num_tensors, lr, beta1, beta2, epsilon, step, mode, - bias_correction, weight_decay, static_cast(fp8_dtype), - at::cuda::getCurrentCUDAStream()); -} - -void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor lr, const float beta1, const float beta2, - const float epsilon, at::Tensor step, const int mode, - const int bias_correction, const float weight_decay, - at::Tensor inv_scale) { - auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); - auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = - makeTransformerEngineTensorList(tensor_lists); - auto lr_cu = makeTransformerEngineTensor(lr); - auto step_cu = makeTransformerEngineTensor(step); - auto inv_scale_cu = makeTransformerEngineTensor(inv_scale); - - nvte_multi_tensor_adam_capturable_cuda( - chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, - lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay, - inv_scale_cu.data(), at::cuda::getCurrentCUDAStream()); -} - -void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor lr, const float beta1, const float beta2, - const float epsilon, at::Tensor step, const int mode, - const int bias_correction, const float weight_decay, - at::Tensor inv_scale) { - auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); - auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = - makeTransformerEngineTensorList(tensor_lists); - auto lr_cu = makeTransformerEngineTensor(lr); - auto step_cu = makeTransformerEngineTensor(step); - auto inv_scale_cu = makeTransformerEngineTensor(inv_scale); - - nvte_multi_tensor_adam_capturable_master_cuda( - chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, - lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay, - inv_scale_cu.data(), at::cuda::getCurrentCUDAStream()); -} - -} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp deleted file mode 100644 index 328970ffa8..0000000000 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp +++ /dev/null @@ -1,33 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "../../extensions.h" - -namespace transformer_engine::pytorch { - -void multi_tensor_compute_scale_and_scale_inv_cuda( - int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, - float max_fp8, bool force_pow_2_scales, float epsilon) { - auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); - auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = - makeTransformerEngineTensorList(tensor_lists); - - nvte_multi_tensor_compute_scale_and_scale_inv_cuda( - chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, max_fp8, - force_pow_2_scales, epsilon, at::cuda::getCurrentCUDAStream()); -} - -void multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, const py::object &dummy, - std::vector> tensor_lists) { - NVTE_CHECK(dummy.is_none(), "No-op flag is not supported."); - auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = - makeTransformerEngineTensorList(tensor_lists); - - nvte_multi_tensor_compute_scale_inv_e8m0_cuda(chunk_size, tensor_lists_ptr.data(), num_lists, - num_tensors, at::cuda::getCurrentCUDAStream()); -} - -} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp deleted file mode 100644 index b02cf1fbba..0000000000 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp +++ /dev/null @@ -1,102 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "../../extensions.h" - -namespace transformer_engine::pytorch { - -std::tuple multi_tensor_l2norm_cuda( - int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, - at::optional per_tensor_python) { - bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false; - - auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); - auto output = at::zeros({320}, float_options); - - at::Tensor output_per_tensor; - at::Tensor ret_per_tensor; - auto ret = at::empty({1}, output.options()); - - int ntensors = tensor_lists[0].size(); - int max_chunks_per_tensor = -1; - - if (per_tensor) { - for (int t = 0; t < ntensors; t++) { - int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; - if (max_chunks_this_tensor > max_chunks_per_tensor) - max_chunks_per_tensor = max_chunks_this_tensor; - } - output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options); - ret_per_tensor = at::empty({ntensors}, float_options); - } else { - output_per_tensor = at::empty({0}, float_options); - ret_per_tensor = at::empty({0}, float_options); - } - - auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); - auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = - makeTransformerEngineTensorList(tensor_lists); - auto output_cu = makeTransformerEngineTensor(output); - auto output_per_tensor_cu = makeTransformerEngineTensor(output_per_tensor); - auto ret_cu = makeTransformerEngineTensor(ret); - auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor); - - nvte_multi_tensor_l2norm_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, - num_tensors, output_cu.data(), output_per_tensor_cu.data(), - ret_cu.data(), ret_per_tensor_cu.data(), per_tensor, - max_chunks_per_tensor, at::cuda::getCurrentCUDAStream()); - - return std::tuple(ret, ret_per_tensor); -} - -std::tuple multi_tensor_unscale_l2norm_cuda( - int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, - at::Tensor inv_scale, at::optional per_tensor_python) { - bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false; - - auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); - auto output = at::zeros({320}, float_options); - - at::Tensor output_per_tensor; - at::Tensor ret_per_tensor; - - int ntensors = tensor_lists[0].size(); - int max_chunks_per_tensor = -1; - - // Create output tensors for multi scale L2 norm kernel. - if (per_tensor) { - for (int t = 0; t < ntensors; t++) { - int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; - if (max_chunks_this_tensor > max_chunks_per_tensor) - max_chunks_per_tensor = max_chunks_this_tensor; - } - output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options); - ret_per_tensor = at::empty({ntensors}, float_options); - } else { - output_per_tensor = at::empty({0}, float_options); - ret_per_tensor = at::empty({0}, float_options); - } - - auto ret = at::empty({1}, output.options()); - - auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); - auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = - makeTransformerEngineTensorList(tensor_lists); - auto output_cu = makeTransformerEngineTensor(output); - auto output_per_tensor_cu = makeTransformerEngineTensor(output_per_tensor); - auto ret_cu = makeTransformerEngineTensor(ret); - auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor); - auto inv_scale_cu = makeTransformerEngineTensor(inv_scale); - - nvte_multi_tensor_unscale_l2norm_cuda( - chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, - output_cu.data(), output_per_tensor_cu.data(), ret_cu.data(), ret_per_tensor_cu.data(), - inv_scale_cu.data(), per_tensor, max_chunks_per_tensor, at::cuda::getCurrentCUDAStream()); - - return std::tuple(ret, ret_per_tensor); -} - -} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp deleted file mode 100644 index 687eb34f32..0000000000 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp +++ /dev/null @@ -1,33 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "../../extensions.h" - -namespace transformer_engine::pytorch { - -void multi_tensor_scale_cuda(int chunk_size, at::Tensor is_infinite, - std::vector> tensor_lists, float scale) { - auto is_infinite_cu = makeTransformerEngineTensor(is_infinite); - auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = - makeTransformerEngineTensorList(tensor_lists); - - nvte_multi_tensor_scale_cuda(chunk_size, is_infinite_cu.data(), tensor_lists_ptr.data(), - num_lists, num_tensors, scale, at::cuda::getCurrentCUDAStream()); -} - -void multi_tensor_scale_tensor_cuda(int chunk_size, at::Tensor is_infinite, - std::vector> tensor_lists, - at::Tensor scale) { - auto is_infinite_cu = makeTransformerEngineTensor(is_infinite); - auto scale_cu = makeTransformerEngineTensor(scale); - auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = - makeTransformerEngineTensorList(tensor_lists); - nvte_multi_tensor_scale_tensor_cuda(chunk_size, is_infinite_cu.data(), tensor_lists_ptr.data(), - num_lists, num_tensors, scale_cu.data(), - at::cuda::getCurrentCUDAStream()); -} - -} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp deleted file mode 100644 index a70fe12b56..0000000000 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp +++ /dev/null @@ -1,24 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "../../extensions.h" - -namespace transformer_engine::pytorch { - -void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, float wd, - float momentum, float dampening, float lr, bool nesterov, bool first_run, - bool wd_after_momentum, float scale) { - auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); - auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = - makeTransformerEngineTensorList(tensor_lists); - - nvte_multi_tensor_sgd_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, - num_tensors, wd, momentum, dampening, lr, nesterov, first_run, - wd_after_momentum, scale, at::cuda::getCurrentCUDAStream()); -} - -} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 3214c3a9db..6e12572b2a 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -4,25 +4,30 @@ * See LICENSE for license information. ************************************************************************/ -#include "../extensions.h" -#include "common/util/system.h" -#include "pybind.h" - -namespace transformer_engine::pytorch { - -std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, - const at::Tensor &mu, const at::Tensor &rsigma, - const at::Tensor &gamma, const int sm_margin, - const bool zero_centered_gamma) { - const auto &dz_ = dz.contiguous(); - const auto &x_ = x.contiguous(); - const auto &mu_ = mu.contiguous(); - const auto &rsigma_ = rsigma.contiguous(); - const auto &gamma_ = gamma.contiguous(); - - auto dx = at::empty_like(x_); - auto dgamma = at::empty_like(gamma_); - auto dbeta = at::empty_like(gamma_); +#include + +#include "../stable_common.h" + +namespace transformer_engine::pytorch::stable { + +using Tensor = torch::stable::Tensor; + +// ============================================================================ +// Layernorm backward +// ============================================================================ + +std::tuple layernorm_bwd(Tensor dz, Tensor x, Tensor mu, Tensor rsigma, + Tensor gamma, int64_t sm_margin, + bool zero_centered_gamma) { + auto dz_ = torch::stable::contiguous(dz); + auto x_ = torch::stable::contiguous(x); + auto mu_ = torch::stable::contiguous(mu); + auto rsigma_ = torch::stable::contiguous(rsigma); + auto gamma_ = torch::stable::contiguous(gamma); + + auto dx = torch::stable::empty_like(x_); + auto dgamma = torch::stable::empty_like(gamma_); + auto dbeta = torch::stable::empty_like(gamma_); TensorWrapper workspace; auto dz_cu = makeTransformerEngineTensor(dz_); @@ -34,194 +39,109 @@ std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, auto dgamma_cu = makeTransformerEngineTensor(dgamma); auto dbeta_cu = makeTransformerEngineTensor(dbeta); - // This call populates tensors with the required config. - NVTE_SCOPED_GIL_RELEASE({ - nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), - dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - }); - - // Alloc space for Tensors. - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // Actual call to bwd kernel. - NVTE_SCOPED_GIL_RELEASE({ - nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), - dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - }); - - return {py::cast(dx), py::cast(dgamma), py::cast(dbeta)}; + auto device_idx = dz_.get_device_index(); + int sm_count = getSMCount(device_idx) - static_cast(sm_margin); + auto stream = getCurrentCUDAStreamRaw(device_idx); + + // First call: query workspace size + nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), + dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(), sm_count, + zero_centered_gamma, stream); + + // Allocate workspace + auto ws_shape = workspace.shape(); + auto ws_dtype = workspace.dtype(); + auto workspace_data = allocateStableTensor( + std::vector(ws_shape.data, ws_shape.data + ws_shape.ndim), ws_dtype, device_idx); + workspace = makeTransformerEngineTensor( + workspace_data.data_ptr(), std::vector(ws_shape.data, ws_shape.data + ws_shape.ndim), + ws_dtype); + + // Second call: actual computation + nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), + dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(), sm_count, + zero_centered_gamma, stream); + + return std::make_tuple(dx, dgamma, dbeta); } -std::vector layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias, - float eps, py::object out, py::handle quantizer, - DType out_dtype, const int sm_margin, - const bool zero_centered_gamma) { - using namespace transformer_engine::pytorch::detail; - - // Ensure that cuDNN handle is created on the correct device, - // overriding torch.cuda.set_device calls from user side. - // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(input.cast().device()); - - // Input and param tensors - auto none = py::none(); - const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none); - const TensorWrapper &weight_nvte = makeTransformerEngineTensor(weight, none); - TensorWrapper bias_nvte; +// ============================================================================ +// Layernorm forward (unquantized output) +// ============================================================================ + +std::tuple layernorm_fwd(Tensor input, Tensor weight, + std::optional bias, double eps, + int64_t sm_margin, bool zero_centered_gamma) { + auto input_ = torch::stable::contiguous(input); + auto weight_ = torch::stable::contiguous(weight); + + auto input_cu = makeTransformerEngineTensor(input_); + auto weight_cu = makeTransformerEngineTensor(weight_); + // bias_ must outlive the kernel launch — declaring at function scope ensures + // the contiguous tensor (if created) stays alive until after the kernel. + Tensor bias_contiguous; + TensorWrapper bias_cu; if (bias.has_value()) { - bias_nvte = makeTransformerEngineTensor(*bias); + bias_contiguous = torch::stable::contiguous(bias.value()); + bias_cu = makeTransformerEngineTensor(bias_contiguous); } - // Tensor dimensions - const auto shape = nvte_shape_to_vector(input_nvte.shape()); - const auto outer_size = product(shape) / shape.back(); - const auto inner_size = shape.back(); - - // Tensors to save for backward pass - at::Tensor mu_py = at::empty({static_cast(outer_size)}, at::CUDA(at::kFloat)); - at::Tensor rsigma_py = at::empty({static_cast(outer_size)}, at::CUDA(at::kFloat)); - TensorWrapper mu_nvte = makeTransformerEngineTensor(mu_py); - TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py); - - // Quantizer - auto quantizer_cpp = convert_quantizer(quantizer); - - // Choose implementation - enum class Impl { - // Compute norm in high precision, then quantize - UNFUSED, - // Compute norm directly - FULLY_FUSED, - // Compute norm and amax in high precision, then quantize to FP8 - FUSED_NORM_AMAX_FP8, - // Compute norm and amax in high precision, then quantize to NVFP4 - FUSED_NORM_AMAX_NVFP4 - }; - Impl impl = Impl::UNFUSED; - if (quantizer.is_none() || IsFloat8Quantizers(quantizer.ptr())) { - impl = Impl::FULLY_FUSED; - } else if (IsMXFP8Quantizers(quantizer.ptr())) { - if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN") && outer_size % 128 == 0 && - inner_size % 128 == 0) { - // cuDNN MXFP8 kernel requires full 128x128 tiles - impl = Impl::FULLY_FUSED; - } - } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && - !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); - NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); - impl = Impl::FUSED_NORM_AMAX_FP8; - } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { - auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); - NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer - impl = Impl::UNFUSED; - } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - // TE kernel supports amax output - impl = Impl::FUSED_NORM_AMAX_NVFP4; - } - } - - // Output tensor - TensorWrapper out_nvte; - if (out.is_none()) { - if (impl == Impl::FULLY_FUSED) { - // FP8 has no special logic to optimize for GEMM, MXFP8 cuDNN - // kernel does not support GEMM swizzled scales - quantizer_cpp->optimize_for_gemm = false; - } - std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype); - } else { - out_nvte = makeTransformerEngineTensor(out, quantizer); - } + auto shape = getStableTensorShape(input_); + size_t outer_size = 1; + for (size_t i = 0; i + 1 < shape.size(); ++i) outer_size *= shape[i]; - // Construct unquantized output tensor if needed - TensorWrapper unquantized_out_nvte; - py::object unquantized_out; - TensorWrapper *kernel_out_nvte = &out_nvte; - switch (impl) { - case Impl::UNFUSED: { - NoneQuantizer q{none}; - std::tie(unquantized_out_nvte, unquantized_out) = q.create_tensor(shape, out_dtype); - kernel_out_nvte = &unquantized_out_nvte; - } break; - case Impl::FUSED_NORM_AMAX_FP8: { - auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); - std::tie(unquantized_out_nvte, unquantized_out) = - fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype); - kernel_out_nvte = &unquantized_out_nvte; - } break; - case Impl::FUSED_NORM_AMAX_NVFP4: { - auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); - std::tie(unquantized_out_nvte, unquantized_out) = - nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, out_dtype); - kernel_out_nvte = &unquantized_out_nvte; - } break; - default: { - } - } + auto device_idx = input_.get_device_index(); + auto output = torch::stable::empty_like(input_); + auto mu = allocateStableTensor({static_cast(outer_size)}, ScalarType::Float, device_idx); + auto rsigma = + allocateStableTensor({static_cast(outer_size)}, ScalarType::Float, device_idx); - // Query workspace size + auto output_cu = makeTransformerEngineTensor(output); + auto mu_cu = makeTransformerEngineTensor(mu); + auto rsigma_cu = makeTransformerEngineTensor(rsigma); TensorWrapper workspace; - NVTE_SCOPED_GIL_RELEASE({ - nvte_layernorm_fwd(input_nvte.data(), weight_nvte.data(), bias_nvte.data(), eps, - kernel_out_nvte->data(), mu_nvte.data(), rsigma_nvte.data(), - workspace.data(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - }); - // Allocate workspace - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // Launch kernel - NVTE_SCOPED_GIL_RELEASE({ - nvte_layernorm_fwd(input_nvte.data(), weight_nvte.data(), bias_nvte.data(), eps, - kernel_out_nvte->data(), mu_nvte.data(), rsigma_nvte.data(), - workspace.data(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - }); - - // Quantize output if needed - switch (impl) { - case Impl::UNFUSED: { - quantizer_cpp->quantize(unquantized_out_nvte, out_nvte); - } break; - case Impl::FUSED_NORM_AMAX_FP8: { - auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); - fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); - } break; - case Impl::FUSED_NORM_AMAX_NVFP4: { - auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); - nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); - } break; - default: { - } + int sm_count = getSMCount(device_idx) - static_cast(sm_margin); + auto stream = getCurrentCUDAStreamRaw(device_idx); + + // First call: query workspace + nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), static_cast(eps), + output_cu.data(), mu_cu.data(), rsigma_cu.data(), workspace.data(), sm_count, + zero_centered_gamma, stream); + + // workspace_data must outlive the second kernel call — hoist out of if block. + Tensor workspace_data; + auto ws_shape = workspace.shape(); + auto ws_dtype = workspace.dtype(); + if (ws_shape.ndim > 0) { + workspace_data = allocateStableTensor( + std::vector(ws_shape.data, ws_shape.data + ws_shape.ndim), ws_dtype, device_idx); + workspace = makeTransformerEngineTensor( + workspace_data.data_ptr(), + std::vector(ws_shape.data, ws_shape.data + ws_shape.ndim), ws_dtype); } - return {out, py::cast(mu_py), py::cast(rsigma_py)}; + // Second call: actual computation + nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), static_cast(eps), + output_cu.data(), mu_cu.data(), rsigma_cu.data(), workspace.data(), sm_count, + zero_centered_gamma, stream); + + return std::make_tuple(output, mu, rsigma); } -std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, - const at::Tensor &rsigma, const at::Tensor &gamma, - const int sm_margin, const bool zero_centered_gamma) { - const auto &dz_ = dz.contiguous(); - const auto &x_ = x.contiguous(); - const auto &rsigma_ = rsigma.contiguous(); - const auto &gamma_ = gamma.contiguous(); +// ============================================================================ +// RMSnorm backward +// ============================================================================ - auto dx = at::empty_like(x_); - auto dgamma = at::empty_like(gamma_); +std::tuple rmsnorm_bwd(Tensor dz, Tensor x, Tensor rsigma, Tensor gamma, + int64_t sm_margin, bool zero_centered_gamma) { + auto dz_ = torch::stable::contiguous(dz); + auto x_ = torch::stable::contiguous(x); + auto rsigma_ = torch::stable::contiguous(rsigma); + auto gamma_ = torch::stable::contiguous(gamma); + + auto dx = torch::stable::empty_like(x_); + auto dgamma = torch::stable::empty_like(gamma_); TensorWrapper workspace; auto dz_cu = makeTransformerEngineTensor(dz_); @@ -231,42 +151,91 @@ std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, auto dx_cu = makeTransformerEngineTensor(dx); auto dgamma_cu = makeTransformerEngineTensor(dgamma); - // This call populates tensors with the required config. - NVTE_SCOPED_GIL_RELEASE({ - nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), workspace.data(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - }); - - // Alloc space for Tensors. - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // Actual call to bwd kernel. - NVTE_SCOPED_GIL_RELEASE({ - nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), workspace.data(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - }); - - return {py::cast(dx), py::cast(dgamma)}; + auto device_idx = dz_.get_device_index(); + int sm_count = getSMCount(device_idx) - static_cast(sm_margin); + auto stream = getCurrentCUDAStreamRaw(device_idx); + + nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), + dgamma_cu.data(), workspace.data(), sm_count, zero_centered_gamma, stream); + + auto ws_shape = workspace.shape(); + auto ws_dtype = workspace.dtype(); + auto workspace_data = allocateStableTensor( + std::vector(ws_shape.data, ws_shape.data + ws_shape.ndim), ws_dtype, device_idx); + workspace = makeTransformerEngineTensor( + workspace_data.data_ptr(), std::vector(ws_shape.data, ws_shape.data + ws_shape.ndim), + ws_dtype); + + nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), + dgamma_cu.data(), workspace.data(), sm_count, zero_centered_gamma, stream); + + return std::make_tuple(dx, dgamma); } -std::vector rmsnorm_bwd_add(const at::Tensor &dz, const at::Tensor &x, - const at::Tensor &add, const at::Tensor &rsigma, - const at::Tensor &gamma, const int sm_margin, - const bool zero_centered_gamma) { - const auto &dz_ = dz.contiguous(); - const auto &x_ = x.contiguous(); - const auto &add_ = add.contiguous(); - const auto &rsigma_ = rsigma.contiguous(); - const auto &gamma_ = gamma.contiguous(); - - auto dx = at::empty_like(x_); - auto dgamma = at::empty_like(gamma_); +// ============================================================================ +// RMSnorm forward (unquantized output) +// ============================================================================ + +std::tuple rmsnorm_fwd(Tensor input, Tensor weight, double eps, int64_t sm_margin, + bool zero_centered_gamma) { + auto input_ = torch::stable::contiguous(input); + auto weight_ = torch::stable::contiguous(weight); + + auto input_cu = makeTransformerEngineTensor(input_); + auto weight_cu = makeTransformerEngineTensor(weight_); + + auto shape = getStableTensorShape(input_); + size_t outer_size = 1; + for (size_t i = 0; i + 1 < shape.size(); ++i) outer_size *= shape[i]; + + auto device_idx = input_.get_device_index(); + auto output = torch::stable::empty_like(input_); + auto rsigma = + allocateStableTensor({static_cast(outer_size)}, ScalarType::Float, device_idx); + + auto output_cu = makeTransformerEngineTensor(output); + auto rsigma_cu = makeTransformerEngineTensor(rsigma); + TensorWrapper workspace; + + int sm_count = getSMCount(device_idx) - static_cast(sm_margin); + auto stream = getCurrentCUDAStreamRaw(device_idx); + + nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), static_cast(eps), output_cu.data(), + rsigma_cu.data(), workspace.data(), sm_count, zero_centered_gamma, stream); + + // workspace_data must outlive the second kernel call — hoist out of if block. + Tensor workspace_data; + auto ws_shape = workspace.shape(); + auto ws_dtype = workspace.dtype(); + if (ws_shape.ndim > 0) { + workspace_data = allocateStableTensor( + std::vector(ws_shape.data, ws_shape.data + ws_shape.ndim), ws_dtype, device_idx); + workspace = makeTransformerEngineTensor( + workspace_data.data_ptr(), + std::vector(ws_shape.data, ws_shape.data + ws_shape.ndim), ws_dtype); + } + + nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), static_cast(eps), output_cu.data(), + rsigma_cu.data(), workspace.data(), sm_count, zero_centered_gamma, stream); + + return std::make_tuple(output, rsigma); +} + +// ============================================================================ +// RMSnorm backward with add +// ============================================================================ + +std::tuple rmsnorm_bwd_add(Tensor dz, Tensor x, Tensor add, Tensor rsigma, + Tensor gamma, int64_t sm_margin, + bool zero_centered_gamma) { + auto dz_ = torch::stable::contiguous(dz); + auto x_ = torch::stable::contiguous(x); + auto add_ = torch::stable::contiguous(add); + auto rsigma_ = torch::stable::contiguous(rsigma); + auto gamma_ = torch::stable::contiguous(gamma); + + auto dx = torch::stable::empty_like(x_); + auto dgamma = torch::stable::empty_like(gamma_); TensorWrapper workspace; auto dz_cu = makeTransformerEngineTensor(dz_); @@ -277,173 +246,166 @@ std::vector rmsnorm_bwd_add(const at::Tensor &dz, const at::Tensor & auto dx_cu = makeTransformerEngineTensor(dx); auto dgamma_cu = makeTransformerEngineTensor(dgamma); - // This call populates tensors with the required config. - NVTE_SCOPED_GIL_RELEASE({ - nvte_rmsnorm_bwd_add(dz_cu.data(), x_cu.data(), add_cu.data(), rsigma_cu.data(), - gamma_cu.data(), dx_cu.data(), dgamma_cu.data(), workspace.data(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - }); - - // Alloc space for Tensors. - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // Actual call to bwd kernel. - NVTE_SCOPED_GIL_RELEASE({ - nvte_rmsnorm_bwd_add(dz_cu.data(), x_cu.data(), add_cu.data(), rsigma_cu.data(), - gamma_cu.data(), dx_cu.data(), dgamma_cu.data(), workspace.data(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - }); - - return {py::cast(dx), py::cast(dgamma)}; + auto device_idx = dz_.get_device_index(); + int sm_count = getSMCount(device_idx) - static_cast(sm_margin); + auto stream = getCurrentCUDAStreamRaw(device_idx); + + nvte_rmsnorm_bwd_add(dz_cu.data(), x_cu.data(), add_cu.data(), rsigma_cu.data(), gamma_cu.data(), + dx_cu.data(), dgamma_cu.data(), workspace.data(), sm_count, + zero_centered_gamma, stream); + + auto ws_shape = workspace.shape(); + auto ws_dtype = workspace.dtype(); + auto workspace_data = allocateStableTensor( + std::vector(ws_shape.data, ws_shape.data + ws_shape.ndim), ws_dtype, device_idx); + workspace = makeTransformerEngineTensor( + workspace_data.data_ptr(), std::vector(ws_shape.data, ws_shape.data + ws_shape.ndim), + ws_dtype); + + nvte_rmsnorm_bwd_add(dz_cu.data(), x_cu.data(), add_cu.data(), rsigma_cu.data(), gamma_cu.data(), + dx_cu.data(), dgamma_cu.data(), workspace.data(), sm_count, + zero_centered_gamma, stream); + + return std::make_tuple(dx, dgamma); } -std::vector rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps, - py::object out, py::handle quantizer, DType out_dtype, - const int sm_margin, const bool zero_centered_gamma) { - using namespace transformer_engine::pytorch::detail; - - // Ensure that cuDNN handle is created on the correct device, - // overriding torch.cuda.set_device calls from user side. - // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(input.cast().device()); - - // Input and param tensors - auto none = py::none(); - const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none); - const TensorWrapper &weight_nvte = makeTransformerEngineTensor(weight, none); - - // Tensor dimensions - const auto shape = nvte_shape_to_vector(input_nvte.shape()); - const auto outer_size = product(shape) / shape.back(); - const auto inner_size = shape.back(); - - // Tensors to save for backward pass - at::Tensor rsigma_py = at::empty({static_cast(outer_size)}, at::CUDA(at::kFloat)); - TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py); - - // Quantizer - auto quantizer_cpp = convert_quantizer(quantizer); - - // Choose implementation - enum class Impl { - // Compute norm in high precision, then quantize - UNFUSED, - // Compute norm directly - FULLY_FUSED, - // Compute norm and amax in high precision, then quantize to FP8 - FUSED_NORM_AMAX_FP8, - // Compute norm and amax in high precision, then quantize to NVFP4 - FUSED_NORM_AMAX_NVFP4 - }; - Impl impl = Impl::UNFUSED; - if (quantizer.is_none() || IsFloat8Quantizers(quantizer.ptr())) { - impl = Impl::FULLY_FUSED; - } else if (IsMXFP8Quantizers(quantizer.ptr())) { - if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN") && outer_size % 128 == 0 && - inner_size % 128 == 0) { - // cuDNN MXFP8 kernel requires full 128x128 tiles - impl = Impl::FULLY_FUSED; - } - } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && - !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); - NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); - impl = Impl::FUSED_NORM_AMAX_FP8; - } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { - auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); - NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer - impl = Impl::UNFUSED; - } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - // TE kernel supports amax output - impl = Impl::FUSED_NORM_AMAX_NVFP4; - } +// ============================================================================ +// Layernorm forward — no-alloc variant for quantized output +// ============================================================================ +// +// The caller pre-allocates output_data and all quantization buffers. +// The NVTE kernel writes to the output TensorWrapper, which is configured +// from the raw buffer arguments. This preserves all kernel fusion: +// +// FULLY_FUSED: pass quantized output_data + amax + scale + scale_inv +// NORM+AMAX fused: pass hp output_data + amax (no scale/scale_inv) +// UNFUSED: pass hp output_data only (no quantization buffers) +// +// The Python shim decides which buffers to provide based on quantizer type. + +std::tuple layernorm_fwd_noalloc( + Tensor input, Tensor weight, std::optional bias, double eps, + // Pre-allocated output buffer + Tensor output_data, + int64_t output_te_dtype, // transformer_engine::DType as int + // Optional quantization metadata (pass empty tensors if unused) + std::optional output_amax, std::optional output_scale, + std::optional output_scale_inv, + int64_t scaling_mode, // NVTEScalingMode as int + // mu/rsigma pre-allocated by caller + Tensor mu, Tensor rsigma, int64_t sm_margin, bool zero_centered_gamma) { + auto input_ = torch::stable::contiguous(input); + auto weight_ = torch::stable::contiguous(weight); + + auto input_cu = makeTransformerEngineTensor(input_); + auto weight_cu = makeTransformerEngineTensor(weight_); + // bias_contiguous must outlive the kernel — hoist out of if block. + Tensor bias_contiguous; + TensorWrapper bias_cu; + if (bias.has_value()) { + bias_contiguous = torch::stable::contiguous(bias.value()); + bias_cu = makeTransformerEngineTensor(bias_contiguous); } - // Output tensor - TensorWrapper out_nvte; - if (out.is_none()) { - if (impl == Impl::FULLY_FUSED) { - // FP8 has no special logic to optimize for GEMM, MXFP8 cuDNN - // kernel does not support GEMM swizzled scales - quantizer_cpp->optimize_for_gemm = false; - } - std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype); - } else { - out_nvte = makeTransformerEngineTensor(out, quantizer); - } + auto shape = getStableTensorShape(input_); + auto te_dtype = static_cast(output_te_dtype); + auto nvte_scaling = static_cast(scaling_mode); - // Construct unquantized output tensor if needed - TensorWrapper unquantized_out_nvte; - py::object unquantized_out; - TensorWrapper *kernel_out_nvte = &out_nvte; - switch (impl) { - case Impl::UNFUSED: { - NoneQuantizer q{none}; - std::tie(unquantized_out_nvte, unquantized_out) = q.create_tensor(shape, out_dtype); - kernel_out_nvte = &unquantized_out_nvte; - } break; - case Impl::FUSED_NORM_AMAX_FP8: { - auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); - std::tie(unquantized_out_nvte, unquantized_out) = - fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype); - kernel_out_nvte = &unquantized_out_nvte; - } break; - case Impl::FUSED_NORM_AMAX_NVFP4: { - auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); - std::tie(unquantized_out_nvte, unquantized_out) = - nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, out_dtype); - kernel_out_nvte = &unquantized_out_nvte; - } break; - default: { - } - } + auto output_cu = makeQuantizedTensorWrapper(output_data, te_dtype, shape, output_amax, + output_scale, output_scale_inv, nvte_scaling); + auto mu_cu = makeTransformerEngineTensor(mu); + auto rsigma_cu = makeTransformerEngineTensor(rsigma); - // Query workspace size - TensorWrapper workspace; - NVTE_SCOPED_GIL_RELEASE({ - nvte_rmsnorm_fwd(input_nvte.data(), weight_nvte.data(), eps, kernel_out_nvte->data(), - rsigma_nvte.data(), workspace.data(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - }); + auto device_idx = input_.get_device_index(); + int sm_count = getSMCount(device_idx) - static_cast(sm_margin); + auto stream = getCurrentCUDAStreamRaw(device_idx); - // Allocate workspace - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // Launch kernel - NVTE_SCOPED_GIL_RELEASE({ - nvte_rmsnorm_fwd(input_nvte.data(), weight_nvte.data(), eps, kernel_out_nvte->data(), - rsigma_nvte.data(), workspace.data(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, - zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - }); - - // Quantize output if needed - switch (impl) { - case Impl::UNFUSED: { - quantizer_cpp->quantize(unquantized_out_nvte, out_nvte); - } break; - case Impl::FUSED_NORM_AMAX_FP8: { - auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); - fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); - } break; - case Impl::FUSED_NORM_AMAX_NVFP4: { - auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); - nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); - } break; - default: { - } - } + runWithWorkspace( + [&](NVTETensor ws) { + nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), + static_cast(eps), output_cu.data(), mu_cu.data(), + rsigma_cu.data(), ws, sm_count, zero_centered_gamma, stream); + }, + device_idx); - return {out, py::none(), py::cast(rsigma_py)}; + return std::make_tuple(mu, rsigma); } -} // namespace transformer_engine::pytorch +// ============================================================================ +// RMSnorm forward — no-alloc variant for quantized output +// ============================================================================ + +Tensor rmsnorm_fwd_noalloc(Tensor input, Tensor weight, double eps, Tensor output_data, + int64_t output_te_dtype, std::optional output_amax, + std::optional output_scale, + std::optional output_scale_inv, int64_t scaling_mode, + Tensor rsigma, int64_t sm_margin, bool zero_centered_gamma) { + auto input_ = torch::stable::contiguous(input); + auto weight_ = torch::stable::contiguous(weight); + + auto input_cu = makeTransformerEngineTensor(input_); + auto weight_cu = makeTransformerEngineTensor(weight_); + + auto shape = getStableTensorShape(input_); + auto te_dtype = static_cast(output_te_dtype); + auto nvte_scaling = static_cast(scaling_mode); + + auto output_cu = makeQuantizedTensorWrapper(output_data, te_dtype, shape, output_amax, + output_scale, output_scale_inv, nvte_scaling); + auto rsigma_cu = makeTransformerEngineTensor(rsigma); + + auto device_idx = input_.get_device_index(); + int sm_count = getSMCount(device_idx) - static_cast(sm_margin); + auto stream = getCurrentCUDAStreamRaw(device_idx); + + runWithWorkspace( + [&](NVTETensor ws) { + nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), static_cast(eps), + output_cu.data(), rsigma_cu.data(), ws, sm_count, zero_centered_gamma, + stream); + }, + device_idx); + + return rsigma; +} + +} // namespace transformer_engine::pytorch::stable + +// Schema definitions (added to the transformer_engine_stable library) +STABLE_TORCH_LIBRARY_FRAGMENT(transformer_engine_stable, m) { + m.def( + "layernorm_bwd(Tensor dz, Tensor x, Tensor mu, Tensor rsigma, Tensor gamma, int sm_margin, " + "bool zero_centered_gamma) -> (Tensor, Tensor, Tensor)"); + m.def( + "layernorm_fwd(Tensor input, Tensor weight, Tensor? bias, float eps, int sm_margin, bool " + "zero_centered_gamma) -> (Tensor, Tensor, Tensor)"); + m.def( + "layernorm_fwd_noalloc(Tensor input, Tensor weight, Tensor? bias, float eps, Tensor " + "output_data, int output_te_dtype, Tensor? output_amax, Tensor? output_scale, Tensor? " + "output_scale_inv, int scaling_mode, Tensor mu, Tensor rsigma, int sm_margin, bool " + "zero_centered_gamma) -> (Tensor, Tensor)"); + m.def( + "rmsnorm_bwd(Tensor dz, Tensor x, Tensor rsigma, Tensor gamma, int sm_margin, bool " + "zero_centered_gamma) -> (Tensor, Tensor)"); + m.def( + "rmsnorm_fwd(Tensor input, Tensor weight, float eps, int sm_margin, bool " + "zero_centered_gamma) -> (Tensor, Tensor)"); + m.def( + "rmsnorm_fwd_noalloc(Tensor input, Tensor weight, float eps, Tensor output_data, int " + "output_te_dtype, Tensor? output_amax, Tensor? output_scale, Tensor? output_scale_inv, int " + "scaling_mode, Tensor rsigma, int sm_margin, bool zero_centered_gamma) -> Tensor"); + m.def( + "rmsnorm_bwd_add(Tensor dz, Tensor x, Tensor add, Tensor rsigma, Tensor gamma, int " + "sm_margin, bool zero_centered_gamma) -> (Tensor, Tensor)"); +} + +STABLE_TORCH_LIBRARY_IMPL(transformer_engine_stable, CUDA, m) { + using namespace transformer_engine::pytorch::stable; + m.impl("layernorm_bwd", TORCH_BOX(layernorm_bwd)); + m.impl("layernorm_fwd", TORCH_BOX(layernorm_fwd)); + m.impl("layernorm_fwd_noalloc", TORCH_BOX(layernorm_fwd_noalloc)); + m.impl("rmsnorm_bwd", TORCH_BOX(rmsnorm_bwd)); + m.impl("rmsnorm_fwd", TORCH_BOX(rmsnorm_fwd)); + m.impl("rmsnorm_fwd_noalloc", TORCH_BOX(rmsnorm_fwd_noalloc)); + m.impl("rmsnorm_bwd_add", TORCH_BOX(rmsnorm_bwd_add)); +} diff --git a/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp b/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp deleted file mode 100644 index 685250d137..0000000000 --- a/transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp +++ /dev/null @@ -1,156 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "../extensions.h" - -namespace transformer_engine::pytorch { - -void nvfp4_2d_compute_partial_amax(const at::Tensor& tensor, at::Tensor amax, size_t h, size_t w, - size_t start_offset, size_t block_len) { - TORCH_CHECK(block_len == 16, "Currently only block_len = 16 is supported for NVFP4 2D"); - TORCH_CHECK(amax.dim() == 2, "amax must be a 2D tensor"); - TORCH_CHECK(amax.scalar_type() == at::ScalarType::Float, "amax must be a float tensor"); - TORCH_CHECK(tensor.scalar_type() == at::ScalarType::Float || - tensor.scalar_type() == at::ScalarType::BFloat16, - "tensor must be a float or bfloat16 tensor"); - - const TensorWrapper tensor_cu = makeTransformerEngineTensor(tensor.contiguous()); - TensorWrapper amax_cu = makeTransformerEngineTensor(amax); - - nvte_nvfp4_2d_compute_partial_amax(tensor_cu.data(), amax_cu.data(), h, w, amax.stride(0), - amax.stride(1), start_offset, block_len, - at::cuda::getCurrentCUDAStream()); -} - -void nvfp4_2d_partial_cast(const at::Tensor& inp, py::handle out, const at::Tensor& scale, - const at::Tensor& global_scale, size_t h, size_t w, size_t start_offset, - size_t block_len) { - TORCH_CHECK(block_len == 16, "Currently only block_len = 16 is supported for NVFP4 2D"); - TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor"); - TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor"); - TORCH_CHECK(global_scale.numel() == 1, "global_scale must be a scalar tensor"); - TORCH_CHECK(global_scale.scalar_type() == at::ScalarType::Float, - "global_scale must be a float tensor"); - TORCH_CHECK( - inp.scalar_type() == at::ScalarType::Float || inp.scalar_type() == at::ScalarType::BFloat16, - "input must be a float or bfloat16 tensor"); - - const TensorWrapper inp_cu = makeTransformerEngineTensor(inp.contiguous()); - const TensorWrapper out_cu = makeTransformerEngineTensor(out, py::none()); - const TensorWrapper scale_cu = makeTransformerEngineTensor(scale); - const TensorWrapper global_scale_cu = makeTransformerEngineTensor(global_scale); - - nvte_nvfp4_2d_partial_cast(inp_cu.data(), out_cu.data(), scale_cu.data(), global_scale_cu.data(), - h, w, scale.stride(0), scale.stride(1), start_offset, block_len, - at::cuda::getCurrentCUDAStream()); -} - -void nvfp4_multi_tensor_2d_partial_cast(std::vector inp_list, - std::vector out_list, - std::vector scale_list, - std::vector global_scale_list, - std::vector h_list, std::vector w_list, - std::vector start_offset_list, int64_t block_len) { - TORCH_CHECK(block_len == 16, "Currently only block_len = 16 is supported for NVFP4 2D"); - - const size_t num_tensors = inp_list.size(); - TORCH_CHECK(out_list.size() == num_tensors, "out_list size mismatch"); - TORCH_CHECK(scale_list.size() == num_tensors, "scale_list size mismatch"); - TORCH_CHECK(global_scale_list.size() == num_tensors, "global_scale_list size mismatch"); - TORCH_CHECK(h_list.size() == num_tensors, "h_list size mismatch"); - TORCH_CHECK(w_list.size() == num_tensors, "w_list size mismatch"); - TORCH_CHECK(start_offset_list.size() == num_tensors, "start_offset_list size mismatch"); - - if (num_tensors == 0) { - return; - } - - auto stream = at::cuda::getCurrentCUDAStream(); - - for (size_t i = 0; i < num_tensors; ++i) { - const auto& inp = inp_list[i]; - const auto& out = out_list[i]; - const auto& scale = scale_list[i]; - const auto& global_scale = global_scale_list[i]; - const size_t h = static_cast(h_list[i]); - const size_t w = static_cast(w_list[i]); - const size_t start_offset = static_cast(start_offset_list[i]); - - TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor"); - TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor"); - TORCH_CHECK(global_scale.numel() == 1, "global_scale must be a scalar tensor"); - TORCH_CHECK(global_scale.scalar_type() == at::ScalarType::Float, - "global_scale must be a float tensor"); - TORCH_CHECK( - inp.scalar_type() == at::ScalarType::Float || inp.scalar_type() == at::ScalarType::BFloat16, - "input must be a float or bfloat16 tensor"); - - const TensorWrapper inp_cu = makeTransformerEngineTensor(inp.contiguous()); - const TensorWrapper out_cu = makeTransformerEngineTensor(out); - const TensorWrapper scale_cu = makeTransformerEngineTensor(scale); - const TensorWrapper global_scale_cu = makeTransformerEngineTensor(global_scale); - - nvte_nvfp4_2d_partial_cast(inp_cu.data(), out_cu.data(), scale_cu.data(), - global_scale_cu.data(), h, w, scale.stride(0), scale.stride(1), - start_offset, static_cast(block_len), stream); - } -} - -void nvfp4_multi_tensor_compute_partial_amax( - std::vector master_weight_list, std::vector partial_amax_list, - std::vector global_amax_list, std::vector h_list, - std::vector w_list, std::vector start_offset_list, int64_t block_len) { - TORCH_CHECK(block_len == 16, "Currently only block_len = 16 is supported for NVFP4 2D"); - - const size_t num_tensors = master_weight_list.size(); - TORCH_CHECK(partial_amax_list.size() == num_tensors, "partial_amax_list size mismatch"); - TORCH_CHECK(global_amax_list.size() == num_tensors, "global_amax_list size mismatch"); - TORCH_CHECK(h_list.size() == num_tensors, "h_list size mismatch"); - TORCH_CHECK(w_list.size() == num_tensors, "w_list size mismatch"); - TORCH_CHECK(start_offset_list.size() == num_tensors, "start_offset_list size mismatch"); - - if (num_tensors == 0) { - return; - } - - auto stream = at::cuda::getCurrentCUDAStream(); - - for (size_t i = 0; i < num_tensors; ++i) { - const auto& master_weight = master_weight_list[i]; - auto& partial_amax = partial_amax_list[i]; - auto& global_amax = global_amax_list[i]; - const size_t h = static_cast(h_list[i]); - const size_t w = static_cast(w_list[i]); - const size_t start_offset = static_cast(start_offset_list[i]); - - TORCH_CHECK(partial_amax.dim() == 2, "partial_amax must be a 2D tensor"); - TORCH_CHECK(partial_amax.scalar_type() == at::ScalarType::Float, - "partial_amax must be a float tensor"); - TORCH_CHECK(master_weight.scalar_type() == at::ScalarType::Float || - master_weight.scalar_type() == at::ScalarType::BFloat16, - "master_weight must be a float or bfloat16 tensor"); - TORCH_CHECK(global_amax.scalar_type() == at::ScalarType::Float, - "global_amax must be a float tensor"); - TORCH_CHECK(global_amax.numel() == 1, "global_amax must have exactly one element"); - - // Compute partial amax (per-block amax) - const TensorWrapper tensor_cu = makeTransformerEngineTensor(master_weight.contiguous()); - TensorWrapper amax_cu = makeTransformerEngineTensor(partial_amax); - - nvte_nvfp4_2d_compute_partial_amax(tensor_cu.data(), amax_cu.data(), h, w, - partial_amax.stride(0), partial_amax.stride(1), start_offset, - static_cast(block_len), stream); - - // Compute global amax - auto* global_amax_ptr = global_amax.data_ptr(); - TensorWrapper fake_te_output( - /*dptr=*/nullptr, tensor_cu.shape(), DType::kFloat32, global_amax_ptr); - - nvte_compute_amax(tensor_cu.data(), fake_te_output.data(), stream); - } -} - -} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp b/transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp index ac68727ac8..a49130e639 100644 --- a/transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp @@ -4,128 +4,153 @@ * See LICENSE for license information. ************************************************************************/ -#include "../extensions.h" +#include + +#include "../stable_common.h" #ifdef NVTE_ENABLE_NVSHMEM +// Include only host headers — this is compiled as C++, not CUDA +#define NVSHMEM_HOSTLIB_ONLY #include -#include #include +#undef NVSHMEM_HOSTLIB_ONLY #endif -#include -#include -#include -#include - -namespace transformer_engine::pytorch { +namespace transformer_engine::pytorch::stable { -void init_nvshmem_backend(c10d::ProcessGroup *process_group) { -#ifdef NVTE_ENABLE_NVSHMEM - nvshmemx_init_attr_t attr = {}; - nvshmemx_uniqueid_t id = {}; +using Tensor = torch::stable::Tensor; +using ScalarType = torch::headeronly::ScalarType; - int my_rank = process_group->getRank(); - int num_ranks = process_group->getSize(); - if (my_rank == 0) { - nvshmemx_get_uniqueid(&id); - } +// ============================================================================ +// NVSHMEM create tensor in shared memory +// ============================================================================ - auto backend_is_nccl = (process_group->getBackendType() == c10d::ProcessGroup::BackendType::NCCL); - NVTE_CHECK(backend_is_nccl, "Currently only support NCCL boostrap for NVSHMEM"); - auto datatensor = - torch::from_blob(reinterpret_cast(&id), - {static_cast(sizeof(nvshmemx_uniqueid_t) / sizeof(uint8_t))}, - at::device(torch::kCPU).dtype(torch::kUInt8)); - auto datatmp = (backend_is_nccl) ? datatensor.cuda() : datatensor; - - c10d::BroadcastOptions bcast_opts; - bcast_opts.rootRank = 0; - std::vector datachunk = {datatmp}; - auto work = process_group->broadcast(datachunk, bcast_opts); - work->wait(); - - if (backend_is_nccl) { - datatensor.copy_(datatmp.cpu()); - datatmp = torch::Tensor(); +Tensor nvshmem_create_tensor(int64_t num_elements, int64_t scalar_type_int, int64_t device_idx) { +#ifdef NVTE_ENABLE_NVSHMEM + auto dtype = static_cast(scalar_type_int); + size_t elem_size = 0; + switch (dtype) { + case ScalarType::Float: + elem_size = 4; + break; + case ScalarType::Half: + elem_size = 2; + break; + case ScalarType::BFloat16: + elem_size = 2; + break; + case ScalarType::Byte: + elem_size = 1; + break; + case ScalarType::Long: + elem_size = 8; + break; + case ScalarType::Int: + elem_size = 4; + break; + case ScalarType::Double: + elem_size = 8; + break; + default: + STD_TORCH_CHECK(false, "Unsupported dtype for nvshmem_create_tensor"); } - - nvshmemx_set_attr_uniqueid_args(my_rank, num_ranks, &id, &attr); - nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr); - - NVTE_CHECK(my_rank == nvshmem_my_pe(), "my_rank: ", my_rank, - " != nvshmem_my_pe(): ", nvshmem_my_pe()); - NVTE_CHECK(num_ranks == nvshmem_n_pes(), "num_ranks: ", num_ranks, - " != nvshmem_n_pes(): ", nvshmem_n_pes()); + size_t total_bytes = static_cast(num_elements) * elem_size; + void *ptr = nvshmem_malloc(total_bytes); + STD_TORCH_CHECK(ptr != nullptr, "nvshmem_malloc failed for ", total_bytes, " bytes"); + + // Wrap with from_blob and nvshmem_free as deleter + std::vector shape = {num_elements}; + std::vector strides = {1}; + auto device = torch::stable::Device(torch::stable::DeviceType::CUDA, device_idx); + return torch::stable::from_blob(ptr, shape, strides, device, dtype, nvshmem_free); #else - NVTE_ERROR("Internal TE error: init_nvshmem_backend cannot be initialized with valid PyTorch ", - "distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"); + STD_TORCH_CHECK(false, "NVSHMEM not available. Build with NVTE_ENABLE_NVSHMEM=1."); #endif } -void nvshmem_wait_on_current_stream(torch::Tensor signal, const std::string &wait_kind) { +// ============================================================================ +// NVSHMEM wait on current CUDA stream +// +// Uses CUDA driver API (cuStreamWaitValue64/cuStreamWriteValue64) which +// doesn't require NVSHMEM device code. Supports stream_wait and nvshmem_wait +// modes. kernel_wait mode requires device code and is not supported. +// ============================================================================ + +void nvshmem_wait_on_current_stream(Tensor signal, int64_t wait_kind_int) { #ifdef NVTE_ENABLE_NVSHMEM uint64_t *sig_addr = reinterpret_cast(signal.data_ptr()); - cudaStream_t cur_stream = (cudaStream_t)at::cuda::getCurrentCUDAStream(); - - WaitKind wait_kind_enum = WaitKind::STREAM_WAIT; - - if (wait_kind == "kernel") { - wait_kind_enum = WaitKind::KERNEL_WAIT; - } else if (wait_kind == "nvshmem") { - wait_kind_enum = WaitKind::NVSHMEM_WAIT; - } else if (wait_kind == "stream") { - wait_kind_enum = WaitKind::STREAM_WAIT; - } else { - NVTE_ERROR("Invalid wait kind: ", wait_kind); + auto stream = getCurrentCUDAStreamRaw(signal.get_device_index()); + uint64_t wait_value = 1; + uint64_t signal_reset = 0; + + // WaitKind: 0=KERNEL_WAIT, 1=NVSHMEM_WAIT, 2=STREAM_WAIT + switch (wait_kind_int) { + case 0: // KERNEL_WAIT — requires device code + STD_TORCH_CHECK(false, + "KERNEL_WAIT mode requires NVSHMEM device code. " + "Use 'stream' or 'nvshmem' wait_kind instead."); + break; + case 1: // NVSHMEM_WAIT — use nvshmemx host API + driver reset + nvshmemx_uint64_wait_until_on_stream(sig_addr, NVSHMEM_CMP_EQ, wait_value, stream); + { + CUresult res = cuStreamWriteValue64( + reinterpret_cast(stream), reinterpret_cast(sig_addr), + static_cast(signal_reset), CU_STREAM_WRITE_VALUE_DEFAULT); + STD_TORCH_CHECK(res == CUDA_SUCCESS, "cuStreamWriteValue64 failed"); + } + break; + case 2: // STREAM_WAIT — pure CUDA driver API + default: { + CUresult res = cuStreamWaitValue64( + reinterpret_cast(stream), reinterpret_cast(sig_addr), + static_cast(wait_value), CU_STREAM_WAIT_VALUE_GEQ); + STD_TORCH_CHECK(res == CUDA_SUCCESS, "cuStreamWaitValue64 failed"); + res = cuStreamWriteValue64( + reinterpret_cast(stream), reinterpret_cast(sig_addr), + static_cast(signal_reset), CU_STREAM_WRITE_VALUE_DEFAULT); + STD_TORCH_CHECK(res == CUDA_SUCCESS, "cuStreamWriteValue64 failed"); + } break; } - nvshmem_wait_on_stream(sig_addr, wait_kind_enum, cur_stream); - #else - NVTE_ERROR( - "Internal TE error: nvshmem_wait_on_current_stream cannot be initialized with valid PyTorch ", - "distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"); + STD_TORCH_CHECK(false, "NVSHMEM not available. Build with NVTE_ENABLE_NVSHMEM=1."); #endif } -torch::Tensor create_nvshmem_tensor(const std::vector &shape, c10::ScalarType dtype) { -#ifdef NVTE_ENABLE_NVSHMEM - auto option_gpu = - at::TensorOptions().dtype(dtype).device(at::kCUDA).device_index(c10::cuda::current_device()); - auto size = torch::elementSize(dtype) * - std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>()); - return at::from_blob( - nvshmem_malloc(size), shape, [](void *ptr) { nvshmem_free(ptr); }, option_gpu); -#else - NVTE_ERROR("Internal TE error: create_nvshmem_tensor cannot be initialized with valid PyTorch ", - "distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"); -#endif -} +// ============================================================================ +// NVSHMEM send with signal on current CUDA stream +// ============================================================================ -void nvshmem_send_on_current_stream(torch::Tensor src, torch::Tensor dst, int peer, - torch::Tensor signal) { +void nvshmem_send_on_current_stream(Tensor src, Tensor dst, int64_t peer, Tensor signal) { #ifdef NVTE_ENABLE_NVSHMEM - void *src_ptr = reinterpret_cast(src.data_ptr()); - void *dst_ptr = reinterpret_cast(dst.data_ptr()); + void *src_ptr = src.data_ptr(); + void *dst_ptr = dst.data_ptr(); uint64_t *sig_addr = reinterpret_cast(signal.data_ptr()); - auto nelement = src.numel() * src.element_size(); + size_t nelement = static_cast(src.numel()) * src.element_size(); uint64_t sigval = 1; - at::cuda::CUDAStream cur_stream = at::cuda::getCurrentCUDAStream(); - + auto stream = getCurrentCUDAStreamRaw(src.get_device_index()); nvshmemx_putmem_signal_on_stream(dst_ptr, src_ptr, nelement, sig_addr, sigval, NVSHMEM_SIGNAL_SET, - peer, (cudaStream_t)cur_stream); + static_cast(peer), stream); #else - NVTE_ERROR( - "Internal TE error: nvshmem_send_on_current_stream cannot be initialized with valid PyTorch ", - "distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"); + STD_TORCH_CHECK(false, "NVSHMEM not available. Build with NVTE_ENABLE_NVSHMEM=1."); #endif } -void nvshmem_finalize() { -#ifdef NVTE_ENABLE_NVSHMEM - nvshmem_finalize(); -#else - NVTE_ERROR("Internal TE error: nvshmem_finalize cannot be initialized with valid PyTorch ", - "distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"); -#endif + +} // namespace transformer_engine::pytorch::stable + +STABLE_TORCH_LIBRARY_FRAGMENT(transformer_engine_stable, m) { + m.def("nvshmem_create_tensor(int num_elements, int scalar_type, int device_idx) -> Tensor"); + m.def("nvshmem_wait_on_current_stream(Tensor signal, int wait_kind) -> ()"); + m.def("nvshmem_send_on_current_stream(Tensor src, Tensor dst, int peer, Tensor signal) -> ()"); } -} // namespace transformer_engine::pytorch +STABLE_TORCH_LIBRARY_IMPL(transformer_engine_stable, CUDA, m) { + using namespace transformer_engine::pytorch::stable; + m.impl("nvshmem_wait_on_current_stream", TORCH_BOX(nvshmem_wait_on_current_stream)); + m.impl("nvshmem_send_on_current_stream", TORCH_BOX(nvshmem_send_on_current_stream)); +} + +// nvshmem_create_tensor has no tensor input args, use CompositeImplicitAutograd +STABLE_TORCH_LIBRARY_IMPL(transformer_engine_stable, CompositeImplicitAutograd, m) { + using namespace transformer_engine::pytorch::stable; + m.impl("nvshmem_create_tensor", TORCH_BOX(nvshmem_create_tensor)); +} diff --git a/transformer_engine/pytorch/csrc/extensions/padding.cpp b/transformer_engine/pytorch/csrc/extensions/padding.cpp index 6c66fda015..66733d76aa 100644 --- a/transformer_engine/pytorch/csrc/extensions/padding.cpp +++ b/transformer_engine/pytorch/csrc/extensions/padding.cpp @@ -4,52 +4,52 @@ * See LICENSE for license information. ************************************************************************/ -#include "../extensions.h" -#include "pybind.h" +#include -namespace transformer_engine::pytorch { +#include "../stable_common.h" -void fused_multi_row_padding(at::Tensor input, at::Tensor output, - std::vector input_row_list, - std::vector padded_input_row_list) { +namespace transformer_engine::pytorch::stable { + +using Tensor = torch::stable::Tensor; + +void fused_multi_row_padding(Tensor input, Tensor output, std::vector input_row_list, + std::vector padded_input_row_list) { NVTE_CHECK(input_row_list.size() == padded_input_row_list.size(), "Number of input row list and padded row list must match."); NVTE_CHECK(input.dim() == 2, "Dimension of input must equal 2."); - NVTE_CHECK(output.dim() == 2, "Dimension of output must equal 2."); + NVTE_CHECK(output.dim() == 2, "Dimension of output must equal 2."); const auto num_tensors = input_row_list.size(); - // Extract properties from PyTorch tensors std::vector input_dptr_list, output_dptr_list; std::vector> input_shape_list, output_shape_list; std::vector input_type_list; - void* d_input_ptr = reinterpret_cast(input.data_ptr()); - void* d_output_ptr = reinterpret_cast(output.data_ptr()); + void* d_input_ptr = input.data_ptr(); + void* d_output_ptr = output.data_ptr(); + for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { input_dptr_list.push_back(d_input_ptr); output_dptr_list.push_back(d_output_ptr); - // Move the input pointer to the next split. char* input_char_ptr = reinterpret_cast(d_input_ptr); - const size_t input_dptr_offset = - input_row_list[tensor_id] * input.size(1) * input.element_size(); + const size_t input_dptr_offset = static_cast(input_row_list[tensor_id]) * + static_cast(input.size(1)) * input.element_size(); input_char_ptr += input_dptr_offset; d_input_ptr = reinterpret_cast(input_char_ptr); - input_shape_list.push_back({input_row_list[tensor_id], static_cast(input.size(1))}); + input_shape_list.push_back( + {static_cast(input_row_list[tensor_id]), static_cast(input.size(1))}); input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); - // Move the output pointer to the next split. char* output_char_ptr = reinterpret_cast(d_output_ptr); - const size_t output_dptr_offset = - padded_input_row_list[tensor_id] * output.size(1) * output.element_size(); + const size_t output_dptr_offset = static_cast(padded_input_row_list[tensor_id]) * + static_cast(output.size(1)) * output.element_size(); output_char_ptr += output_dptr_offset; d_output_ptr = reinterpret_cast(output_char_ptr); - output_shape_list.push_back( - {padded_input_row_list[tensor_id], static_cast(output.size(1))}); + output_shape_list.push_back({static_cast(padded_input_row_list[tensor_id]), + static_cast(output.size(1))}); } - // Construct TE tensors std::vector nvte_input_list, nvte_output_list; std::vector tensor_wrappers; auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, @@ -65,70 +65,59 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, make_tensor(input_dptr_list[i], input_shape_list[i], input_type_list[i])); nvte_output_list.emplace_back( make_tensor(output_dptr_list[i], output_shape_list[i], input_type_list[i])); - padded_num_rows_list.emplace_back(padded_input_row_list[i]); + padded_num_rows_list.emplace_back(static_cast(padded_input_row_list[i])); } - // Check tensor lists NVTE_CHECK(nvte_output_list.size() == nvte_input_list.size(), "Number of input and output tensors must match"); - NVTE_CHECK(padded_num_rows_list.size() == nvte_input_list.size() && - "Number of input and padded row list must match"); - - // Launch TE kernel - NVTE_SCOPED_GIL_RELEASE({ - nvte_multi_padding(nvte_input_list.size(), nvte_input_list.data(), nvte_output_list.data(), - padded_num_rows_list.data(), at::cuda::getCurrentCUDAStream()); - }); -} -void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, - std::vector input_row_list, - std::vector unpadded_input_row_list) { - using namespace transformer_engine; - using namespace transformer_engine::pytorch; + nvte_multi_padding(nvte_input_list.size(), nvte_input_list.data(), nvte_output_list.data(), + padded_num_rows_list.data(), + getCurrentCUDAStreamRaw(input.get_device_index())); +} +void fused_multi_row_unpadding(Tensor input, Tensor output, std::vector input_row_list, + std::vector unpadded_input_row_list) { NVTE_CHECK(input_row_list.size() == unpadded_input_row_list.size(), "Number of input row list and padded row list must match."); NVTE_CHECK(input.dim() == 2, "Dimension of input must equal 2."); - NVTE_CHECK(output.dim() == 2, "Dimension of output must equal 2."); + NVTE_CHECK(output.dim() == 2, "Dimension of output must equal 2."); const auto num_tensors = input_row_list.size(); - // Extract properties from PyTorch tensors std::vector input_dptr_list, output_dptr_list; std::vector> input_shape_list, output_shape_list; - std::vector input_type_list; - void* d_input_ptr = reinterpret_cast(input.data_ptr()); - void* d_output_ptr = reinterpret_cast(output.data_ptr()); + std::vector input_type_list; + void* d_input_ptr = input.data_ptr(); + void* d_output_ptr = output.data_ptr(); + for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { input_dptr_list.push_back(d_input_ptr); output_dptr_list.push_back(d_output_ptr); - // Move the input pointer to the next split. char* input_char_ptr = reinterpret_cast(d_input_ptr); - const size_t input_dptr_offset = - input_row_list[tensor_id] * input.size(1) * input.element_size(); + const size_t input_dptr_offset = static_cast(input_row_list[tensor_id]) * + static_cast(input.size(1)) * input.element_size(); input_char_ptr += input_dptr_offset; d_input_ptr = reinterpret_cast(input_char_ptr); - input_shape_list.push_back({input_row_list[tensor_id], static_cast(input.size(1))}); + input_shape_list.push_back( + {static_cast(input_row_list[tensor_id]), static_cast(input.size(1))}); input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); - // Move the output pointer to the next split. char* output_char_ptr = reinterpret_cast(d_output_ptr); - const size_t output_dptr_offset = - unpadded_input_row_list[tensor_id] * output.size(1) * output.element_size(); + const size_t output_dptr_offset = static_cast(unpadded_input_row_list[tensor_id]) * + static_cast(output.size(1)) * output.element_size(); output_char_ptr += output_dptr_offset; d_output_ptr = reinterpret_cast(output_char_ptr); - output_shape_list.push_back( - {unpadded_input_row_list[tensor_id], static_cast(output.size(1))}); + output_shape_list.push_back({static_cast(unpadded_input_row_list[tensor_id]), + static_cast(output.size(1))}); } - // Construct TE tensors std::vector nvte_input_list, nvte_output_list; - std::vector tensor_wrappers; + std::vector tensor_wrappers; auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, - transformer_engine::DType dtype) -> NVTETensor { + DType dtype) -> NVTETensor { tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype)); return tensor_wrappers.back().data(); }; @@ -140,18 +129,20 @@ void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, make_tensor(input_dptr_list[i], input_shape_list[i], input_type_list[i])); nvte_output_list.emplace_back( make_tensor(output_dptr_list[i], output_shape_list[i], input_type_list[i])); - unpadded_num_rows_list.emplace_back(unpadded_input_row_list[i]); + unpadded_num_rows_list.emplace_back(static_cast(unpadded_input_row_list[i])); } - // Check tensor lists NVTE_CHECK(nvte_output_list.size() == nvte_input_list.size(), "Number of input and output tensors must match"); - NVTE_CHECK(unpadded_num_rows_list.size() == nvte_input_list.size() && - "Number of input and padded row list must match"); - // Launch TE kernel nvte_multi_unpadding(nvte_input_list.size(), nvte_input_list.data(), nvte_output_list.data(), - unpadded_num_rows_list.data(), at::cuda::getCurrentCUDAStream()); + unpadded_num_rows_list.data(), + getCurrentCUDAStreamRaw(input.get_device_index())); +} + +STABLE_TORCH_LIBRARY_IMPL(transformer_engine_stable, CUDA, m) { + m.impl("fused_multi_row_padding", TORCH_BOX(fused_multi_row_padding)); + m.impl("fused_multi_row_unpadding", TORCH_BOX(fused_multi_row_unpadding)); } -} // namespace transformer_engine::pytorch +} // namespace transformer_engine::pytorch::stable diff --git a/transformer_engine/pytorch/csrc/extensions/partial_cast.cpp b/transformer_engine/pytorch/csrc/extensions/partial_cast.cpp new file mode 100644 index 0000000000..6673cf1434 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/partial_cast.cpp @@ -0,0 +1,122 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include + +#include "../stable_common.h" + +namespace transformer_engine::pytorch::stable { + +using Tensor = torch::stable::Tensor; + +// FP8 block scaling +void fp8_block_scaling_compute_partial_amax(Tensor tensor, Tensor amax, int64_t h, int64_t w, + int64_t start_offset, int64_t block_len) { + auto t_cu = makeTransformerEngineTensor(tensor); + auto a_cu = makeTransformerEngineTensor(amax); + nvte_fp8_block_scaling_compute_partial_amax(t_cu.data(), a_cu.data(), h, w, amax.stride(0), + amax.stride(1), start_offset, block_len, + getCurrentCUDAStreamRaw(tensor.get_device_index())); +} + +void fp8_block_scaling_partial_cast(Tensor inp, Tensor out, Tensor scale, int64_t h, int64_t w, + int64_t start_offset, int64_t block_len, int64_t out_dtype) { + auto i_cu = makeTransformerEngineTensor(inp); + auto o_cu = makeTransformerEngineTensor(out); + auto s_cu = makeTransformerEngineTensor(scale); + nvte_fp8_block_scaling_partial_cast(i_cu.data(), o_cu.data(), s_cu.data(), h, w, scale.stride(0), + scale.stride(1), start_offset, block_len, + static_cast(out_dtype), + getCurrentCUDAStreamRaw(inp.get_device_index())); +} + +// MXFP8 scaling +void mxfp8_scaling_compute_partial_amax(Tensor input, Tensor amax_rowwise, Tensor amax_colwise, + int64_t rows, int64_t cols, int64_t start_offset) { + auto i_cu = makeTransformerEngineTensor(input); + auto ar_cu = makeTransformerEngineTensor(amax_rowwise); + auto ac_cu = makeTransformerEngineTensor(amax_colwise); + nvte_mxfp8_scaling_compute_partial_amax(i_cu.data(), ar_cu.data(), ac_cu.data(), rows, cols, + start_offset, + getCurrentCUDAStreamRaw(input.get_device_index())); +} + +void mxfp8_scaling_partial_cast(Tensor input, Tensor output_rowwise, Tensor output_colwise, + Tensor scale_inv_rowwise, Tensor scale_inv_colwise, int64_t rows, + int64_t cols, int64_t start_offset) { + auto i_cu = makeTransformerEngineTensor(input); + auto or_cu = makeTransformerEngineTensor(output_rowwise); + auto oc_cu = makeTransformerEngineTensor(output_colwise); + auto sr_cu = makeTransformerEngineTensor(scale_inv_rowwise); + auto sc_cu = makeTransformerEngineTensor(scale_inv_colwise); + nvte_mxfp8_scaling_partial_cast(i_cu.data(), or_cu.data(), oc_cu.data(), sr_cu.data(), + sc_cu.data(), rows, cols, start_offset, + getCurrentCUDAStreamRaw(input.get_device_index())); +} + +// NVFP4 2D +void nvfp4_2d_compute_partial_amax(Tensor tensor, Tensor amax, int64_t h, int64_t w, + int64_t start_offset, int64_t block_len) { + auto t_cu = makeTransformerEngineTensor(tensor); + auto a_cu = makeTransformerEngineTensor(amax); + nvte_nvfp4_2d_compute_partial_amax(t_cu.data(), a_cu.data(), h, w, amax.stride(0), amax.stride(1), + start_offset, block_len, + getCurrentCUDAStreamRaw(tensor.get_device_index())); +} + +void nvfp4_2d_partial_cast_noalloc(Tensor inp, Tensor out_data, int64_t out_dtype, + std::optional out_scale_inv, int64_t out_scaling_mode, + Tensor scale, Tensor global_scale, int64_t h, int64_t w, + int64_t start_offset, int64_t block_len) { + auto i_cu = makeTransformerEngineTensor(inp); + auto out_shape = getStableTensorShape(out_data); + auto o_cu = makeQuantizedTensorWrapper(out_data, static_cast(out_dtype), out_shape, + std::nullopt, std::nullopt, out_scale_inv, + static_cast(out_scaling_mode)); + auto s_cu = makeTransformerEngineTensor(scale); + auto gs_cu = makeTransformerEngineTensor(global_scale); + nvte_nvfp4_2d_partial_cast(i_cu.data(), o_cu.data(), s_cu.data(), gs_cu.data(), h, w, + scale.stride(0), scale.stride(1), start_offset, block_len, + getCurrentCUDAStreamRaw(inp.get_device_index())); +} + +} // namespace transformer_engine::pytorch::stable + +STABLE_TORCH_LIBRARY_FRAGMENT(transformer_engine_stable, m) { + m.def( + "fp8_block_scaling_compute_partial_amax(Tensor tensor, Tensor amax, int h, int w, int " + "start_offset, int block_len) -> ()"); + m.def( + "fp8_block_scaling_partial_cast(Tensor inp, Tensor out, Tensor scale, int h, int w, int " + "start_offset, int block_len, int out_dtype) -> ()"); + m.def( + "mxfp8_scaling_compute_partial_amax(Tensor input, Tensor amax_rowwise, Tensor amax_colwise, " + "int rows, int cols, int start_offset) -> ()"); + m.def( + "mxfp8_scaling_partial_cast(Tensor input, Tensor output_rowwise, Tensor output_colwise, " + "Tensor scale_inv_rowwise, Tensor scale_inv_colwise, int rows, int cols, int start_offset) " + "-> ()"); + m.def( + "nvfp4_2d_compute_partial_amax(Tensor tensor, Tensor amax, int h, int w, int start_offset, " + "int block_len) -> ()"); + m.def( + "nvfp4_2d_partial_cast_noalloc(Tensor inp, Tensor out_data, int out_dtype, Tensor? " + "out_scale_inv, int out_scaling_mode, Tensor scale, Tensor global_scale, int h, int w, int " + "start_offset, int block_len) -> ()"); +} + +STABLE_TORCH_LIBRARY_IMPL(transformer_engine_stable, CUDA, m) { + using namespace transformer_engine::pytorch::stable; + m.impl("fp8_block_scaling_compute_partial_amax", + TORCH_BOX(fp8_block_scaling_compute_partial_amax)); + m.impl("fp8_block_scaling_partial_cast", TORCH_BOX(fp8_block_scaling_partial_cast)); + m.impl("mxfp8_scaling_compute_partial_amax", TORCH_BOX(mxfp8_scaling_compute_partial_amax)); + m.impl("mxfp8_scaling_partial_cast", TORCH_BOX(mxfp8_scaling_partial_cast)); + m.impl("nvfp4_2d_compute_partial_amax", TORCH_BOX(nvfp4_2d_compute_partial_amax)); + m.impl("nvfp4_2d_partial_cast_noalloc", TORCH_BOX(nvfp4_2d_partial_cast_noalloc)); +} diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cpp b/transformer_engine/pytorch/csrc/extensions/permutation.cpp index 226705b169..32c6d86b59 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cpp @@ -4,155 +4,149 @@ * See LICENSE for license information. ************************************************************************/ -#include "../extensions.h" - -namespace transformer_engine::pytorch { - -std::tuple> moe_permute_fwd( - at::Tensor input, const DType dtype, at::Tensor indices, int64_t num_out_tokens, - std::vector workspace, int64_t max_expanded_token_num) { - const int num_tokens = input.size(0); - int num_cols = input.size(1); - const int topK = indices.size(1); - - // Initialize the workspace on the first run - if (workspace.empty()) { - auto options = - torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false); - - at::Tensor sorted_indices = torch::empty(max_expanded_token_num, options); - at::Tensor row_id = torch::range(0, max_expanded_token_num - 1, 1, options); - at::Tensor sorted_row_id = - torch::empty(max_expanded_token_num, - torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); - - size_t temp_storage_bytes = 0; - nvte_device_radix_sort_pairs(nullptr, &temp_storage_bytes, nullptr, nullptr, nullptr, nullptr, - max_expanded_token_num); - at::Tensor temp_storage = torch::empty( - temp_storage_bytes, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); - - workspace.push_back(sorted_indices); - workspace.push_back(row_id); - workspace.push_back(sorted_row_id); - workspace.push_back(temp_storage); - } - - void *indices_ptr = getDataPtr(indices, 0); - void *sorted_indices_ptr = getDataPtr(workspace[0], 0); - void *row_id_ptr = getDataPtr(workspace[1], 0); - void *sorted_row_id_ptr = getDataPtr(workspace[2], 0); - - void *d_temp_storage = getDataPtr(workspace[3], 0); - size_t temp_storage_bytes = std::numeric_limits::max(); - - nvte_device_radix_sort_pairs( - d_temp_storage, &temp_storage_bytes, reinterpret_cast(indices_ptr), - reinterpret_cast(sorted_indices_ptr), reinterpret_cast(row_id_ptr), - reinterpret_cast(sorted_row_id_ptr), num_tokens * topK); - - // Output buffer alloc - num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK; - at::Tensor permuted_output = - torch::empty({num_out_tokens, num_cols}, - torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false)); - at::Tensor row_id_map = torch::empty( - {num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); - - auto stream = at::cuda::getCurrentCUDAStream().stream(); +#include + +#include "../stable_common.h" + +namespace transformer_engine::pytorch::stable { + +using Tensor = torch::stable::Tensor; + +// ============================================================================ +// MOE Permutation forward +// +// The workspace tensors (sorted_indices, row_id, sorted_row_id, temp_storage) +// are allocated on first call and reused. In stable ABI, the Python shim +// manages the workspace list. +// ============================================================================ + +std::tuple moe_permute_fwd(Tensor input, int64_t dtype, Tensor sorted_row_id, + Tensor row_id_map, int64_t num_tokens, int64_t topK, + int64_t num_out_tokens) { + auto te_dtype = static_cast(dtype); + auto shape = getStableTensorShape(input); + NVTE_CHECK(shape.size() == 2, "Permutation input must be 2D."); + const size_t num_cols = shape[1]; + + auto device_idx = input.get_device_index(); + int64_t actual_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK; + + auto permuted_output = allocateStableTensor({actual_out_tokens, static_cast(num_cols)}, + GetStableScalarType(te_dtype), device_idx); auto input_cu = makeTransformerEngineTensor( - input.data_ptr(), - std::vector{static_cast(input.size(0)), static_cast(num_cols)}, - dtype); - auto permuted_output_cu = - makeTransformerEngineTensor(permuted_output.data_ptr(), - std::vector{static_cast(permuted_output.size(0)), - static_cast(num_cols)}, - dtype); - auto sorted_row_id_cu = makeTransformerEngineTensor( - sorted_row_id_ptr, std::vector{static_cast(num_tokens * topK)}, - DType::kInt32); + input.data_ptr(), std::vector{static_cast(num_tokens * topK), num_cols}, + te_dtype); + auto output_cu = makeTransformerEngineTensor( + permuted_output.data_ptr(), + std::vector{static_cast(actual_out_tokens), num_cols}, te_dtype); + auto sorted_row_id_cu = makeTransformerEngineTensor(sorted_row_id); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); + TensorWrapper empty; - nvte_permute(input_cu.data(), permuted_output_cu.data(), sorted_row_id_cu.data(), - row_id_map_cu.data(), TensorWrapper().data(), TensorWrapper().data(), - TensorWrapper().data(), num_tokens, topK, num_cols, num_out_tokens, stream); + auto stream = getCurrentCUDAStreamRaw(device_idx); + nvte_permute(input_cu.data(), output_cu.data(), sorted_row_id_cu.data(), row_id_map_cu.data(), + empty.data(), empty.data(), empty.data(), static_cast(num_tokens), + static_cast(topK), num_cols, static_cast(actual_out_tokens), stream); - return std::make_tuple(permuted_output, row_id_map, workspace); + return std::make_tuple(permuted_output, row_id_map); } -at::Tensor moe_permute_bwd(at::Tensor input, const DType dtype, at::Tensor row_id_map, - at::Tensor prob, int64_t num_tokens, int64_t topK) { - return moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, topK); -} +// ============================================================================ +// MOE Unpermute forward (also used as permute backward) +// ============================================================================ -at::Tensor moe_unpermute_fwd(at::Tensor input, const DType dtype, at::Tensor row_id_map, - at::Tensor prob, int64_t num_tokens, int64_t topK) { - int num_cols = input.size(1); +Tensor moe_unpermute_fwd(Tensor input, int64_t dtype, Tensor row_id_map, Tensor prob, + int64_t num_tokens, int64_t topK) { + auto te_dtype = static_cast(dtype); + auto shape = getStableTensorShape(input); + NVTE_CHECK(shape.size() == 2, "Unpermutation input must be 2D."); + const size_t num_cols = shape[1]; - // Output buffer alloc - at::Tensor unpermuted_output = - torch::empty({num_tokens, num_cols}, - torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false)); - - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto device_idx = input.get_device_index(); + auto unpermuted_output = allocateStableTensor({num_tokens, static_cast(num_cols)}, + GetStableScalarType(te_dtype), device_idx); auto input_cu = makeTransformerEngineTensor( input.data_ptr(), - std::vector{static_cast(input.size(0)), static_cast(num_cols)}, - dtype); - auto unpermuted_output_cu = makeTransformerEngineTensor( - unpermuted_output.data_ptr(), - std::vector{static_cast(unpermuted_output.size(0)), - static_cast(num_cols)}, - dtype); + std::vector{static_cast(num_tokens) * static_cast(topK), num_cols}, + te_dtype); + auto output_cu = makeTransformerEngineTensor( + unpermuted_output.data_ptr(), std::vector{static_cast(num_tokens), num_cols}, + te_dtype); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); auto prob_cu = makeTransformerEngineTensor(prob); - nvte_unpermute(input_cu.data(), unpermuted_output_cu.data(), row_id_map_cu.data(), prob_cu.data(), - num_tokens, topK, num_cols, stream); + nvte_unpermute(input_cu.data(), output_cu.data(), row_id_map_cu.data(), prob_cu.data(), + static_cast(num_tokens), static_cast(topK), num_cols, + getCurrentCUDAStreamRaw(device_idx)); return unpermuted_output; } -std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd, - const DType dtype, at::Tensor row_id_map, - at::Tensor prob) { - const int topK = (prob.numel() > 0) ? prob.size(1) : 1; - const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0); - int num_cols = input_bwd.size(1); +// ============================================================================ +// MOE Unpermute backward +// ============================================================================ + +std::tuple moe_unpermute_bwd(Tensor input_bwd, Tensor input_fwd, int64_t dtype, + Tensor row_id_map, Tensor prob) { + auto te_dtype = static_cast(dtype); + auto bwd_shape = getStableTensorShape(input_bwd); + NVTE_CHECK(bwd_shape.size() == 2, "Input must be 2D."); + const size_t num_cols = bwd_shape[1]; - // Output buffer alloc - at::Tensor act_grad = - torch::empty({input_fwd.size(0), num_cols}, - torch::dtype(input_bwd.scalar_type()).device(torch::kCUDA).requires_grad(false)); - at::Tensor prob_grad = torch::empty( - {num_tokens, topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); + auto prob_shape = getStableTensorShape(prob); + const size_t topK = (prob.numel() > 0) ? prob_shape[1] : 1; + const size_t num_tokens = + (prob.numel() > 0) ? prob_shape[0] : getStableTensorShape(row_id_map)[0]; - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto fwd_shape = getStableTensorShape(input_fwd); + + auto device_idx = input_bwd.get_device_index(); + auto act_grad = + allocateStableTensor({static_cast(fwd_shape[0]), static_cast(num_cols)}, + GetStableScalarType(te_dtype), device_idx); + auto prob_grad = + allocateStableTensorZeros({static_cast(num_tokens), static_cast(topK)}, + ScalarType::Float, device_idx); auto input_bwd_cu = makeTransformerEngineTensor( - input_bwd.data_ptr(), - std::vector{static_cast(input_bwd.size(0)), static_cast(num_cols)}, - dtype); + input_bwd.data_ptr(), std::vector{bwd_shape[0], num_cols}, te_dtype); auto act_grad_cu = makeTransformerEngineTensor( - act_grad.data_ptr(), - std::vector{static_cast(act_grad.size(0)), static_cast(num_cols)}, - dtype); - auto input_fwd_cu = makeTransformerEngineTensor( - input_fwd.data_ptr(), - std::vector{static_cast(input_fwd.size(0)), static_cast(num_cols)}, - dtype); + act_grad.data_ptr(), std::vector{static_cast(fwd_shape[0]), num_cols}, + te_dtype); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); auto prob_cu = makeTransformerEngineTensor(prob); auto prob_grad_cu = makeTransformerEngineTensor(prob_grad); + auto input_fwd_cu = makeTransformerEngineTensor( + input_fwd.data_ptr(), std::vector{static_cast(fwd_shape[0]), num_cols}, + te_dtype); + TensorWrapper empty; - nvte_permute(input_bwd_cu.data(), act_grad_cu.data(), TensorWrapper().data(), - row_id_map_cu.data(), prob_cu.data(), prob_grad_cu.data(), input_fwd_cu.data(), - num_tokens, topK, num_cols, 0, stream); + nvte_permute(input_bwd_cu.data(), act_grad_cu.data(), empty.data(), row_id_map_cu.data(), + prob_cu.data(), prob_grad_cu.data(), input_fwd_cu.data(), num_tokens, topK, num_cols, + 0, getCurrentCUDAStreamRaw(device_idx)); return std::make_tuple(act_grad, prob_grad); } -} // namespace transformer_engine::pytorch +} // namespace transformer_engine::pytorch::stable + +STABLE_TORCH_LIBRARY_FRAGMENT(transformer_engine_stable, m) { + m.def( + "moe_permute_fwd(Tensor input, int dtype, Tensor sorted_row_id, Tensor row_id_map, int " + "num_tokens, int topK, int num_out_tokens) -> (Tensor, Tensor)"); + m.def( + "moe_unpermute_fwd(Tensor input, int dtype, Tensor row_id_map, Tensor prob, int num_tokens, " + "int topK) -> Tensor"); + m.def( + "moe_unpermute_bwd(Tensor input_bwd, Tensor input_fwd, int dtype, Tensor row_id_map, Tensor " + "prob) -> (Tensor, Tensor)"); +} + +STABLE_TORCH_LIBRARY_IMPL(transformer_engine_stable, CUDA, m) { + using namespace transformer_engine::pytorch::stable; + m.impl("moe_permute_fwd", TORCH_BOX(moe_permute_fwd)); + m.impl("moe_unpermute_fwd", TORCH_BOX(moe_unpermute_fwd)); + m.impl("moe_unpermute_bwd", TORCH_BOX(moe_unpermute_bwd)); +} diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp deleted file mode 100644 index c590a3c9e2..0000000000 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ /dev/null @@ -1,624 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "pybind.h" - -#include -#include -#include -#include -#include - -#include -#include -#include - -#include "../common.h" -#include "../extensions.h" -#include "common.h" - -namespace transformer_engine::pytorch { - -PyTypeObject *Float8TensorPythonClass = nullptr; /// TODO Remove -PyTypeObject *Float8TensorStoragePythonClass = nullptr; -PyTypeObject *Float8QuantizerClass = nullptr; -PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr; -PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove -PyTypeObject *MXFP8TensorStoragePythonClass = nullptr; -PyTypeObject *MXFP8QuantizerClass = nullptr; -PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr; -PyTypeObject *Float8BlockwiseQTensorStoragePythonClass = nullptr; -PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; -PyTypeObject *NVFP4TensorPythonClass = nullptr; -PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; -PyTypeObject *NVFP4QuantizerClass = nullptr; -PyTypeObject *GroupedTensorPythonClass = nullptr; -PyTypeObject *GroupedTensorStoragePythonClass = nullptr; -std::once_flag extension_init_flag; - -void init_float8_extension() { - auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor"); - Float8QuantizerClass = - reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer")); - Float8CurrentScalingQuantizerClass = reinterpret_cast( - PyObject_GetAttrString(fp8_module.ptr(), "Float8CurrentScalingQuantizer")); - Float8TensorPythonClass = - reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "Float8Tensor")); - auto fp8_base_module = - py::module_::import("transformer_engine.pytorch.tensor.storage.float8_tensor_storage"); - Float8TensorStoragePythonClass = reinterpret_cast( - PyObject_GetAttrString(fp8_base_module.ptr(), "Float8TensorStorage")); - NVTE_CHECK(Float8TensorPythonClass != nullptr, - "Internal error: could not initialize pyTorch Float8 extension."); -} - -void init_mxfp8_extension() { - auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.mxfp8_tensor"); - MXFP8QuantizerClass = - reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Quantizer")); - MXFP8TensorPythonClass = - reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Tensor")); - auto fp8_base_module = - py::module_::import("transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage"); - MXFP8TensorStoragePythonClass = reinterpret_cast( - PyObject_GetAttrString(fp8_base_module.ptr(), "MXFP8TensorStorage")); - NVTE_CHECK(MXFP8TensorPythonClass != nullptr, - "Internal error: could not initialize pyTorch MXFP8 extension."); -} - -void init_float8blockwise_extension() { - auto fp8_module = - py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor"); - auto fp8_base_module = py::module_::import( - "transformer_engine.pytorch.tensor.storage.float8_blockwise_tensor_storage"); - Float8BlockwiseQuantizerClass = reinterpret_cast( - PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockQuantizer")); - Float8BlockwiseQTensorStoragePythonClass = reinterpret_cast( - PyObject_GetAttrString(fp8_base_module.ptr(), "Float8BlockwiseQTensorStorage")); - Float8BlockwiseQTensorPythonClass = reinterpret_cast( - PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockwiseQTensor")); - - NVTE_CHECK(Float8BlockwiseQuantizerClass != nullptr, - "Internal error: could not initialize pyTorch float8blockwise extension."); - NVTE_CHECK(Float8BlockwiseQTensorStoragePythonClass != nullptr, - "Internal error: could not initialize pyTorch float8blockwise extension."); - NVTE_CHECK(Float8BlockwiseQTensorPythonClass != nullptr, - "Internal error: could not initialize pyTorch float8blockwise extension."); -} - -void init_nvfp4_extensions() { - auto nvfp4_module = py::module_::import("transformer_engine.pytorch.tensor.nvfp4_tensor"); - NVFP4QuantizerClass = reinterpret_cast( - PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Quantizer")); - NVFP4TensorPythonClass = - reinterpret_cast(PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Tensor")); - auto nvfp4_base_module = - py::module_::import("transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage"); - NVFP4TensorStoragePythonClass = reinterpret_cast( - PyObject_GetAttrString(nvfp4_base_module.ptr(), "NVFP4TensorStorage")); - NVTE_CHECK(NVFP4TensorPythonClass != nullptr, - "Internal error: could not initialize pyTorch NVFP4 extension."); -} - -void init_grouped_tensor_extension() { - if (GroupedTensorPythonClass && GroupedTensorStoragePythonClass) return; - auto grouped_tensor_module = - py::module_::import("transformer_engine.pytorch.tensor.grouped_tensor"); - GroupedTensorPythonClass = reinterpret_cast( - PyObject_GetAttrString(grouped_tensor_module.ptr(), "GroupedTensor")); - auto grouped_tensor_storage_module = - py::module_::import("transformer_engine.pytorch.tensor.storage.grouped_tensor_storage"); - GroupedTensorStoragePythonClass = reinterpret_cast( - PyObject_GetAttrString(grouped_tensor_storage_module.ptr(), "GroupedTensorStorage")); - NVTE_CHECK(GroupedTensorPythonClass != nullptr, - "Internal error: could not initialize pyTorch grouped tensor extension."); - NVTE_CHECK(GroupedTensorStoragePythonClass != nullptr, - "Internal error: could not initialize pyTorch grouped tensor extension."); -} - -void init_extension() { - std::call_once(extension_init_flag, []() { - init_float8_extension(); - init_mxfp8_extension(); - init_float8blockwise_extension(); - init_nvfp4_extensions(); - init_grouped_tensor_extension(); - }); -} - -} // namespace transformer_engine::pytorch - -#include "common/util/pybind_helper.h" - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) - m.def("quantize", transformer_engine::pytorch::quantize, py::arg("tensor"), py::arg("quantizer"), - py::arg("output") = py::none(), py::arg("noop") = py::none()); - m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), - py::arg("otype")); - m.def("group_quantize", transformer_engine::pytorch::group_quantize, py::arg("tensor"), - py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims")); - m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, - "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); - m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", - py::arg("A"), py::arg("transA"), py::arg("B"), py::arg("transB"), py::arg("D"), - py::arg("quantizer"), py::arg("output_dtype"), py::arg("bias"), py::arg("bias_type"), - py::arg("gelu"), py::arg("gelu_in"), py::arg("grad"), py::arg("workspace"), - py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"), - py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt, - py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false, - py::arg("alpha") = 1.0f, py::arg("beta") = std::nullopt); - /* GLU (sigmoid gate) */ - m.def("glu", transformer_engine::pytorch::glu, "GLU activation", py::arg("input"), - py::arg("quantizer")); - /* GELU and variants*/ - m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"), - py::arg("quantizer")); - m.def("geglu", transformer_engine::pytorch::geglu, "GeGLU activation", py::arg("input"), - py::arg("quantizer")); - m.def("qgelu", transformer_engine::pytorch::qgelu, "QuickGELU activation", py::arg("input"), - py::arg("quantizer")); - m.def("qgeglu", transformer_engine::pytorch::qgeglu, "QuickGeGLU activation", py::arg("input"), - py::arg("quantizer")); - /* ReLU and variants */ - m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"), - py::arg("quantizer")); - m.def("reglu", transformer_engine::pytorch::reglu, "ReGLU activation", py::arg("input"), - py::arg("quantizer")); - m.def("srelu", transformer_engine::pytorch::srelu, "Squared ReLU activation", py::arg("input"), - py::arg("quantizer")); - m.def("sreglu", transformer_engine::pytorch::sreglu, "Squared ReGLU activation", py::arg("input"), - py::arg("quantizer")); - /* SwiGLU and variants */ - m.def("silu", transformer_engine::pytorch::silu, "SiLU activation", py::arg("input"), - py::arg("quantizer")); - m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"), - py::arg("quantizer")); - m.def("clamped_swiglu", transformer_engine::pytorch::clamped_swiglu, - "SwiGLU activation used in GPT OSS", py::arg("input"), py::arg("quantizer"), - py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); - /* Backward of GLU */ - m.def("dglu", transformer_engine::pytorch::dglu, "Backward of GLU", py::arg("grad"), - py::arg("fwd_input"), py::arg("quantizer")); - /* Backward of GELU and variants */ - m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"), - py::arg("fwd_input"), py::arg("quantizer")); - m.def("dgeglu", transformer_engine::pytorch::dgeglu, "Backward of GeGLU", py::arg("grad"), - py::arg("fwd_input"), py::arg("quantizer")); - m.def("dqgelu", transformer_engine::pytorch::dqgelu, "Backward of QuickGELU", py::arg("grad"), - py::arg("fwd_input"), py::arg("quantizer")); - m.def("dqgeglu", transformer_engine::pytorch::dqgeglu, "Backward of QuickGeGLU", py::arg("grad"), - py::arg("fwd_input"), py::arg("quantizer")); - /* Backward of ReLU and variants */ - m.def("drelu", transformer_engine::pytorch::drelu, "Backward of ReLU", py::arg("grad"), - py::arg("fwd_input"), py::arg("quantizer")); - m.def("dreglu", transformer_engine::pytorch::dreglu, "Backward of ReGLU", py::arg("grad"), - py::arg("fwd_input"), py::arg("quantizer")); - m.def("dsrelu", transformer_engine::pytorch::dsrelu, "Backward of Squared ReLU", py::arg("grad"), - py::arg("fwd_input"), py::arg("quantizer")); - m.def("dsreglu", transformer_engine::pytorch::dsreglu, "Backward of Squared ReGLU", - py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); - /* Backward of SiLU and variants */ - m.def("dsilu", transformer_engine::pytorch::dsilu, "Backward of SiLU", py::arg("grad"), - py::arg("fwd_input"), py::arg("quantizer")); - m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"), - py::arg("fwd_input"), py::arg("quantizer")); - m.def("clamped_dswiglu", transformer_engine::pytorch::clamped_dswiglu, - "Backward of SwiGLU used in GPT OSS", py::arg("grad"), py::arg("fwd_input"), - py::arg("quantizer"), py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); - /* DBias + DAct fusions*/ - m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize", - py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); - m.def("dbias_dsilu", transformer_engine::pytorch::dbias_dsilu, "DSiLU + DBias + Quantize", - py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); - m.def("dbias_drelu", transformer_engine::pytorch::dbias_drelu, "DReLU + DBias + Quantize", - py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); - m.def("dbias_dqgelu", transformer_engine::pytorch::dbias_dqgelu, "DQGeLU + DBias + Quantize", - py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); - m.def("dbias_dsrelu", transformer_engine::pytorch::dbias_dsrelu, - "DSquaredReLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"), - py::arg("quantizer")); - - // Permutation functions - m.def("moe_permute_fwd", transformer_engine::pytorch::moe_permute_fwd, "MOE permute FWD", - py::call_guard()); - m.def("moe_permute_bwd", transformer_engine::pytorch::moe_permute_bwd, "MOE permute BWD", - py::call_guard()); - m.def("moe_unpermute_fwd", transformer_engine::pytorch::moe_unpermute_fwd, "MOE unpermute FWD", - py::call_guard()); - m.def("moe_unpermute_bwd", transformer_engine::pytorch::moe_unpermute_bwd, "MOE unpermute BWD", - py::call_guard()); - - // Softmax functions - m.def("scaled_softmax_forward", &transformer_engine::pytorch::scaled_softmax_forward, - "Scaled Softmax FWD", py::call_guard()); - m.def("scaled_softmax_backward", &transformer_engine::pytorch::scaled_softmax_backward, - "Scaled Softmax BWD", py::call_guard()); - m.def("scaled_masked_softmax_forward", - &transformer_engine::pytorch::scaled_masked_softmax_forward, "Scaled Masked Softmax FWD", - py::call_guard()); - m.def("scaled_masked_softmax_backward", - &transformer_engine::pytorch::scaled_masked_softmax_backward, "Scaled Masked Softmax BWD", - py::call_guard()); - m.def("scaled_upper_triang_masked_softmax_forward", - &transformer_engine::pytorch::scaled_upper_triang_masked_softmax_forward, - "Scaled Upper-Triangular Masked Softmax FWD", py::call_guard()); - m.def("scaled_upper_triang_masked_softmax_backward", - &transformer_engine::pytorch::scaled_upper_triang_masked_softmax_backward, - "Scaled Upper-Triangular Masked Softmax BWD", py::call_guard()); - m.def("scaled_aligned_causal_masked_softmax_forward", - &transformer_engine::pytorch::scaled_aligned_causal_masked_softmax_forward, - "Scaled Bottom-Right Corner Aligned Masked Softmax FWD", - py::call_guard()); - m.def("scaled_aligned_causal_masked_softmax_backward", - &transformer_engine::pytorch::scaled_aligned_causal_masked_softmax_backward, - "Scaled Bottom-Right Corner Aligned Masked Softmax BWD", - py::call_guard()); - - // Other granular functions - m.def("layernorm_fwd", &transformer_engine::pytorch::layernorm_fwd, "LayerNorm", py::arg("input"), - py::arg("weight"), py::arg("bias"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"), - py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma")); - m.def("layernorm_bwd", &transformer_engine::pytorch::layernorm_bwd, "Backward of LayerNorm"); - m.def("rmsnorm_fwd", &transformer_engine::pytorch::rmsnorm_fwd, "RMSNorm", py::arg("input"), - py::arg("weight"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"), - py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma")); - m.def("rmsnorm_bwd", &transformer_engine::pytorch::rmsnorm_bwd, "Backward of RMSNorm"); - m.def("rmsnorm_bwd_add", &transformer_engine::pytorch::rmsnorm_bwd_add, - "Fused backward of RMSNorm + add"); - m.def("multi_tensor_quantize", &transformer_engine::pytorch::multi_tensor_quantize, - "Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list")); - m.def("split_quantize", &transformer_engine::pytorch::split_quantize, - "Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"), - py::arg("quantizer_list"), py::arg("disable_bulk_allocation") = false); - m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm, - "Grouped GEMM"); - m.def("te_general_grouped_gemm_for_grouped_tensor", - &transformer_engine::pytorch::te_general_grouped_gemm_for_grouped_tensor, - "Grouped GEMM for GroupedTensor"); - m.def("te_general_grouped_gemm_for_discrete_in", - &transformer_engine::pytorch::te_general_grouped_gemm_for_discrete_in, - "Grouped GEMM for discrete A input list"); - m.def("te_general_grouped_gemm_for_discrete_out", - &transformer_engine::pytorch::te_general_grouped_gemm_for_discrete_out, - "Grouped GEMM for discrete output list"); - m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", - py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"), - py::call_guard()); - m.def("nvfp4_data_transpose", &transformer_engine::pytorch::nvfp4_data_transpose, - "Transpose NVFP4 packed data with nibble repacking", py::arg("input"), py::kw_only(), - py::arg("out"), py::call_guard()); - m.def( - "nvfp4_2d_scale_transpose", &transformer_engine::pytorch::nvfp4_2d_scale_transpose, - "Transpose NVFP4 tile-level scales (E4M3 stored as uint8) from rowwise to columnwise format", - py::arg("input"), py::arg("output"), py::arg("M_tiles"), py::arg("K_tiles"), - py::call_guard()); - m.def("nvfp4_expand_scale_to_fp8", &transformer_engine::pytorch::nvfp4_expand_scale_to_fp8, - "Expand tile-level scales to row-level scales and convert to FP8 E4M3", py::arg("input"), - py::arg("output"), py::arg("tile_rows"), py::arg("tile_cols"), py::arg("rows_padded"), - py::arg("block_len"), py::call_guard()); - m.def("nvfp4_compute_per_block_scale", - &transformer_engine::pytorch::nvfp4_compute_per_block_scale, - "Compute per-block decode scale from block amax and global amax", py::arg("block_amax"), - py::arg("scale"), py::arg("global_amax"), py::call_guard()); - m.def("nvfp4_compute_global_scale", &transformer_engine::pytorch::nvfp4_compute_global_scale, - "Compute global encode scale from global amax", py::arg("global_amax"), - py::arg("global_scale"), py::call_guard()); - m.def("nvfp4_fused_scale", &transformer_engine::pytorch::nvfp4_fused_scale, - "Fused kernel: compute per-block decode scale, copy global amax, expand to row-level FP8", - py::arg("block_amax"), py::arg("global_amax"), py::arg("per_block_scale"), - py::arg("target_scale"), py::arg("target_amax"), py::arg("tile_rows"), py::arg("tile_cols"), - py::arg("rows_padded"), py::arg("block_len"), py::call_guard()); - m.def("nvfp4_multi_tensor_fused_scale", - &transformer_engine::pytorch::nvfp4_multi_tensor_fused_scale, - "Batched fused scale: compute per-block decode scale, copy global amax, expand to FP8 for " - "multiple tensors", - py::arg("block_amax_list"), py::arg("global_amax_list"), py::arg("per_block_scale_list"), - py::arg("target_scale_list"), py::arg("target_amax_list"), py::arg("tile_rows_list"), - py::arg("tile_cols_list"), py::arg("rows_padded_list"), py::arg("block_len"), - py::call_guard()); - m.def("nvfp4_2d_multi_tensor_transpose", - &transformer_engine::pytorch::nvfp4_2d_multi_tensor_transpose, - "Batched NVFP4 columnwise creation: transpose data and scales for multiple tensors", - py::arg("rowwise_data_list"), py::arg("columnwise_data_list"), - py::arg("rowwise_scale_inv_list"), py::arg("columnwise_scale_inv_list"), py::arg("M_list"), - py::arg("K_list"), py::call_guard()); - m.def("swap_first_dims", &transformer_engine::pytorch::swap_first_dims, - "Swap first two tensor dimensions", py::arg("tensor"), py::kw_only(), py::arg("out"), - py::call_guard()); - m.def("get_fused_attn_backend", &transformer_engine::pytorch::get_fused_attn_backend, - "Get Fused Attention backend", py::call_guard()); - m.def("compute_amax", &transformer_engine::pytorch::compute_amax, - "Compute absolute max value in tensor", py::arg("input"), py::arg("amax"), - py::call_guard()); - m.def("fused_amax_and_scale_update_after_reduction", - &transformer_engine::pytorch::fused_amax_and_scale_update_after_reduction, - "Update amax history and FP8 scale/scale_inv after reduction", - py::call_guard()); - m.def("fp8_block_scaling_compute_partial_amax", - &transformer_engine::pytorch::fp8_block_scaling_compute_partial_amax, - "Compute partial amax from master weights for fp8 block scaling", py::arg("tensor"), - py::arg("amax"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"), - py::call_guard()); - m.def("fp8_block_scaling_partial_cast", - &transformer_engine::pytorch::fp8_block_scaling_partial_cast, - "Partial cast from master weights for fp8 block scaling", py::arg("inp"), py::arg("out"), - py::arg("scale"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"), - py::arg("out_dtype"), py::call_guard()); - // NVFP4 2D - m.def("nvfp4_2d_compute_partial_amax", - &transformer_engine::pytorch::nvfp4_2d_compute_partial_amax, - "Compute partial amax from master weights for NVFP4 2D", py::arg("tensor"), py::arg("amax"), - py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len") = 16, - py::call_guard()); - m.def("nvfp4_multi_tensor_compute_partial_amax", - &transformer_engine::pytorch::nvfp4_multi_tensor_compute_partial_amax, - "Batched compute partial and global amax from master weights for NVFP4 2D", - py::arg("master_weight_list"), py::arg("partial_amax_list"), py::arg("global_amax_list"), - py::arg("h_list"), py::arg("w_list"), py::arg("start_offset_list"), - py::arg("block_len") = 16, py::call_guard()); - m.def("nvfp4_2d_partial_cast", &transformer_engine::pytorch::nvfp4_2d_partial_cast, - "Partial cast from master weights for NVFP4 2D", py::arg("inp"), py::arg("out"), - py::arg("scale"), py::arg("global_scale"), py::arg("h"), py::arg("w"), - py::arg("start_offset"), py::arg("block_len") = 16, - py::call_guard()); - m.def("nvfp4_multi_tensor_2d_partial_cast", - &transformer_engine::pytorch::nvfp4_multi_tensor_2d_partial_cast, - "Batched partial cast from master weights for NVFP4 2D", py::arg("inp_list"), - py::arg("out_list"), py::arg("scale_list"), py::arg("global_scale_list"), py::arg("h_list"), - py::arg("w_list"), py::arg("start_offset_list"), py::arg("block_len") = 16, - py::call_guard()); - m.def("mxfp8_scaling_compute_partial_amax", - &transformer_engine::pytorch::mxfp8_scaling_compute_partial_amax, - "Compute partial amax from master weights for fp8 mxfp8 scaling", py::arg("input"), - py::arg("amax_rowwise"), py::arg("amax_colwise"), py::arg("rows"), py::arg("cols"), - py::arg("start_offset"), py::call_guard()); - m.def("mxfp8_scaling_partial_cast", &transformer_engine::pytorch::mxfp8_scaling_partial_cast, - "Partial cast from master weights for fp8 mxfp8 scaling", py::arg("input"), - py::arg("output_rowwise"), py::arg("output_colwise"), py::arg("scale_inv_rowwise"), - py::arg("scale_inv_colwise"), py::arg("rows"), py::arg("cols"), py::arg("start_offset"), - py::call_guard()); - m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding, - "Fused Multi-tensor padding", py::call_guard()); - m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding, - "Fused Multi-tensor unpadding", py::call_guard()); - m.def("swizzle_scales_for_gemm_", &transformer_engine::pytorch::inplace_swizzle_scale_for_gemm, - "Convert tensor block scales into GEMM swizzled format"); - - // attention kernels - m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd, - "Prepare QKV for Flash Attention", py::call_guard()); - m.def("fa_prepare_bwd", &transformer_engine::pytorch::fa_prepare_bwd, - "Backward of QKV preparation for Flash Attention", - py::call_guard()); - m.def("fused_attn_fwd", &transformer_engine::pytorch::fused_attn_fwd, - "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"); - m.def("fused_attn_bwd", &transformer_engine::pytorch::fused_attn_bwd, - "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"); - m.def("copy_to_kv_cache", &transformer_engine::pytorch::copy_to_kv_cache, - "Copy new KV tokens to KV cache", py::call_guard()); - m.def("convert_thd_to_bshd", &transformer_engine::pytorch::convert_thd_to_bshd, - "Convert a tensor from THD to BSHD", py::call_guard()); - m.def("convert_bshd_to_thd", &transformer_engine::pytorch::convert_bshd_to_thd, - "Convert a tesnor from BSHD to THD", py::call_guard()); - - // fused apply rope - m.def("fused_rope_forward", &transformer_engine::pytorch::fused_rope_forward, - "Fused Apply RoPE FWD", py::call_guard()); - m.def("fused_rope_backward", &transformer_engine::pytorch::fused_rope_backward, - "Fused Apply RoPE BWD", py::call_guard()); - m.def("fused_qkv_rope_forward", &transformer_engine::pytorch::fused_qkv_rope_forward, - "Fused Apply QKV RoPE FWD", py::call_guard()); - m.def("fused_qkv_rope_backward", &transformer_engine::pytorch::fused_qkv_rope_backward, - "Fused Apply QKV RoPE BWD", py::call_guard()); - - // fused router - m.def("fused_topk_with_score_function_fwd", - &transformer_engine::pytorch::fused_topk_with_score_function_fwd, py::arg("logits"), - py::arg("topk"), py::arg("use_pre_softmax"), py::arg("num_groups"), py::arg("group_topk"), - py::arg("scaling_factor"), py::arg("score_function"), py::arg("expert_bias"), - "Fused topk with score function fwd"); - m.def("fused_topk_with_score_function_bwd", - &transformer_engine::pytorch::fused_topk_with_score_function_bwd, py::arg("num_tokens"), - py::arg("num_experts"), py::arg("routing_map"), py::arg("intermediate_output"), - py::arg("grad_probs"), py::arg("grad_logits"), py::arg("topk"), py::arg("use_pre_softmax"), - py::arg("scaling_factor"), py::arg("score_function"), "Fused topk with score function bwd"); - m.def("fused_score_for_moe_aux_loss_fwd", - &transformer_engine::pytorch::fused_score_for_moe_aux_loss_fwd, py::arg("logits"), - py::arg("topk"), py::arg("score_function"), "Fused aux loss with score function fwd"); - m.def("fused_score_for_moe_aux_loss_bwd", - &transformer_engine::pytorch::fused_score_for_moe_aux_loss_bwd, py::arg("num_tokens"), - py::arg("num_experts"), py::arg("intermediate_output"), py::arg("grad_scores"), - py::arg("grad_logits"), py::arg("topk"), py::arg("score_function"), - "Fused aux loss with score function bwd"); - m.def("fused_moe_aux_loss_fwd", &transformer_engine::pytorch::fused_moe_aux_loss_fwd, - py::arg("probs"), py::arg("tokens_per_expert"), py::arg("total_num_tokens"), - py::arg("num_experts"), py::arg("num_rows"), py::arg("num_cols"), py::arg("topk"), - py::arg("coeff"), "Fused aux loss fwd"); - m.def("fused_moe_aux_loss_bwd", &transformer_engine::pytorch::fused_moe_aux_loss_bwd, - py::arg("Const_buf"), py::arg("tokens_per_expert"), py::arg("num_rows"), - py::arg("num_cols"), py::arg("grad_aux_loss"), "Fused aux loss bwd"); - - // Dropout - m.def("dropout_fwd", transformer_engine::pytorch::dropout_fwd, "Dropout forward with 8-bit RNG", - py::arg("input"), py::arg("dropout_probability"), py::arg("out") = std::nullopt); - m.def("dropout_bwd", transformer_engine::pytorch::dropout_bwd, "Dropout backward with 8-bit RNG", - py::arg("grad_output"), py::arg("mask"), py::arg("dropout_probability"), - py::arg("grad_input") = std::nullopt); - - // Misc - m.def("get_cublasLt_version", &transformer_engine::pytorch::get_cublasLt_version, - "Get cublasLt version", py::call_guard()); - m.def("get_cudnn_version", &transformer_engine::pytorch::get_cudnn_version, "Get cuDNN version", - py::call_guard()); - m.def("splits_to_offsets", &transformer_engine::pytorch::splits_to_offsets, - "Compute grouped tensor offsets from split sizes", py::arg("first_dims"), - py::arg("logical_last_dim"), py::call_guard()); - m.def("get_num_cublas_streams", &nvte_get_num_compute_streams, "Get number of compute streams", - py::call_guard()); - - // Support THD format for Context Parallel - m.def("thd_read_half_tensor", &transformer_engine::pytorch::thd_read_half_tensor, - "Read the first half(half_idx=0) or the second half(half_idx=1) of each sequence in a THD " - "tensor", - py::call_guard()); - m.def("thd_second_half_lse_correction", - &transformer_engine::pytorch::thd_second_half_lse_correction, - "Correct the second half of the softmax_lse", py::call_guard()); - m.def("thd_read_second_half_lse", &transformer_engine::pytorch::thd_read_second_half_lse, - "Read the second half of the softmax_lse", py::call_guard()); - m.def("thd_out_correction", &transformer_engine::pytorch::thd_out_correction, - "Correct the THD format output of context parallelism in forward pass", - py::call_guard()); - m.def("thd_grad_correction", &transformer_engine::pytorch::thd_grad_correction, - "Correct the THD format gradients of context parallelism in backward pass", - py::call_guard()); - m.def("thd_get_partitioned_indices", &transformer_engine::pytorch::thd_get_partitioned_indices, - "Generate partitioned indices for inputs in THD format", - py::call_guard()); - - // nvshmem functions - m.def("init_nvshmem_backend", &transformer_engine::pytorch::init_nvshmem_backend, - "Initialize nvshmem backend with Pytorch distributed process groups", - py::call_guard()); - m.def("create_nvshmem_tensor", &transformer_engine::pytorch::create_nvshmem_tensor, - "Create a tensor in NVSHMEM shared memory", py::call_guard()); - m.def("nvshmem_send_on_current_stream", - &transformer_engine::pytorch::nvshmem_send_on_current_stream, - "Asynchronously send tensor data to a remote PE using NVSHMEM on the current CUDA stream", - py::call_guard()); - m.def("nvshmem_wait_on_current_stream", - &transformer_engine::pytorch::nvshmem_wait_on_current_stream, - "Wait for a signal value to be updated by a remote PE using NVSHMEM on the current CUDA " - "stream", - py::call_guard()); - m.def("nvshmem_finalize", &transformer_engine::pytorch::nvshmem_finalize, - "Clean up and finalize the NVSHMEM communication backend and free associated resources", - py::call_guard()); - - // multi-tensor functions - m.def("multi_tensor_scale", &transformer_engine::pytorch::multi_tensor_scale_cuda, - "Fused overflow check + scale for a list of contiguous tensors", - py::call_guard()); - m.def("multi_tensor_scale_tensor", &transformer_engine::pytorch::multi_tensor_scale_tensor_cuda, - "Fused overflow check + scale for a list of contiguous tensors with scale passed as tensor", - py::call_guard()); - m.def("multi_tensor_l2norm", &transformer_engine::pytorch::multi_tensor_l2norm_cuda, - "Computes L2 norm for a list of contiguous tensors", - py::call_guard()); - m.def("multi_tensor_unscale_l2norm", - &transformer_engine::pytorch::multi_tensor_unscale_l2norm_cuda, - "Computes L2 norm for a list of contiguous tensors after unscaling (unscaling is only " - "performed for L2 norm computation, and tensors are not updated)", - py::call_guard()); - m.def("multi_tensor_adam", &transformer_engine::pytorch::multi_tensor_adam_cuda, - "Compute and apply gradient update to parameters for Adam optimizer", - py::call_guard()); - m.def("multi_tensor_adam_param_remainder", - &transformer_engine::pytorch::multi_tensor_adam_param_remainder_cuda, - "Compute and apply gradient update to parameters for Adam optimizer" - "where the master parameters only store the remainder bits", - py::call_guard()); - m.def("multi_tensor_adam_fp8", &transformer_engine::pytorch::multi_tensor_adam_fp8_cuda, - "Compute and apply gradient update to parameters for Adam optimizer", - py::call_guard()); - m.def("multi_tensor_adam_capturable", - &transformer_engine::pytorch::multi_tensor_adam_capturable_cuda, - "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph " - "support and LR scheduling", - py::call_guard()); - m.def("multi_tensor_adam_capturable_master", - &transformer_engine::pytorch::multi_tensor_adam_capturable_master_cuda, - "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph " - "support, LR scheduling and FP32 master weights", - py::call_guard()); - m.def("multi_tensor_sgd", &transformer_engine::pytorch::multi_tensor_sgd_cuda, - "Fused SGD optimizer for list of contiguous tensors", - py::call_guard()); - m.def("multi_tensor_compute_scale_and_scale_inv", - &transformer_engine::pytorch::multi_tensor_compute_scale_and_scale_inv_cuda, - "Fused compute scale and scale_inv from amax", py::call_guard()); - m.def("multi_tensor_compute_scale_inv_e8m0", - &transformer_engine::pytorch::multi_tensor_compute_scale_inv_e8m0_cuda, - "Fused compute E8M0 scale_inv from amax", py::call_guard()); - - // Comm+GEMM Overlap - m.def("bulk_overlap_ag_with_external_gemm", - &transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm, - "Bulk overlap All-Gather with a GEMM operation launched by another communicator", - py::call_guard(), py::arg("allgather_communicator"), - py::arg("send_stream"), py::arg("recv_stream")); - - // Data structures - py::class_(m, "FP8TensorMeta") - .def(py::init<>()) - .def_readwrite("scale", &transformer_engine::pytorch::FP8TensorMeta::scale) - .def_readwrite("scale_inv", &transformer_engine::pytorch::FP8TensorMeta::scale_inv) - .def_readwrite("amax_history", &transformer_engine::pytorch::FP8TensorMeta::amax_history); - - py::enum_(m, "FP8FwdTensors") - .value("GEMM1_INPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM1_INPUT) - .value("GEMM1_WEIGHT", transformer_engine::pytorch::FP8FwdTensors::GEMM1_WEIGHT) - .value("GEMM1_OUTPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM1_OUTPUT) - .value("GEMM2_INPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM2_INPUT) - .value("GEMM2_WEIGHT", transformer_engine::pytorch::FP8FwdTensors::GEMM2_WEIGHT) - .value("GEMM2_OUTPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM2_OUTPUT) - .value("GEMM3_INPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM3_INPUT) - .value("GEMM3_WEIGHT", transformer_engine::pytorch::FP8FwdTensors::GEMM3_WEIGHT) - .value("GEMM3_OUTPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM3_OUTPUT); - - py::enum_(m, "FP8BwdTensors") - .value("GRAD_OUTPUT1", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT1) - .value("GRAD_INPUT1", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT1) - .value("GRAD_OUTPUT2", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT2) - .value("GRAD_INPUT2", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT2) - .value("GRAD_OUTPUT3", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT3) - .value("GRAD_INPUT3", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT3); - - py::class_(m, "CommOverlapHelper") - .def(py::init<>(), py::call_guard()) - .def(py::init>(), - py::call_guard(), py::arg("world_group"), - py::arg("intra_node_group") = py::none()); - - py::class_, transformer_engine::CommOverlapBase, - transformer_engine::CommOverlapCore>(m, "CommOverlap") - .def(py::init &, at::ScalarType, CommOverlapHelper *, int, int, int, - int, int, int, int, bool, bool, bool>(), - py::call_guard(), py::arg("buffer_shape"), - py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), - py::arg("num_splits") = 3, py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, - py::arg("comm_cga_size") = 2, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, - py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, - py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false) - .def("copy_into_buffer", - static_cast( - &CommOverlap::copy_into_buffer), - py::arg("input"), py::arg("local_chunk") = false) - .def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false, - py::arg("shape") = std::nullopt) - .def("get_communication_stream", &CommOverlap::get_communication_stream); - - py::class_, - transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>( - m, "CommOverlapP2P") - .def(py::init &, at::ScalarType, CommOverlapHelper *, int, - transformer_engine::CommOverlapType, int, int, int, int, int, bool, bool, bool, - bool>(), - py::call_guard(), py::arg("buffer_shape"), - py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("comm_type"), - py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 1, - py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1, - py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, - py::arg("use_ce") = true, py::arg("aggregate") = false) - .def("copy_into_buffer", - static_cast( - &CommOverlapP2P::copy_into_buffer), - py::arg("input"), py::arg("local_chunk") = false) - .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false, - py::arg("shape") = std::nullopt) - .def("get_communication_stream", &CommOverlapP2P::get_communication_stream); -} diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index c02d2ec616..6fb8d94d18 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -4,66 +4,161 @@ * See LICENSE for license information. ************************************************************************/ -#include -#include +#include +#include +#include +#include +#include +#include -#include +#include "../stable_common.h" +#include "common/util/cuda_runtime.h" -#include "../extensions.h" -#include "transformer_engine/transformer_engine.h" +namespace transformer_engine::pytorch::stable { -namespace transformer_engine::pytorch { +using Tensor = torch::stable::Tensor; -void compute_amax(const at::Tensor& tensor, at::Tensor& amax) { - auto input_tensor = tensor.contiguous(); - const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); +int64_t get_cublasLt_version() { return static_cast(cublasLtGetVersion()); } - TORCH_CHECK(amax.scalar_type() == at::kFloat, "amax must be a float tensor"); - TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element"); - auto* amax_ptr = amax.data_ptr(); - TensorWrapper fake_te_output( - /*dptr=*/nullptr, te_input.shape(), - DType::kFloat32, // It doesn't matter because we only compute amax. - amax_ptr); +int64_t get_cudnn_version() { return static_cast(cudnnGetVersion()); } - nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream()); +void compute_amax(Tensor input, Tensor amax) { + auto input_ = torch::stable::contiguous(input); + auto input_cu = makeTransformerEngineTensor(input_); + auto shape = getStableTensorShape(input_); + + // Build output TensorWrapper with amax pointer + TensorWrapper fake_output(NVTE_DELAYED_TENSOR_SCALING); + fake_output.set_rowwise_data(nullptr, DType::kFloat32, + std::vector(shape.begin(), shape.end())); + fake_output.set_amax(amax.data_ptr(), DType::kFloat32, std::vector{1}); + + nvte_compute_amax(input_cu.data(), fake_output.data(), + getCurrentCUDAStreamRaw(input_.get_device_index())); } -void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reduction_buffer, - std::vector amax_histories, - std::vector scales, - const std::string& amax_compute_algo, - DType fp8_dtype, float margin) { - size_t num_tensors = amax_histories.size(); +// fused_amax_and_scale_update_after_reduction uses the pointer-pack pattern: +// Python passes flat int64 tensors with data_ptr() values for amax_histories and scales. +// fused_amax_and_scale_update: use pointer+ndim+shape encoding +// shapes tensor: [num_tensors * 3] — (ndim, dim0, dim1) per tensor. dim1=0 for 1D. +void fused_amax_and_scale_update( + Tensor amax_reduction_buffer, + Tensor amax_history_ptrs, // int64 [num_tensors] — data_ptr() per history + Tensor amax_history_shapes, // int64 [num_tensors * 3] — (ndim, dim0, dim1) per history + Tensor scale_ptrs, // int64 [num_tensors] — data_ptr() per scale + Tensor scale_shapes, // int64 [num_tensors * 3] — (ndim, dim0, dim1) per scale + int64_t num_tensors, std::string amax_compute_algo, int64_t fp8_dtype, double margin) { + auto buf_cu = makeTransformerEngineTensor(amax_reduction_buffer); + + const int64_t* ah_ptrs = static_cast(amax_history_ptrs.data_ptr()); + const int64_t* ah_shapes = static_cast(amax_history_shapes.data_ptr()); + const int64_t* sc_ptrs = static_cast(scale_ptrs.data_ptr()); + const int64_t* sc_shapes = static_cast(scale_shapes.data_ptr()); + std::vector te_amax_histories; std::vector te_scales; te_amax_histories.reserve(num_tensors); te_scales.reserve(num_tensors); - for (size_t i = 0; i < num_tensors; i++) { + + for (int64_t i = 0; i < num_tensors; i++) { te_amax_histories.push_back(nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING)); - NVTETensor& amax_history = te_amax_histories.back(); - NVTEShape amax_shape = convertTorchShape(amax_histories[i].sizes()); - NVTEBasicTensor amax_history_data = {amax_histories[i].data_ptr(), - static_cast(DType::kFloat32), amax_shape}; - nvte_set_tensor_param(&amax_history, kNVTERowwiseData, &amax_history_data); + size_t ah_ndim = static_cast(ah_shapes[i * 3]); + size_t ah_dims[] = {static_cast(ah_shapes[i * 3 + 1]), + static_cast(ah_shapes[i * 3 + 2])}; + NVTEShape amax_shape = nvte_make_shape(ah_dims, ah_ndim); + NVTEBasicTensor amax_data = {reinterpret_cast(ah_ptrs[i]), + static_cast(DType::kFloat32), amax_shape}; + nvte_set_tensor_param(&te_amax_histories.back(), kNVTERowwiseData, &amax_data); te_scales.push_back(nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING)); - NVTETensor& scale = te_scales.back(); - NVTEShape scale_shape = convertTorchShape(scales[i].sizes()); - NVTEBasicTensor scale_data = {scales[i].data_ptr(), static_cast(DType::kFloat32), - scale_shape}; - nvte_set_tensor_param(&scale, kNVTERowwiseData, &scale_data); + size_t sc_ndim = static_cast(sc_shapes[i * 3]); + size_t sc_dims[] = {static_cast(sc_shapes[i * 3 + 1]), + static_cast(sc_shapes[i * 3 + 2])}; + NVTEShape scale_shape = nvte_make_shape(sc_dims, sc_ndim); + NVTEBasicTensor scale_data = {reinterpret_cast(sc_ptrs[i]), + static_cast(DType::kFloat32), scale_shape}; + nvte_set_tensor_param(&te_scales.back(), kNVTERowwiseData, &scale_data); } + nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( - makeTransformerEngineTensor(amax_reduction_buffer).data(), te_amax_histories, te_scales, - amax_compute_algo.c_str(), static_cast(fp8_dtype), margin, - at::cuda::getCurrentCUDAStream()); - for (auto& t : te_amax_histories) { - nvte_destroy_tensor(t); - } - for (auto& t : te_scales) { - nvte_destroy_tensor(t); - } + buf_cu.data(), te_amax_histories, te_scales, amax_compute_algo.c_str(), + static_cast(fp8_dtype), static_cast(margin), + getCurrentCUDAStreamRaw(amax_reduction_buffer.get_device_index())); + + for (auto& t : te_amax_histories) nvte_destroy_tensor(t); + for (auto& t : te_scales) nvte_destroy_tensor(t); } -} // namespace transformer_engine::pytorch +int64_t get_fused_attn_backend(bool is_training, int64_t q_dtype, int64_t kv_dtype, + int64_t qkv_layout, int64_t bias_type, int64_t attn_mask_type, + int64_t softmax_type, double p_dropout, int64_t num_attn_heads, + int64_t num_gqa_groups, int64_t max_seqlen_q, int64_t max_seqlen_kv, + int64_t head_dim_qk, int64_t head_dim_v, int64_t window_size_left, + int64_t window_size_right, bool return_max_logit, bool cuda_graph, + bool deterministic) { + return static_cast(nvte_get_fused_attn_backend( + is_training, static_cast(q_dtype), static_cast(kv_dtype), + static_cast(qkv_layout), static_cast(bias_type), + static_cast(attn_mask_type), static_cast(softmax_type), + static_cast(p_dropout), static_cast(num_attn_heads), + static_cast(num_gqa_groups), static_cast(max_seqlen_q), + static_cast(max_seqlen_kv), static_cast(head_dim_qk), + static_cast(head_dim_v), window_size_left, window_size_right, return_max_logit, + cuda_graph, deterministic)); +} + +int64_t get_num_cublas_streams() { return static_cast(nvte_get_num_compute_streams()); } + +bool device_supports_multicast(int64_t device_id) { + return transformer_engine::cuda::supports_multicast(static_cast(device_id)); +} + +std::vector get_stream_priority_range(int64_t device_id) { + int low = 0, high = 0; + transformer_engine::cuda::stream_priority_range(&low, &high, static_cast(device_id)); + return {static_cast(low), static_cast(high)}; +} + +bool ubuf_built_with_mpi() { return transformer_engine::ubuf_built_with_mpi(); } + +} // namespace transformer_engine::pytorch::stable + +STABLE_TORCH_LIBRARY_FRAGMENT(transformer_engine_stable, m) { + m.def("get_cublasLt_version() -> int"); + m.def("get_cudnn_version() -> int"); + m.def("compute_amax(Tensor input, Tensor amax) -> ()"); + m.def( + "fused_amax_and_scale_update(Tensor amax_reduction_buffer, Tensor amax_history_ptrs, Tensor " + "amax_history_shapes, Tensor scale_ptrs, Tensor scale_shapes, int num_tensors, str " + "amax_compute_algo, int fp8_dtype, float margin) -> ()"); + // shapes format: [num_tensors * 3] — (ndim, dim0, dim1) per tensor + m.def("get_num_cublas_streams() -> int"); + m.def("device_supports_multicast(int device_id) -> bool"); + m.def("get_stream_priority_range(int device_id) -> int[]"); + m.def("ubuf_built_with_mpi() -> bool"); + m.def( + "get_fused_attn_backend(bool is_training, int q_dtype, int kv_dtype, int qkv_layout, int " + "bias_type, int attn_mask_type, int softmax_type, float p_dropout, int num_attn_heads, int " + "num_gqa_groups, int max_seqlen_q, int max_seqlen_kv, int head_dim_qk, int head_dim_v, int " + "window_size_left, int window_size_right, bool return_max_logit, bool cuda_graph, bool " + "deterministic) -> int"); +} + +// Ops with tensor arguments → CUDA dispatch key +STABLE_TORCH_LIBRARY_IMPL(transformer_engine_stable, CUDA, m) { + using namespace transformer_engine::pytorch::stable; + m.impl("compute_amax", TORCH_BOX(compute_amax)); + m.impl("fused_amax_and_scale_update", TORCH_BOX(fused_amax_and_scale_update)); +} + +// Ops without tensor arguments → CompositeImplicitAutograd fallback +STABLE_TORCH_LIBRARY_IMPL(transformer_engine_stable, CompositeImplicitAutograd, m) { + using namespace transformer_engine::pytorch::stable; + m.impl("get_cublasLt_version", TORCH_BOX(get_cublasLt_version)); + m.impl("get_cudnn_version", TORCH_BOX(get_cudnn_version)); + m.impl("get_num_cublas_streams", TORCH_BOX(get_num_cublas_streams)); + m.impl("device_supports_multicast", TORCH_BOX(device_supports_multicast)); + m.impl("get_stream_priority_range", TORCH_BOX(get_stream_priority_range)); + m.impl("ubuf_built_with_mpi", TORCH_BOX(ubuf_built_with_mpi)); + m.impl("get_fused_attn_backend", TORCH_BOX(get_fused_attn_backend)); +} diff --git a/transformer_engine/pytorch/csrc/extensions/registration.cpp b/transformer_engine/pytorch/csrc/extensions/registration.cpp new file mode 100644 index 0000000000..44b4e57689 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/registration.cpp @@ -0,0 +1,80 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../stable_common.h" + +// This file defines the transformer_engine_stable library namespace. +// All other stable ABI files use STABLE_TORCH_LIBRARY_FRAGMENT to add schemas +// and STABLE_TORCH_LIBRARY_IMPL to add implementations. +STABLE_TORCH_LIBRARY(transformer_engine_stable, m) { + // Softmax ops + m.def("scaled_softmax_forward(Tensor input, float scale_factor) -> Tensor"); + m.def( + "scaled_softmax_backward(Tensor output_grad, Tensor softmax_results, float scale_factor) -> " + "Tensor"); + m.def("scaled_masked_softmax_forward(Tensor input, Tensor mask, float scale_factor) -> Tensor"); + m.def( + "scaled_masked_softmax_backward(Tensor output_grad, Tensor softmax_results, float " + "scale_factor) -> Tensor"); + m.def("scaled_upper_triang_masked_softmax_forward(Tensor input, float scale_factor) -> Tensor"); + m.def( + "scaled_upper_triang_masked_softmax_backward(Tensor output_grads, Tensor softmax_results, " + "float scale_factor) -> Tensor"); + m.def("scaled_aligned_causal_masked_softmax_forward(Tensor input, float scale_factor) -> Tensor"); + m.def( + "scaled_aligned_causal_masked_softmax_backward(Tensor output_grad, Tensor softmax_results, " + "float scale_factor) -> Tensor"); + + // Padding ops + m.def( + "fused_multi_row_padding(Tensor input, Tensor output, int[] input_row_list, int[] " + "padded_input_row_list) -> ()"); + m.def( + "fused_multi_row_unpadding(Tensor input, Tensor output, int[] input_row_list, int[] " + "unpadded_input_row_list) -> ()"); + + // Misc ops + m.def("splits_to_offsets(Tensor first_dims, int logical_last_dim) -> Tensor"); + + // RoPE ops + m.def( + "fused_rope_forward(Tensor input, Tensor freqs, Tensor? start_positions, int qkv_format, " + "bool interleaved, Tensor? cu_seqlens, int cp_size, int cp_rank) -> Tensor"); + m.def( + "fused_rope_backward(Tensor output_grads, Tensor freqs, Tensor? start_positions, int " + "qkv_format, bool interleaved, Tensor? cu_seqlens, int cp_size, int cp_rank) -> Tensor"); + m.def( + "fused_qkv_rope_forward(Tensor qkv_input, Tensor q_freqs, Tensor k_freqs, Tensor? " + "start_positions, int[] qkv_split_arg_list, int qkv_format, bool interleaved, int cp_size, " + "int cp_rank) -> (Tensor, Tensor, Tensor)"); + m.def( + "fused_qkv_rope_backward(Tensor q_grad_out, Tensor k_grad_out, Tensor v_grad_out, Tensor " + "q_freqs, Tensor k_freqs, int[] qkv_split_arg_list, int qkv_format, bool interleaved, int " + "cp_size, int cp_rank) -> Tensor"); + + // Router ops + m.def( + "fused_topk_with_score_function_fwd(Tensor logits, int topk, bool use_pre_softmax, int " + "num_groups, int group_topk, float scaling_factor, str score_function, Tensor? expert_bias) " + "-> (Tensor, Tensor, Tensor)"); + m.def( + "fused_topk_with_score_function_bwd(int num_tokens, int num_experts, Tensor routing_map, " + "Tensor intermediate_output, Tensor grad_probs, Tensor grad_logits, int topk, bool " + "use_pre_softmax, float scaling_factor, str score_function) -> ()"); + m.def( + "fused_score_for_moe_aux_loss_fwd(Tensor logits, int topk, str score_function) -> (Tensor, " + "Tensor, Tensor)"); + m.def( + "fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts, Tensor " + "intermediate_output, Tensor grad_scores, Tensor grad_logits, int topk, str score_function) " + "-> ()"); + m.def( + "fused_moe_aux_loss_fwd(Tensor probs, Tensor tokens_per_expert, int total_num_tokens, int " + "num_experts, int num_rows, int num_cols, int topk, float coeff) -> (Tensor, Tensor)"); + m.def( + "fused_moe_aux_loss_bwd(Tensor Const_buf, Tensor tokens_per_expert, int num_rows, int " + "num_cols, Tensor grad_aux_loss) -> Tensor"); +} diff --git a/transformer_engine/pytorch/csrc/extensions/router.cpp b/transformer_engine/pytorch/csrc/extensions/router.cpp index 94625c0f12..9fa074e5d4 100644 --- a/transformer_engine/pytorch/csrc/extensions/router.cpp +++ b/transformer_engine/pytorch/csrc/extensions/router.cpp @@ -4,77 +4,73 @@ * See LICENSE for license information. ************************************************************************/ -#include "../extensions.h" -#include "common.h" +#include -namespace transformer_engine::pytorch { +#include + +#include "../stable_common.h" + +namespace transformer_engine::pytorch::stable { + +using Tensor = torch::stable::Tensor; static std::map score_function_map = { {"sigmoid", 0}, {"softmax", 1}, {"sqrtsoftplus", 2}}; -std::tuple fused_topk_with_score_function_fwd( - at::Tensor logits, int topk, bool use_pre_softmax, std::optional num_groups, - std::optional group_topk, std::optional scaling_factor, std::string score_function, - std::optional expert_bias) { - int num_tokens = logits.size(0); - int num_experts = logits.size(1); - // Check if the input is valid - TORCH_CHECK(num_tokens > 0 && num_experts > 0, - "num_tokens and num_experts must be greater than 0"); - // Expert bias only happens at the sigmoid case +std::tuple fused_topk_with_score_function_fwd( + Tensor logits, int64_t topk, bool use_pre_softmax, int64_t num_groups, int64_t group_topk, + double scaling_factor, std::string score_function, std::optional expert_bias) { + int64_t num_tokens = logits.size(0); + int64_t num_experts = logits.size(1); + + STD_TORCH_CHECK(num_tokens > 0 && num_experts > 0, + "num_tokens and num_experts must be greater than 0"); if (expert_bias.has_value()) { - TORCH_CHECK(score_function == "sigmoid" || score_function == "sqrtsoftplus", - "score_function must be sigmoid or sqrtsoftplus when expert_bias is not None"); - TORCH_CHECK(expert_bias.value().scalar_type() == at::kFloat, - "expert_bias must be a float32 tensor"); + STD_TORCH_CHECK(score_function == "sigmoid" || score_function == "sqrtsoftplus", + "score_function must be sigmoid or sqrtsoftplus when expert_bias is not None"); } - // Check if the score function is valid - TORCH_CHECK(score_function == "softmax" || score_function == "sigmoid" || - score_function == "sqrtsoftplus", - "score_function must be softmax, sigmoid or sqrtsoftplus for router fusion"); + STD_TORCH_CHECK(score_function == "softmax" || score_function == "sigmoid" || + score_function == "sqrtsoftplus", + "score_function must be softmax, sigmoid or sqrtsoftplus"); + if (score_function == "sigmoid" || score_function == "sqrtsoftplus") { - use_pre_softmax = false; // Pre-softmax only happens at the softmax case + use_pre_softmax = false; } - // Reformat the input to make it compatible with the kernel - int group_topk_value = group_topk.has_value() ? group_topk.value() : -1; - int num_groups_value = num_groups.has_value() ? num_groups.value() : -1; - float scaling_factor_value = scaling_factor.has_value() ? scaling_factor.value() : 1.0f; + int group_topk_value = static_cast(group_topk); + int num_groups_value = static_cast(num_groups); + float scaling_factor_value = static_cast(scaling_factor); - // Construct the output tensor - at::Tensor probs = - at::empty({num_tokens, num_experts}, at::dtype(logits.scalar_type()).device(at::kCUDA)); - at::Tensor routing_map = - at::empty({num_tokens, num_experts}, at::dtype(at::kBool).device(at::kCUDA)); - // Intermediate output is used to store the output of the softmax/sigmoid function - at::Tensor intermediate_output = - at::empty({num_tokens, num_experts}, at::dtype(at::kFloat).device(at::kCUDA)); + auto device_idx = logits.get_device_index(); + auto probs = allocateStableTensor({num_tokens, num_experts}, logits.scalar_type(), device_idx); + auto routing_map = allocateStableTensor({num_tokens, num_experts}, ScalarType::Bool, device_idx); + auto intermediate_output = + allocateStableTensor({num_tokens, num_experts}, ScalarType::Float, device_idx); auto logits_cu = makeTransformerEngineTensor(logits); auto probs_cu = makeTransformerEngineTensor(probs); auto routing_map_cu = makeTransformerEngineTensor(routing_map); auto intermediate_output_cu = makeTransformerEngineTensor(intermediate_output); - auto expert_bias_cu = TensorWrapper(); // empty expert_bias_cu tensor + auto expert_bias_cu = TensorWrapper(); if (expert_bias.has_value()) { expert_bias_cu = makeTransformerEngineTensor(expert_bias.value()); } nvte_fused_topk_with_score_function_forward( - logits_cu.data(), num_tokens, num_experts, topk, use_pre_softmax, num_groups_value, - group_topk_value, scaling_factor_value, score_function_map[score_function], - expert_bias_cu.data(), probs_cu.data(), routing_map_cu.data(), intermediate_output_cu.data(), - at::cuda::getCurrentCUDAStream()); + logits_cu.data(), static_cast(num_tokens), static_cast(num_experts), + static_cast(topk), use_pre_softmax, num_groups_value, group_topk_value, + scaling_factor_value, score_function_map[score_function], expert_bias_cu.data(), + probs_cu.data(), routing_map_cu.data(), intermediate_output_cu.data(), + getCurrentCUDAStreamRaw(device_idx)); return std::make_tuple(probs, routing_map, intermediate_output); } -void fused_topk_with_score_function_bwd(int num_tokens, int num_experts, at::Tensor routing_map, - at::Tensor intermediate_output, at::Tensor grad_probs, - at::Tensor grad_logits, int topk, bool use_pre_softmax, - std::optional scaling_factor, - std::string score_function) { - // Get the value of the parameters - auto scaling_factor_value = scaling_factor.has_value() ? scaling_factor.value() : 1.0f; +void fused_topk_with_score_function_bwd(int64_t num_tokens, int64_t num_experts, Tensor routing_map, + Tensor intermediate_output, Tensor grad_probs, + Tensor grad_logits, int64_t topk, bool use_pre_softmax, + double scaling_factor, std::string score_function) { + float scaling_factor_value = static_cast(scaling_factor); auto score_function_value = score_function_map[score_function]; auto routing_map_cu = makeTransformerEngineTensor(routing_map); @@ -83,31 +79,27 @@ void fused_topk_with_score_function_bwd(int num_tokens, int num_experts, at::Ten auto grad_logits_cu = makeTransformerEngineTensor(grad_logits); nvte_fused_topk_with_score_function_backward( - routing_map_cu.data(), intermediate_output_cu.data(), grad_probs_cu.data(), num_tokens, - num_experts, topk, use_pre_softmax, scaling_factor_value, score_function_value, - grad_logits_cu.data(), at::cuda::getCurrentCUDAStream()); + routing_map_cu.data(), intermediate_output_cu.data(), grad_probs_cu.data(), + static_cast(num_tokens), static_cast(num_experts), static_cast(topk), + use_pre_softmax, scaling_factor_value, score_function_value, grad_logits_cu.data(), + getCurrentCUDAStreamRaw(routing_map.get_device_index())); } -std::tuple fused_score_for_moe_aux_loss_fwd( - at::Tensor logits, int topk, std::string score_function) { - int num_tokens = logits.size(0); - int num_experts = logits.size(1); - // Check if the input is valid - TORCH_CHECK(num_tokens > 0 && num_experts > 0, - "num_tokens and num_experts must be greater than 0"); - TORCH_CHECK(topk > 0, "topk must be greater than 0"); - // Check if the score function is valid - TORCH_CHECK(score_function == "softmax" || score_function == "sigmoid" || - score_function == "sqrtsoftplus", - "score_function must be softmax, sigmoid or sqrtsoftplus for router fusion"); +std::tuple fused_score_for_moe_aux_loss_fwd(Tensor logits, int64_t topk, + std::string score_function) { + int64_t num_tokens = logits.size(0); + int64_t num_experts = logits.size(1); + + STD_TORCH_CHECK(num_tokens > 0 && num_experts > 0, + "num_tokens and num_experts must be greater than 0"); + STD_TORCH_CHECK(topk > 0, "topk must be greater than 0"); int score_function_value = score_function_map[score_function]; - // Construct the output tensor - at::Tensor scores = at::empty({num_tokens, num_experts}, at::dtype(at::kFloat).device(at::kCUDA)); - at::Tensor routing_map = - at::empty({num_tokens, num_experts}, at::dtype(at::kBool).device(at::kCUDA)); - at::Tensor intermediate_output = - at::empty({num_tokens, num_experts}, at::dtype(at::kFloat).device(at::kCUDA)); + auto device_idx = logits.get_device_index(); + auto scores = allocateStableTensor({num_tokens, num_experts}, ScalarType::Float, device_idx); + auto routing_map = allocateStableTensor({num_tokens, num_experts}, ScalarType::Bool, device_idx); + auto intermediate_output = + allocateStableTensor({num_tokens, num_experts}, ScalarType::Float, device_idx); auto logits_cu = makeTransformerEngineTensor(logits); auto scores_cu = makeTransformerEngineTensor(scores); @@ -115,17 +107,17 @@ std::tuple fused_score_for_moe_aux_loss_fwd( auto intermediate_output_cu = makeTransformerEngineTensor(intermediate_output); nvte_fused_score_for_moe_aux_loss_forward( - logits_cu.data(), num_tokens, num_experts, topk, score_function_value, scores_cu.data(), - routing_map_cu.data(), intermediate_output_cu.data(), at::cuda::getCurrentCUDAStream()); + logits_cu.data(), static_cast(num_tokens), static_cast(num_experts), + static_cast(topk), score_function_value, scores_cu.data(), routing_map_cu.data(), + intermediate_output_cu.data(), getCurrentCUDAStreamRaw(device_idx)); return std::make_tuple(scores, routing_map, intermediate_output); } -void fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts, - at::Tensor intermediate_output, at::Tensor grad_scores, - at::Tensor grad_logits, int topk, +void fused_score_for_moe_aux_loss_bwd(int64_t num_tokens, int64_t num_experts, + Tensor intermediate_output, Tensor grad_scores, + Tensor grad_logits, int64_t topk, std::string score_function) { - // Get the value of the parameters int score_function_value = score_function_map[score_function]; auto intermediate_output_cu = makeTransformerEngineTensor(intermediate_output); @@ -133,52 +125,63 @@ void fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts, auto grad_logits_cu = makeTransformerEngineTensor(grad_logits); nvte_fused_score_for_moe_aux_loss_backward( - intermediate_output_cu.data(), grad_scores_cu.data(), num_tokens, num_experts, topk, - score_function_value, grad_logits_cu.data(), at::cuda::getCurrentCUDAStream()); + intermediate_output_cu.data(), grad_scores_cu.data(), static_cast(num_tokens), + static_cast(num_experts), static_cast(topk), score_function_value, + grad_logits_cu.data(), getCurrentCUDAStreamRaw(intermediate_output.get_device_index())); } -std::tuple fused_moe_aux_loss_fwd(at::Tensor probs, - at::Tensor tokens_per_expert, - int total_num_tokens, int num_experts, - int num_rows, int num_cols, int topk, - float coeff) { - TORCH_CHECK(topk > 0, "topk must be greater than 0"); - TORCH_CHECK(total_num_tokens > 0, "total_num_tokens must be greater than 0"); - TORCH_CHECK(num_experts > 0, "num_experts must be greater than 0"); +std::tuple fused_moe_aux_loss_fwd(Tensor probs, Tensor tokens_per_expert, + int64_t total_num_tokens, int64_t num_experts, + int64_t num_rows, int64_t num_cols, int64_t topk, + double coeff) { + STD_TORCH_CHECK(topk > 0, "topk must be greater than 0"); + STD_TORCH_CHECK(total_num_tokens > 0, "total_num_tokens must be greater than 0"); - // Create the output tensor - at::Tensor aux_loss = at::empty({}, at::dtype(probs.scalar_type()).device(at::kCUDA)); - at::Tensor Const_buf = at::empty({}, at::dtype(at::kFloat).device(at::kCUDA)); + auto device_idx = probs.get_device_index(); + // Scalar tensors (0-dim) + auto aux_loss = allocateStableTensor({}, probs.scalar_type(), device_idx); + auto Const_buf = allocateStableTensor({}, ScalarType::Float, device_idx); auto probs_cu = makeTransformerEngineTensor(probs); auto tokens_per_expert_cu = makeTransformerEngineTensor(tokens_per_expert); auto aux_loss_cu = makeTransformerEngineTensor(aux_loss); auto Const_buf_cu = makeTransformerEngineTensor(Const_buf); - nvte_fused_moe_aux_loss_forward(probs_cu.data(), tokens_per_expert_cu.data(), total_num_tokens, - num_experts, num_rows, num_cols, topk, coeff, aux_loss_cu.data(), - Const_buf_cu.data(), at::cuda::getCurrentCUDAStream()); + nvte_fused_moe_aux_loss_forward( + probs_cu.data(), tokens_per_expert_cu.data(), static_cast(total_num_tokens), + static_cast(num_experts), static_cast(num_rows), static_cast(num_cols), + static_cast(topk), static_cast(coeff), aux_loss_cu.data(), Const_buf_cu.data(), + getCurrentCUDAStreamRaw(device_idx)); return std::make_tuple(aux_loss, Const_buf); } -at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert, int num_rows, - int num_cols, at::Tensor grad_aux_loss) { - // Create the output tensor - at::Tensor grad_probs = - at::empty({num_rows, num_cols}, at::dtype(grad_aux_loss.scalar_type()).device(at::kCUDA)); +Tensor fused_moe_aux_loss_bwd(Tensor Const_buf, Tensor tokens_per_expert, int64_t num_rows, + int64_t num_cols, Tensor grad_aux_loss) { + auto device_idx = grad_aux_loss.get_device_index(); + auto grad_probs = + allocateStableTensor({num_rows, num_cols}, grad_aux_loss.scalar_type(), device_idx); auto Const_buf_cu = makeTransformerEngineTensor(Const_buf); auto tokens_per_expert_cu = makeTransformerEngineTensor(tokens_per_expert); auto grad_aux_loss_cu = makeTransformerEngineTensor(grad_aux_loss); auto grad_probs_cu = makeTransformerEngineTensor(grad_probs); - // Meta data for the kernel - nvte_fused_moe_aux_loss_backward(Const_buf_cu.data(), tokens_per_expert_cu.data(), num_rows, - num_cols, grad_aux_loss_cu.data(), grad_probs_cu.data(), - at::cuda::getCurrentCUDAStream()); + nvte_fused_moe_aux_loss_backward(Const_buf_cu.data(), tokens_per_expert_cu.data(), + static_cast(num_rows), static_cast(num_cols), + grad_aux_loss_cu.data(), grad_probs_cu.data(), + getCurrentCUDAStreamRaw(device_idx)); return grad_probs; } -} // namespace transformer_engine::pytorch +STABLE_TORCH_LIBRARY_IMPL(transformer_engine_stable, CUDA, m) { + m.impl("fused_topk_with_score_function_fwd", TORCH_BOX(fused_topk_with_score_function_fwd)); + m.impl("fused_topk_with_score_function_bwd", TORCH_BOX(fused_topk_with_score_function_bwd)); + m.impl("fused_score_for_moe_aux_loss_fwd", TORCH_BOX(fused_score_for_moe_aux_loss_fwd)); + m.impl("fused_score_for_moe_aux_loss_bwd", TORCH_BOX(fused_score_for_moe_aux_loss_bwd)); + m.impl("fused_moe_aux_loss_fwd", TORCH_BOX(fused_moe_aux_loss_fwd)); + m.impl("fused_moe_aux_loss_bwd", TORCH_BOX(fused_moe_aux_loss_bwd)); +} + +} // namespace transformer_engine::pytorch::stable diff --git a/transformer_engine/pytorch/csrc/extensions/softmax.cpp b/transformer_engine/pytorch/csrc/extensions/softmax.cpp index 3bb6a5e7b3..1c078f7604 100644 --- a/transformer_engine/pytorch/csrc/extensions/softmax.cpp +++ b/transformer_engine/pytorch/csrc/extensions/softmax.cpp @@ -4,234 +4,253 @@ * See LICENSE for license information. ************************************************************************/ -#include "../extensions.h" +#include -namespace transformer_engine::pytorch { +#include "../stable_common.h" -at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) { - AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); +namespace transformer_engine::pytorch::stable { - const int batches = input.size(0); - const int attn_heads = input.size(1); - const int query_seq_len = input.size(2); - const int key_seq_len = input.size(3); +using Tensor = torch::stable::Tensor; - AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); - AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); - AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1"); +Tensor scaled_softmax_forward(Tensor input, double scale_factor) { + STD_TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); + auto dtype = input.scalar_type(); + STD_TORCH_CHECK(dtype == ScalarType::Half || dtype == ScalarType::BFloat16, + "Only fp16 and bf16 are supported"); - // Output - auto act_options = input.options().requires_grad(false); - auto softmax_results = - torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + const int64_t batches = input.size(0); + const int64_t attn_heads = input.size(1); + const int64_t query_seq_len = input.size(2); + const int64_t key_seq_len = input.size(3); + + STD_TORCH_CHECK(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); + STD_TORCH_CHECK(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); + STD_TORCH_CHECK(query_seq_len > 1, "Query sequence length must be greater than 1"); + + // Allocate output + std::vector out_shape = {batches, attn_heads, query_seq_len, key_seq_len}; + auto softmax_results = allocateStableTensor(out_shape, dtype, input.get_device_index()); auto input_cu = makeTransformerEngineTensor(input); auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - nvte_scaled_softmax_forward(input_cu.data(), softmax_results_cu.data(), scale_factor, - at::cuda::getCurrentCUDAStream()); + nvte_scaled_softmax_forward(input_cu.data(), softmax_results_cu.data(), + static_cast(scale_factor), + getCurrentCUDAStreamRaw(input.get_device_index())); return softmax_results; } -at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, - float scale_factor) { - auto output_grads = output_grad_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); +Tensor scaled_softmax_backward(Tensor output_grad_, Tensor softmax_results_, double scale_factor) { + auto output_grads = torch::stable::contiguous(output_grad_); + auto softmax_results = torch::stable::contiguous(softmax_results_); - AT_ASSERTM(output_grads.dim() == 4, "expected 4D tensor"); - AT_ASSERTM(softmax_results.dim() == 4, "expected 4D tensor"); + STD_TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); + STD_TORCH_CHECK(softmax_results.dim() == 4, "expected 4D tensor"); - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + auto og_dtype = output_grads.scalar_type(); + auto sr_dtype = softmax_results.scalar_type(); + STD_TORCH_CHECK(og_dtype == ScalarType::Half || og_dtype == ScalarType::BFloat16, + "Only fp16 and bf16 are supported"); + STD_TORCH_CHECK(sr_dtype == ScalarType::Half || sr_dtype == ScalarType::BFloat16, + "Only fp16 and bf16 are supported"); auto output_grads_cu = makeTransformerEngineTensor(output_grads); auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); // Produce gradients in place. nvte_scaled_softmax_backward(output_grads_cu.data(), softmax_results_cu.data(), - output_grads_cu.data(), scale_factor, - at::cuda::getCurrentCUDAStream()); + output_grads_cu.data(), static_cast(scale_factor), + getCurrentCUDAStreamRaw(output_grads.get_device_index())); return output_grads; } -at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor) { - AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); - if (!input.is_contiguous()) input = input.contiguous(); - if (!mask.is_contiguous()) mask = mask.contiguous(); - - const int batches = input.size(0); - const int pad_batches = mask.size(0); - const int attn_heads = input.size(1); - const int query_seq_len = input.size(2); - const int key_seq_len = input.size(3); - - AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); - AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); - AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1"); - TORCH_CHECK(pad_batches == 1 || pad_batches == batches); - TORCH_CHECK(mask.size(1) == 1); - TORCH_CHECK(mask.size(2) == query_seq_len); - TORCH_CHECK(mask.size(3) == key_seq_len); - - auto act_options = input.options().requires_grad(false); - auto softmax_results = - torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); +Tensor scaled_masked_softmax_forward(Tensor input, Tensor mask, double scale_factor) { + STD_TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); + auto dtype = input.scalar_type(); + STD_TORCH_CHECK(dtype == ScalarType::Half || dtype == ScalarType::BFloat16, + "Only fp16 and bf16 are supported"); + STD_TORCH_CHECK(mask.dim() == 4, "expected 4D tensor"); + + if (!input.is_contiguous()) input = torch::stable::contiguous(input); + if (!mask.is_contiguous()) mask = torch::stable::contiguous(mask); + + const int64_t batches = input.size(0); + const int64_t pad_batches = mask.size(0); + const int64_t attn_heads = input.size(1); + const int64_t query_seq_len = input.size(2); + const int64_t key_seq_len = input.size(3); + + STD_TORCH_CHECK(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); + STD_TORCH_CHECK(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); + STD_TORCH_CHECK(query_seq_len > 1, "Query sequence length must be greater than 1"); + STD_TORCH_CHECK(pad_batches == 1 || pad_batches == batches, + "Mask batch size must be 1 or match input batch size"); + STD_TORCH_CHECK(mask.size(1) == 1, "Mask dim 1 must be 1"); + STD_TORCH_CHECK(mask.size(2) == query_seq_len, "Mask dim 2 must match query_seq_len"); + STD_TORCH_CHECK(mask.size(3) == key_seq_len, "Mask dim 3 must match key_seq_len"); + + std::vector out_shape = {batches, attn_heads, query_seq_len, key_seq_len}; + auto softmax_results = allocateStableTensor(out_shape, dtype, input.get_device_index()); auto input_cu = makeTransformerEngineTensor(input); auto mask_cu = makeTransformerEngineTensor(mask); auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); nvte_scaled_masked_softmax_forward(input_cu.data(), mask_cu.data(), softmax_results_cu.data(), - scale_factor, at::cuda::getCurrentCUDAStream()); + static_cast(scale_factor), + getCurrentCUDAStreamRaw(input.get_device_index())); return softmax_results; } -at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, - float scale_factor) { - auto output_grads = output_grad_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); +Tensor scaled_masked_softmax_backward(Tensor output_grad_, Tensor softmax_results_, + double scale_factor) { + auto output_grads = torch::stable::contiguous(output_grad_); + auto softmax_results = torch::stable::contiguous(softmax_results_); - AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); + STD_TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); + STD_TORCH_CHECK(softmax_results.dim() == 4, "expected 4D tensor"); - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + auto og_dtype = output_grads.scalar_type(); + auto sr_dtype = softmax_results.scalar_type(); + STD_TORCH_CHECK(og_dtype == ScalarType::Half || og_dtype == ScalarType::BFloat16, + "Only fp16 and bf16 are supported"); + STD_TORCH_CHECK(sr_dtype == ScalarType::Half || sr_dtype == ScalarType::BFloat16, + "Only fp16 and bf16 are supported"); auto output_grads_cu = makeTransformerEngineTensor(output_grads); auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - // Produce gradients in place. nvte_scaled_softmax_backward(output_grads_cu.data(), softmax_results_cu.data(), - output_grads_cu.data(), scale_factor, - at::cuda::getCurrentCUDAStream()); + output_grads_cu.data(), static_cast(scale_factor), + getCurrentCUDAStreamRaw(output_grads.get_device_index())); return output_grads; } -at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor) { - AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); +Tensor scaled_upper_triang_masked_softmax_forward(Tensor input, double scale_factor) { + STD_TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); + auto dtype = input.scalar_type(); + STD_TORCH_CHECK(dtype == ScalarType::Half || dtype == ScalarType::BFloat16, + "Only fp16 and bf16 are supported"); - const int attn_batches = input.size(0); - const int seq_len = input.size(1); - AT_ASSERTM(seq_len <= 16384, "Sequence length must be 16384 or less"); + const int64_t attn_batches = input.size(0); + const int64_t seq_len = input.size(1); + STD_TORCH_CHECK(seq_len <= 16384, "Sequence length must be 16384 or less"); - // Output - auto act_options = input.options().requires_grad(false); - auto softmax_results = torch::empty({attn_batches, seq_len, seq_len}, act_options); + std::vector out_shape = {attn_batches, seq_len, seq_len}; + auto softmax_results = allocateStableTensor(out_shape, dtype, input.get_device_index()); auto input_cu = makeTransformerEngineTensor(input); auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - nvte_scaled_upper_triang_masked_softmax_forward(input_cu.data(), softmax_results_cu.data(), - scale_factor, at::cuda::getCurrentCUDAStream()); + nvte_scaled_upper_triang_masked_softmax_forward( + input_cu.data(), softmax_results_cu.data(), static_cast(scale_factor), + getCurrentCUDAStreamRaw(input.get_device_index())); return softmax_results; } -at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, - at::Tensor softmax_results_, - float scale_factor) { - auto output_grads = output_grads_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); +Tensor scaled_upper_triang_masked_softmax_backward(Tensor output_grads_, Tensor softmax_results_, + double scale_factor) { + auto output_grads = torch::stable::contiguous(output_grads_); + auto softmax_results = torch::stable::contiguous(softmax_results_); - AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); + STD_TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); + STD_TORCH_CHECK(softmax_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + auto og_dtype = output_grads.scalar_type(); + auto sr_dtype = softmax_results.scalar_type(); + STD_TORCH_CHECK(og_dtype == ScalarType::Half || og_dtype == ScalarType::BFloat16, + "Only fp16 and bf16 are supported"); + STD_TORCH_CHECK(sr_dtype == ScalarType::Half || sr_dtype == ScalarType::BFloat16, + "Only fp16 and bf16 are supported"); - TORCH_CHECK(output_grads.size(1) == output_grads.size(2)); + STD_TORCH_CHECK(output_grads.size(1) == output_grads.size(2), + "Output grads dim 1 and dim 2 must match"); auto output_grads_cu = makeTransformerEngineTensor(output_grads); auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - // Produce gradients in place. nvte_scaled_upper_triang_masked_softmax_backward( - output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), scale_factor, - at::cuda::getCurrentCUDAStream()); + output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), + static_cast(scale_factor), getCurrentCUDAStreamRaw(output_grads.get_device_index())); return output_grads; } -at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor) { - AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); +Tensor scaled_aligned_causal_masked_softmax_forward(Tensor input, double scale_factor) { + STD_TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); + auto dtype = input.scalar_type(); + STD_TORCH_CHECK(dtype == ScalarType::Half || dtype == ScalarType::BFloat16, + "Only fp16 and bf16 are supported"); - const int batches = input.size(0); - const int attn_heads = input.size(1); - const int query_seq_len = input.size(2); - const int key_seq_len = input.size(3); + const int64_t batches = input.size(0); + const int64_t attn_heads = input.size(1); + const int64_t query_seq_len = input.size(2); + const int64_t key_seq_len = input.size(3); - AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); - AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); - AT_ASSERTM(query_seq_len >= 1, "Query sequence length must be greater or equal to 1"); + STD_TORCH_CHECK(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); + STD_TORCH_CHECK(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); + STD_TORCH_CHECK(query_seq_len >= 1, "Query sequence length must be greater or equal to 1"); - // Output - auto act_options = input.options().requires_grad(false); - auto softmax_results = - torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + std::vector out_shape = {batches, attn_heads, query_seq_len, key_seq_len}; + auto softmax_results = allocateStableTensor(out_shape, dtype, input.get_device_index()); auto input_cu = makeTransformerEngineTensor(input); auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - nvte_scaled_aligned_causal_masked_softmax_forward(input_cu.data(), softmax_results_cu.data(), - scale_factor, at::cuda::getCurrentCUDAStream()); + nvte_scaled_aligned_causal_masked_softmax_forward( + input_cu.data(), softmax_results_cu.data(), static_cast(scale_factor), + getCurrentCUDAStreamRaw(input.get_device_index())); return softmax_results; } -at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grad_, - at::Tensor softmax_results_, - float scale_factor) { - auto output_grads = output_grad_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); +Tensor scaled_aligned_causal_masked_softmax_backward(Tensor output_grad_, Tensor softmax_results_, + double scale_factor) { + auto output_grads = torch::stable::contiguous(output_grad_); + auto softmax_results = torch::stable::contiguous(softmax_results_); - AT_ASSERTM(output_grads.dim() == 4, "expected 4D tensor"); - AT_ASSERTM(softmax_results.dim() == 4, "expected 4D tensor"); + STD_TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); + STD_TORCH_CHECK(softmax_results.dim() == 4, "expected 4D tensor"); - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + auto og_dtype = output_grads.scalar_type(); + auto sr_dtype = softmax_results.scalar_type(); + STD_TORCH_CHECK(og_dtype == ScalarType::Half || og_dtype == ScalarType::BFloat16, + "Only fp16 and bf16 are supported"); + STD_TORCH_CHECK(sr_dtype == ScalarType::Half || sr_dtype == ScalarType::BFloat16, + "Only fp16 and bf16 are supported"); auto output_grads_cu = makeTransformerEngineTensor(output_grads); auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); - // Produce gradients in place. nvte_scaled_aligned_causal_masked_softmax_backward( - output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), scale_factor, - at::cuda::getCurrentCUDAStream()); + output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), + static_cast(scale_factor), getCurrentCUDAStreamRaw(output_grads.get_device_index())); return output_grads; } -} // namespace transformer_engine::pytorch +// ============================================================================ +// Op registration via stable ABI (schemas defined in registration.cpp) +// ============================================================================ + +STABLE_TORCH_LIBRARY_IMPL(transformer_engine_stable, CUDA, m) { + m.impl("scaled_softmax_forward", TORCH_BOX(scaled_softmax_forward)); + m.impl("scaled_softmax_backward", TORCH_BOX(scaled_softmax_backward)); + m.impl("scaled_masked_softmax_forward", TORCH_BOX(scaled_masked_softmax_forward)); + m.impl("scaled_masked_softmax_backward", TORCH_BOX(scaled_masked_softmax_backward)); + m.impl("scaled_upper_triang_masked_softmax_forward", + TORCH_BOX(scaled_upper_triang_masked_softmax_forward)); + m.impl("scaled_upper_triang_masked_softmax_backward", + TORCH_BOX(scaled_upper_triang_masked_softmax_backward)); + m.impl("scaled_aligned_causal_masked_softmax_forward", + TORCH_BOX(scaled_aligned_causal_masked_softmax_forward)); + m.impl("scaled_aligned_causal_masked_softmax_backward", + TORCH_BOX(scaled_aligned_causal_masked_softmax_backward)); +} + +} // namespace transformer_engine::pytorch::stable diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp deleted file mode 100644 index 7ff35d6b68..0000000000 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ /dev/null @@ -1,478 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include -#include - -#include "common.h" -#include "common/common.h" -#include "extensions.h" -#include "pybind.h" -#include "util.h" - -namespace transformer_engine { -namespace pytorch { - -namespace { - -void reset_tensor_data(transformer_engine::TensorWrapper &tensor, bool rowwise, bool columnwise) { - NVTEShape shape; - shape.ndim = 1; - shape.data[0] = 0; - const transformer_engine::DType dtype = transformer_engine::DType::kFloat32; - if (rowwise) { - tensor.set_rowwise_data(nullptr, dtype, shape); - tensor.set_rowwise_scale_inv(nullptr, dtype, shape); - } - if (columnwise) { - tensor.set_columnwise_data(nullptr, dtype, shape); - tensor.set_columnwise_scale_inv(nullptr, dtype, shape); - } -} - -bool is_empty_grouped_tensor_param(const NVTEBasicTensor &t) { - if (t.data_ptr == nullptr) { - return true; - } - return t.shape.ndim == 1 && t.shape.data[0] == 0; -} - -} // namespace - -std::tuple, std::optional> swizzle_scales_for_gemm( - transformer_engine::TensorWrapper &tensor, bool rowwise_usage, bool columnwise_usage) { - // Return early if scale swizzling is not required - const auto scaling_mode = tensor.scaling_mode(); - switch (scaling_mode) { - case NVTE_MXFP8_1D_SCALING: - case NVTE_NVFP4_1D_SCALING: - // Tensor format requires scale swizzling - break; - case NVTE_INVALID_SCALING: - NVTE_ERROR("Invalid scaling mode for swizzling scaling factors."); - default: - // Tensor format does not require scale swizzling for GEMM - return {std::nullopt, std::nullopt}; - } - - // Return early if scales are already swizzled - if (tensor.get_with_gemm_swizzled_scales()) { - return {std::nullopt, std::nullopt}; - } - - // CUDA stream - auto stream = at::cuda::getCurrentCUDAStream(); - - // Swizzle row-wise scales if needed - std::optional rowwise_scales_pyt; - if (rowwise_usage) { - // Buffer for unswizzled scales - const auto input_scales_nvte = tensor.get_rowwise_scale_inv(); - void *input_scales_dptr = input_scales_nvte.data_ptr; - const NVTEShape input_scales_shape = input_scales_nvte.shape; - const auto scales_dtype = static_cast(input_scales_nvte.dtype); - - // Allocate buffer for swizzled scales - const NVTEShape output_scales_shape = input_scales_shape; - rowwise_scales_pyt = allocateSpace(input_scales_shape, scales_dtype, false); - void *output_scales_dptr = getDataPtr(*rowwise_scales_pyt); - - // Initialize TE tensors with scales - const auto data_nvte = tensor.get_rowwise_data(); - const auto data_dtype = static_cast(data_nvte.dtype); - TensorWrapper input_nvte(scaling_mode); - input_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape); - input_nvte.set_rowwise_scale_inv(input_scales_dptr, scales_dtype, input_scales_shape); - TensorWrapper output_nvte(scaling_mode); - output_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape); - output_nvte.set_rowwise_scale_inv(output_scales_dptr, scales_dtype, output_scales_shape); - output_nvte.set_with_gemm_swizzled_scales(true); - - // Launch kernel - NVTE_SCOPED_GIL_RELEASE( - { nvte_swizzle_scaling_factors(input_nvte.data(), output_nvte.data(), stream); }); - - // Update tensor with swizzled scales - tensor.set_rowwise_scale_inv(output_scales_dptr, scales_dtype, output_scales_shape); - } - - // Swizzle column-wise scales if needed - std::optional columnwise_scales_pyt; - if (columnwise_usage) { - // Buffer for unswizzled scales - const auto input_scales_nvte = tensor.get_columnwise_scale_inv(); - void *input_scales_dptr = input_scales_nvte.data_ptr; - const NVTEShape input_scales_shape = input_scales_nvte.shape; - const auto scales_dtype = static_cast(input_scales_nvte.dtype); - - // Allocate buffer for swizzled scales - const NVTEShape output_scales_shape = input_scales_shape; - columnwise_scales_pyt = allocateSpace(input_scales_shape, scales_dtype, false); - void *output_scales_dptr = getDataPtr(*columnwise_scales_pyt); - - // Initialize TE tensors with scales - const auto data_nvte = tensor.get_columnwise_data(); - const auto data_dtype = static_cast(data_nvte.dtype); - TensorWrapper input_nvte(scaling_mode); - input_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape); - input_nvte.set_columnwise_scale_inv(input_scales_dptr, scales_dtype, input_scales_shape); - TensorWrapper output_nvte(scaling_mode); - output_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape); - output_nvte.set_columnwise_scale_inv(output_scales_dptr, scales_dtype, output_scales_shape); - output_nvte.set_with_gemm_swizzled_scales(true); - - // Launch kernel - NVTE_SCOPED_GIL_RELEASE( - { nvte_swizzle_scaling_factors(input_nvte.data(), output_nvte.data(), stream); }); - - // Update tensor with swizzled scales - tensor.set_columnwise_scale_inv(output_scales_dptr, scales_dtype, output_scales_shape); - } - - // Update tensor - reset_tensor_data(tensor, !rowwise_usage, !columnwise_usage); - tensor.set_with_gemm_swizzled_scales(true); - - return {std::move(rowwise_scales_pyt), std::move(columnwise_scales_pyt)}; -} - -std::optional multi_tensor_swizzle_scales_for_gemm( - std::vector &tensors, bool rowwise_usage, - bool columnwise_usage) { - // Checks and trivial cases - NVTE_CHECK(rowwise_usage != columnwise_usage, - "Expect exactly one of rowwise_usage=", rowwise_usage, - " and columnwise_usage=", columnwise_usage, "."); - if (tensors.empty()) { - return std::nullopt; - } - const auto scaling_mode = tensors.front().scaling_mode(); - for (const auto &tensor : tensors) { - NVTE_CHECK(tensor.scaling_mode() == scaling_mode, "Tensors have different scaling modes"); - } - - // Return early if scale swizzling is not required - switch (scaling_mode) { - case NVTE_MXFP8_1D_SCALING: - case NVTE_NVFP4_1D_SCALING: - // Tensor format requires scale swizzling - break; - case NVTE_INVALID_SCALING: - NVTE_ERROR("Invalid scaling mode for swizzling scaling factors."); - default: - // Tensor format does not require scale swizzling for GEMM - return std::nullopt; - } - - // Filter out tensors that already have swizzled scales - std::vector tensors_needing_swizzle; - for (auto &tensor : tensors) { - if (!tensor.get_with_gemm_swizzled_scales()) { - tensors_needing_swizzle.push_back(&tensor); - } - } - if (tensors_needing_swizzle.empty()) { - return std::nullopt; - } - - // Determine buffer size needed for swizzled scales - std::vector output_scales_offsets; - size_t output_scales_bytes = 0; - for (auto &tensor : tensors_needing_swizzle) { - const auto scales_nvte = - (rowwise_usage ? tensor->get_rowwise_scale_inv() : tensor->get_columnwise_scale_inv()); - const auto &shape = scales_nvte.shape; - const auto dtype = static_cast(scales_nvte.dtype); - const auto dtype_bits = transformer_engine::pytorch::typeToNumBits(dtype); - const auto size = product(shape, 0, shape.ndim); - output_scales_bytes = roundup(output_scales_bytes, 16); // align to 16B - output_scales_offsets.push_back(output_scales_bytes); - output_scales_bytes += ceildiv(size * dtype_bits, 8); - } - - // Allocate buffer for swizzled scales - auto output_scales_pyt = allocateSpace(std::vector{output_scales_bytes}, - transformer_engine::DType::kByte, false); - uint8_t *output_scales_dptr = reinterpret_cast(getDataPtr(output_scales_pyt)); - - // Construct TE tensors with only scales - std::vector inputs_nvte, outputs_nvte; - for (size_t i = 0; i < tensors_needing_swizzle.size(); ++i) { - auto &tensor = *tensors_needing_swizzle[i]; - inputs_nvte.emplace_back(scaling_mode); - outputs_nvte.emplace_back(scaling_mode); - auto &input_nvte = inputs_nvte.back(); - auto &output_nvte = outputs_nvte.back(); - output_nvte.set_with_gemm_swizzled_scales(true); - if (rowwise_usage) { - const auto data_nvte = tensor.get_rowwise_data(); - const auto scales_nvte = tensor.get_rowwise_scale_inv(); - const auto data_dtype = static_cast(data_nvte.dtype); - const auto scales_dtype = static_cast(scales_nvte.dtype); - input_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape); - input_nvte.set_rowwise_scale_inv(scales_nvte.data_ptr, scales_dtype, scales_nvte.shape); - output_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape); - output_nvte.set_rowwise_scale_inv(output_scales_dptr + output_scales_offsets[i], scales_dtype, - scales_nvte.shape); - } else { - const auto data_nvte = tensor.get_columnwise_data(); - const auto scales_nvte = tensor.get_columnwise_scale_inv(); - const auto data_dtype = static_cast(data_nvte.dtype); - const auto scales_dtype = static_cast(scales_nvte.dtype); - input_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape); - input_nvte.set_columnwise_scale_inv(scales_nvte.data_ptr, scales_dtype, scales_nvte.shape); - output_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape); - output_nvte.set_columnwise_scale_inv(output_scales_dptr + output_scales_offsets[i], - scales_dtype, scales_nvte.shape); - } - } - - // Pack raw NVTETensors into vectors - std::vector inputs_nvte_raw, outputs_nvte_raw; - for (auto &tensor : inputs_nvte) { - inputs_nvte_raw.emplace_back(tensor.data()); - } - for (auto &tensor : outputs_nvte) { - outputs_nvte_raw.emplace_back(tensor.data()); - } - - // Launch kernel - NVTE_SCOPED_GIL_RELEASE({ - nvte_multi_tensor_swizzle_scaling_factors(inputs_nvte_raw.data(), outputs_nvte_raw.data(), - inputs_nvte_raw.size(), - at::cuda::getCurrentCUDAStream()); - }); - - // Update tensors with swizzled scales - for (size_t i = 0; i < tensors_needing_swizzle.size(); ++i) { - auto &tensor = *tensors_needing_swizzle[i]; - reset_tensor_data(tensor, !rowwise_usage, !columnwise_usage); - tensor.set_with_gemm_swizzled_scales(true); - if (rowwise_usage) { - auto scales_nvte = outputs_nvte[i].get_rowwise_scale_inv(); - const auto scales_dtype = static_cast(scales_nvte.dtype); - tensor.set_rowwise_scale_inv(output_scales_dptr + output_scales_offsets[i], scales_dtype, - scales_nvte.shape); - } else { - auto scales_nvte = outputs_nvte[i].get_columnwise_scale_inv(); - const auto scales_dtype = static_cast(scales_nvte.dtype); - tensor.set_columnwise_scale_inv(output_scales_dptr + output_scales_offsets[i], scales_dtype, - scales_nvte.shape); - } - } - - return std::move(output_scales_pyt); -} - -at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper &input, - bool rowwise) { - // Check input tensor - const NVTEScalingMode scaling_mode = input.scaling_mode(); - NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D, - "Input tensor must be a block scaling tensor"); - - // Get tensor data - NVTEBasicTensor data; - size_t data_flat_first_dim = 1; - size_t data_flat_last_dim = 1; - if (rowwise) { - data = input.get_rowwise_data(); - for (size_t i = 0; i < data.shape.ndim - 1; ++i) { - data_flat_first_dim *= data.shape.data[i]; - } - data_flat_last_dim = data.shape.data[data.shape.ndim - 1]; - } else { - data = input.get_columnwise_data(); - data_flat_first_dim = data.shape.data[0]; - for (size_t i = 1; i < data.shape.ndim; ++i) { - data_flat_last_dim *= data.shape.data[i]; - } - } - NVTEShape data_shape{}; - data_shape.data[0] = data_flat_first_dim; - data_shape.data[1] = data_flat_last_dim; - data_shape.ndim = 2; - - // Recreate input tensor with rowwise usage - transformer_engine::TensorWrapper input_cu(scaling_mode); - input_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape); - const NVTEBasicTensor scale_inv = - rowwise ? input.get_rowwise_scale_inv() : input.get_columnwise_scale_inv(); - input_cu.set_rowwise_scale_inv( - scale_inv.data_ptr, static_cast(scale_inv.dtype), scale_inv.shape); - - // Create output tensor - transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); - output_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape); - // Output swizzled mxfp8 scaling factor dimensions - const size_t swizzled_scale_inv_first_dim = ceildiv(data_flat_first_dim, 128) * 128; - const size_t swizzled_scale_inv_last_dim = ceildiv(data_flat_last_dim, 128) * 4; - // Allocate memory for swizzled mxfp8 scaling factors - at::Tensor swizzled_scale_inv = - allocateSpace(std::vector{swizzled_scale_inv_first_dim, swizzled_scale_inv_last_dim}, - transformer_engine::DType::kByte, false); - // Set rowwise scaling factors on output - void *const swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); - NVTEShape swizzled_scale_inv_shape{}; - swizzled_scale_inv_shape.data[0] = swizzled_scale_inv_first_dim; - swizzled_scale_inv_shape.data[1] = swizzled_scale_inv_last_dim; - swizzled_scale_inv_shape.ndim = 2; - output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - swizzled_scale_inv_shape); - output_cu.set_with_gemm_swizzled_scales(true); - - // Convert scaling factors from FP8 block scaling GEMM_READY format to mxfp8 swizzled format - NVTE_SCOPED_GIL_RELEASE({ - nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(input_cu.data(), output_cu.data(), - at::cuda::getCurrentCUDAStream()); - }); - - // Set the input tensor to be the converted mxfp8 tensor and return the swizzled scaling factor - // for it to be kept alive during the GEMM - input = std::move(output_cu); - return swizzled_scale_inv; -} - -std::optional maybe_swizzle_grouped_tensor_for_gemm( - GroupedTensorWrapper &input) { - if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING) { - return std::nullopt; - } - if (input.get_with_gemm_swizzled_scales()) { - return std::nullopt; - } - - const auto row_scales = input.get_rowwise_scale_inv(); - const auto col_scales = input.get_columnwise_scale_inv(); - const bool has_rowwise_scales = !is_empty_grouped_tensor_param(row_scales); - const bool has_columnwise_scales = !is_empty_grouped_tensor_param(col_scales); - if (!has_rowwise_scales && !has_columnwise_scales) { - return std::nullopt; - } - const auto first_dims = input.get_first_dims(); - const auto last_dims = input.get_last_dims(); - if (first_dims.data_ptr != nullptr || last_dims.data_ptr != nullptr) { - NVTE_ERROR( - "Grouped GEMM swizzle requires uniform shapes for now (first_dims/last_dims must be " - "absent)."); - } - - std::optional rowwise_scales_pyt; - std::optional columnwise_scales_pyt; - GroupedTensorWrapper output(input.num_tensors(), input.logical_shape(), input.scaling_mode()); - - const auto rowwise_data = input.get_rowwise_data(); - if (rowwise_data.data_ptr != nullptr) { - output.set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), - rowwise_data.shape); - } - const auto columnwise_data = input.get_columnwise_data(); - if (columnwise_data.data_ptr != nullptr) { - output.set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); - } - const auto tensor_offsets = input.get_tensor_offsets(); - if (tensor_offsets.data_ptr != nullptr) { - output.set_tensor_offsets(tensor_offsets.data_ptr, static_cast(tensor_offsets.dtype), - tensor_offsets.shape); - } - - if (has_rowwise_scales) { - const auto scales_dtype = static_cast(row_scales.dtype); - rowwise_scales_pyt = allocateSpace(row_scales.shape, scales_dtype, false); - void *output_scales_dptr = getDataPtr(*rowwise_scales_pyt); - output.set_rowwise_scale_inv(output_scales_dptr, scales_dtype, row_scales.shape); - } - if (has_columnwise_scales) { - const auto scales_dtype = static_cast(col_scales.dtype); - columnwise_scales_pyt = allocateSpace(col_scales.shape, scales_dtype, false); - void *output_scales_dptr = getDataPtr(*columnwise_scales_pyt); - output.set_columnwise_scale_inv(output_scales_dptr, scales_dtype, col_scales.shape); - } - - output.set_with_gemm_swizzled_scales(true); - NVTE_SCOPED_GIL_RELEASE({ - nvte_swizzle_grouped_scaling_factors(input.data(), output.data(), - at::cuda::getCurrentCUDAStream()); - }); - - if (has_rowwise_scales) { - const auto scales_dtype = static_cast(row_scales.dtype); - input.set_rowwise_scale_inv(getDataPtr(*rowwise_scales_pyt), scales_dtype, row_scales.shape); - } - if (has_columnwise_scales) { - const auto scales_dtype = static_cast(col_scales.dtype); - input.set_columnwise_scale_inv(getDataPtr(*columnwise_scales_pyt), scales_dtype, - col_scales.shape); - } - input.set_with_gemm_swizzled_scales(true); - - return SwizzledGroupedScales{std::move(rowwise_scales_pyt), std::move(columnwise_scales_pyt)}; -} - -void inplace_swizzle_scale_for_gemm(py::handle &tensor) { - // Convert Python tensor to C++ tensor - auto tensor_nvte = makeTransformerEngineTensor(tensor, py::none()); - - // Return early if scale swizzling is not required - const auto scaling_mode = tensor_nvte.scaling_mode(); - switch (scaling_mode) { - case NVTE_MXFP8_1D_SCALING: - case NVTE_NVFP4_1D_SCALING: - // Tensor format requires scale swizzling - break; - case NVTE_INVALID_SCALING: - NVTE_ERROR("Invalid scaling mode for swizzling scaling factors."); - default: - // Tensor format does not require scale swizzling for GEMM - return; - } - - // Return early if scales are already swizzled - if (tensor_nvte.get_with_gemm_swizzled_scales()) { - return; - } - - // Check what scaling factors the tensor contains - auto is_empty = [](const NVTEBasicTensor &t) -> bool { - return t.shape.ndim == 1 && t.shape.data[0] == 0; - }; - const bool has_rowwise_scales = !is_empty(tensor_nvte.get_rowwise_scale_inv()); - const bool has_columnwise_scales = !is_empty(tensor_nvte.get_columnwise_scale_inv()); - - // Swizzle scaling factors - auto [rowwise_scales, columnwise_scales] = - swizzle_scales_for_gemm(tensor_nvte, has_rowwise_scales, has_columnwise_scales); - - // Update Python tensor with swizzled scales - switch (scaling_mode) { - case NVTE_MXFP8_1D_SCALING: - if (has_rowwise_scales) { - tensor.attr("_rowwise_scale_inv") = rowwise_scales; - } - if (has_columnwise_scales) { - tensor.attr("_columnwise_scale_inv") = columnwise_scales; - } - tensor.attr("_with_gemm_swizzled_scales") = true; - break; - case NVTE_NVFP4_1D_SCALING: - if (has_rowwise_scales) { - tensor.attr("_rowwise_scale_inv") = rowwise_scales; - } - if (has_columnwise_scales) { - tensor.attr("_columnwise_scale_inv") = columnwise_scales; - } - tensor.attr("_with_gemm_swizzled_scales") = true; - break; - default: - NVTE_ERROR("Invalid scaling mode for swizzling scaling factors."); - } -} - -} // namespace pytorch -} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index aaa27a104a..4b7c88d294 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -4,184 +4,119 @@ * See LICENSE for license information. ************************************************************************/ -#include #include #include -#include -#include +#include "../stable_common.h" -#include "../extensions.h" -#include "pybind.h" +namespace transformer_engine::pytorch::stable { -namespace transformer_engine { -namespace pytorch { +using Tensor = torch::stable::Tensor; -at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional output) { - init_extension(); +Tensor fp8_transpose(Tensor input, int64_t otype, std::optional output) { + auto shape = getStableTensorShape(input); + auto te_otype = static_cast(otype); - // Tensor dimensions - const auto shape = getTensorShape(input); std::vector transpose_shape_int64; - if (shape.size() > 0) { - transpose_shape_int64.push_back(shape.back()); - for (size_t i = 0; i < shape.size() - 1; ++i) { - transpose_shape_int64.push_back(shape[i]); + if (!shape.empty()) { + transpose_shape_int64.push_back(static_cast(shape.back())); + for (size_t i = 0; i + 1 < shape.size(); ++i) { + transpose_shape_int64.push_back(static_cast(shape[i])); } } - const size_t M = shape.size() > 0 ? product(shape) / shape.back() : 1; - const size_t N = shape.size() > 0 ? shape.back() : 1; + const size_t M = shape.empty() ? 1 : (shape.size() > 1 ? 1 : shape[0]); + size_t total = 1; + for (auto s : shape) total *= s; + const size_t N = shape.empty() ? 1 : shape.back(); + const size_t M_actual = shape.empty() ? 1 : total / N; - // Output tensor - at::Tensor out; + Tensor out; if (output.has_value()) { - out = *output; + out = output.value(); } else { - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - out = at::empty(transpose_shape_int64, opts); + out = allocateStableTensor(transpose_shape_int64, ScalarType::Byte, input.get_device_index()); } - // Return immediately if tensor is empty - if (M == 0 || N == 0) { - return out; - } + if (M_actual == 0 || N == 0) return out; - // Compute transpose - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), std::vector{M, N}, otype); - auto output_cu = makeTransformerEngineTensor(out.data_ptr(), std::vector{N, M}, otype); - nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + auto input_cu = + makeTransformerEngineTensor(input.data_ptr(), std::vector{M_actual, N}, te_otype); + auto output_cu = + makeTransformerEngineTensor(out.data_ptr(), std::vector{N, M_actual}, te_otype); + nvte_transpose(input_cu.data(), output_cu.data(), + getCurrentCUDAStreamRaw(input.get_device_index())); return out; } -at::Tensor nvfp4_data_transpose(at::Tensor input, std::optional output) { - init_extension(); - - // Input is packed FP4: logical [M, K] stored as [M, K/2] bytes - // Output is packed FP4: logical [K, M] stored as [K, M/2] bytes - const auto shape = getTensorShape(input); - NVTE_CHECK(shape.size() == 2, "NVFP4 transpose expects 2D input (packed storage)."); +Tensor nvfp4_data_transpose(Tensor input, std::optional output) { + auto shape = getStableTensorShape(input); + NVTE_CHECK(shape.size() == 2, "NVFP4 transpose expects 2D input."); const size_t M = shape[0]; const size_t K_packed = shape[1]; - const size_t K = K_packed * 2; // logical K + const size_t K = K_packed * 2; const size_t M_packed = M / 2; + NVTE_CHECK(M % 2 == 0, "NVFP4 transpose requires M to be even."); - NVTE_CHECK(M % 2 == 0, "NVFP4 transpose requires M (", M, ") to be even."); - - // Output shape: [K, M/2] - std::vector output_shape = {static_cast(K), static_cast(M_packed)}; - - // Output tensor - at::Tensor out; + Tensor out; if (output.has_value()) { - out = *output; - NVTE_CHECK( - static_cast(out.size(0)) == K && static_cast(out.size(1)) == M_packed, - "Output shape mismatch for NVFP4 transpose."); + out = output.value(); } else { - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - out = at::empty(output_shape, opts); + out = allocateStableTensor({static_cast(K), static_cast(M_packed)}, + ScalarType::Byte, input.get_device_index()); } - // Return immediately if tensor is empty - if (M == 0 || K == 0) { - return out; - } + if (M == 0 || K == 0) return out; - // Call the NVFP4 transpose kernel auto input_cu = makeTransformerEngineTensor(input.data_ptr(), std::vector{M, K_packed}, DType::kByte); auto output_cu = makeTransformerEngineTensor(out.data_ptr(), std::vector{K, M_packed}, DType::kByte); - nvte_nvfp4_data_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + nvte_nvfp4_data_transpose(input_cu.data(), output_cu.data(), + getCurrentCUDAStreamRaw(input.get_device_index())); return out; } -void nvfp4_2d_scale_transpose(at::Tensor input, at::Tensor output, int64_t M_tiles, - int64_t K_tiles) { - init_extension(); - - // Input: rowwise_scale_inv [M_padded, K_tiles], uint8 (E4M3 stored as bytes) - // Output: columnwise_scale_inv [K_padded, M_tiles], uint8 (E4M3 stored as bytes) - const auto in_shape = getTensorShape(input); - const auto out_shape = getTensorShape(output); - NVTE_CHECK(in_shape.size() == 2, "NVFP4 scale transpose expects 2D input."); - NVTE_CHECK(out_shape.size() == 2, "NVFP4 scale transpose expects 2D output."); - NVTE_CHECK(input.scalar_type() == at::kByte, "NVFP4 scale transpose input must be uint8 (E4M3)."); - NVTE_CHECK(output.scalar_type() == at::kByte, - "NVFP4 scale transpose output must be uint8 (E4M3)."); - - auto input_cu = makeTransformerEngineTensor( - input.data_ptr(), std::vector{in_shape[0], in_shape[1]}, DType::kByte); - auto output_cu = makeTransformerEngineTensor( - output.data_ptr(), std::vector{out_shape[0], out_shape[1]}, DType::kByte); +void nvfp4_2d_scale_transpose(Tensor input, Tensor output, int64_t M_tiles, int64_t K_tiles) { + auto in_shape = getStableTensorShape(input); + auto out_shape = getStableTensorShape(output); + + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), in_shape, DType::kByte); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), out_shape, DType::kByte); nvte_nvfp4_scale_transpose(input_cu.data(), output_cu.data(), static_cast(M_tiles), - static_cast(K_tiles), at::cuda::getCurrentCUDAStream()); + static_cast(K_tiles), + getCurrentCUDAStreamRaw(input.get_device_index())); } -void nvfp4_expand_scale_to_fp8(at::Tensor input, at::Tensor output, int64_t tile_rows, - int64_t tile_cols, int64_t rows_padded, int64_t block_len) { - init_extension(); - - // Input: per_block_decode_scale [tile_rows, tile_cols], float32 - // Output: target_scale [rows_padded, tile_cols], uint8 (E4M3) - const auto in_shape = getTensorShape(input); - const auto out_shape = getTensorShape(output); - NVTE_CHECK(in_shape.size() == 2, "NVFP4 expand scale expects 2D input."); - NVTE_CHECK(out_shape.size() == 2, "NVFP4 expand scale expects 2D output."); - NVTE_CHECK(input.scalar_type() == at::kFloat, "NVFP4 expand scale input must be float32."); - NVTE_CHECK(output.scalar_type() == at::kByte, "NVFP4 expand scale output must be uint8 (E4M3)."); +void nvfp4_expand_scale_to_fp8(Tensor input, Tensor output, int64_t tile_rows, int64_t tile_cols, + int64_t rows_padded, int64_t block_len) { + auto in_shape = getStableTensorShape(input); + auto out_shape = getStableTensorShape(output); - auto input_cu = makeTransformerEngineTensor( - input.data_ptr(), std::vector{in_shape[0], in_shape[1]}, DType::kFloat32); - auto output_cu = makeTransformerEngineTensor( - output.data_ptr(), std::vector{out_shape[0], out_shape[1]}, DType::kByte); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), in_shape, DType::kFloat32); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), out_shape, DType::kByte); nvte_nvfp4_expand_scale_to_fp8(input_cu.data(), output_cu.data(), static_cast(tile_rows), static_cast(tile_cols), static_cast(rows_padded), - static_cast(block_len), at::cuda::getCurrentCUDAStream()); + static_cast(block_len), + getCurrentCUDAStreamRaw(input.get_device_index())); } -void nvfp4_compute_per_block_scale(at::Tensor block_amax, at::Tensor scale, - at::Tensor global_amax) { - init_extension(); - - // block_amax and scale: [tile_rows, tile_cols], float32 - // global_amax: single element tensor, float32 (avoids D2H transfer) - NVTE_CHECK(block_amax.scalar_type() == at::kFloat, "Block amax must be float32."); - NVTE_CHECK(scale.scalar_type() == at::kFloat, "Scale must be float32."); - NVTE_CHECK(global_amax.scalar_type() == at::kFloat, "Global amax must be float32."); - NVTE_CHECK(global_amax.numel() == 1, "Global amax must be a single element tensor."); - +void nvfp4_compute_per_block_scale(Tensor block_amax, Tensor scale, Tensor global_amax) { auto block_amax_cu = makeTransformerEngineTensor(block_amax); auto scale_cu = makeTransformerEngineTensor(scale); auto global_amax_cu = makeTransformerEngineTensor(global_amax); nvte_nvfp4_compute_per_block_scale(block_amax_cu.data(), scale_cu.data(), global_amax_cu.data(), - at::cuda::getCurrentCUDAStream()); + getCurrentCUDAStreamRaw(block_amax.get_device_index())); } -void nvfp4_fused_scale(at::Tensor block_amax, at::Tensor global_amax, at::Tensor per_block_scale, - at::Tensor target_scale, at::Tensor target_amax, int64_t tile_rows, +void nvfp4_fused_scale(Tensor block_amax, Tensor global_amax, Tensor per_block_scale, + Tensor target_scale, Tensor target_amax, int64_t tile_rows, int64_t tile_cols, int64_t rows_padded, int64_t block_len) { - init_extension(); - - // block_amax: [tile_rows, tile_cols], float32 - // global_amax: [1], float32 - // per_block_scale: [tile_rows, tile_cols], float32 (for partial_cast) - // target_scale: [rows_padded, tile_cols], uint8 (E4M3) - // target_amax: [1], float32 - NVTE_CHECK(block_amax.scalar_type() == at::kFloat, "Block amax must be float32."); - NVTE_CHECK(global_amax.scalar_type() == at::kFloat, "Global amax must be float32."); - NVTE_CHECK(per_block_scale.scalar_type() == at::kFloat, "Per-block scale must be float32."); - NVTE_CHECK(target_scale.scalar_type() == at::kByte, "Target scale must be uint8 (E4M3)."); - NVTE_CHECK(target_amax.scalar_type() == at::kFloat, "Target amax must be float32."); - NVTE_CHECK(global_amax.numel() == 1, "Global amax must be a single element tensor."); - NVTE_CHECK(target_amax.numel() == 1, "Target amax must be a single element tensor."); - auto block_amax_cu = makeTransformerEngineTensor(block_amax); auto global_amax_cu = makeTransformerEngineTensor(global_amax); auto per_block_scale_cu = makeTransformerEngineTensor(per_block_scale); @@ -192,164 +127,63 @@ void nvfp4_fused_scale(at::Tensor block_amax, at::Tensor global_amax, at::Tensor target_scale_cu.data(), target_amax_cu.data(), static_cast(tile_rows), static_cast(tile_cols), static_cast(rows_padded), static_cast(block_len), - at::cuda::getCurrentCUDAStream()); + getCurrentCUDAStreamRaw(block_amax.get_device_index())); } -void nvfp4_multi_tensor_fused_scale( - std::vector block_amax_list, std::vector global_amax_list, - std::vector per_block_scale_list, std::vector target_scale_list, - std::vector target_amax_list, std::vector tile_rows_list, - std::vector tile_cols_list, std::vector rows_padded_list, int64_t block_len) { - init_extension(); - - const size_t num_tensors = block_amax_list.size(); - NVTE_CHECK(global_amax_list.size() == num_tensors, "global_amax_list size mismatch"); - NVTE_CHECK(per_block_scale_list.size() == num_tensors, "per_block_scale_list size mismatch"); - NVTE_CHECK(target_scale_list.size() == num_tensors, "target_scale_list size mismatch"); - NVTE_CHECK(target_amax_list.size() == num_tensors, "target_amax_list size mismatch"); - NVTE_CHECK(tile_rows_list.size() == num_tensors, "tile_rows_list size mismatch"); - NVTE_CHECK(tile_cols_list.size() == num_tensors, "tile_cols_list size mismatch"); - NVTE_CHECK(rows_padded_list.size() == num_tensors, "rows_padded_list size mismatch"); - - if (num_tensors == 0) { - return; - } - - auto stream = at::cuda::getCurrentCUDAStream(); - - for (size_t i = 0; i < num_tensors; ++i) { - const auto& block_amax = block_amax_list[i]; - const auto& global_amax = global_amax_list[i]; - auto& per_block_scale = per_block_scale_list[i]; - auto& target_scale = target_scale_list[i]; - auto& target_amax = target_amax_list[i]; - const size_t tile_rows = static_cast(tile_rows_list[i]); - const size_t tile_cols = static_cast(tile_cols_list[i]); - const size_t rows_padded = static_cast(rows_padded_list[i]); - - NVTE_CHECK(block_amax.scalar_type() == at::kFloat, "Block amax must be float32."); - NVTE_CHECK(global_amax.scalar_type() == at::kFloat, "Global amax must be float32."); - NVTE_CHECK(per_block_scale.scalar_type() == at::kFloat, "Per-block scale must be float32."); - NVTE_CHECK(target_scale.scalar_type() == at::kByte, "Target scale must be uint8 (E4M3)."); - NVTE_CHECK(target_amax.scalar_type() == at::kFloat, "Target amax must be float32."); - NVTE_CHECK(global_amax.numel() == 1, "Global amax must be a single element tensor."); - NVTE_CHECK(target_amax.numel() == 1, "Target amax must be a single element tensor."); - - auto block_amax_cu = makeTransformerEngineTensor(block_amax); - auto global_amax_cu = makeTransformerEngineTensor(global_amax); - auto per_block_scale_cu = makeTransformerEngineTensor(per_block_scale); - auto target_scale_cu = makeTransformerEngineTensor(target_scale); - auto target_amax_cu = makeTransformerEngineTensor(target_amax); - - nvte_nvfp4_fused_scale(block_amax_cu.data(), global_amax_cu.data(), per_block_scale_cu.data(), - target_scale_cu.data(), target_amax_cu.data(), tile_rows, tile_cols, - rows_padded, static_cast(block_len), stream); - } -} - -void nvfp4_compute_global_scale(at::Tensor global_amax, at::Tensor global_scale) { - init_extension(); - - // global_amax and global_scale: [num_params], float32 - NVTE_CHECK(global_amax.scalar_type() == at::kFloat, "Global amax must be float32."); - NVTE_CHECK(global_scale.scalar_type() == at::kFloat, "Global scale must be float32."); - +void nvfp4_compute_global_scale(Tensor global_amax, Tensor global_scale) { auto global_amax_cu = makeTransformerEngineTensor(global_amax); auto global_scale_cu = makeTransformerEngineTensor(global_scale); nvte_nvfp4_compute_global_scale(global_amax_cu.data(), global_scale_cu.data(), - at::cuda::getCurrentCUDAStream()); + getCurrentCUDAStreamRaw(global_amax.get_device_index())); } -at::Tensor swap_first_dims(at::Tensor tensor, std::optional out) { - init_extension(); - - // Make sure input is contiguous - const auto& input = tensor.contiguous(); - - // Allocate output tensor if needed - if (!out) { - auto in_shape = getTensorShape(input); - NVTE_CHECK(in_shape.size() >= 2, "Invalid input tensor dimensions (shape=", in_shape, ")"); - std::vector out_shape_int64(in_shape.begin(), in_shape.end()); - out_shape_int64[0] = static_cast(in_shape[1]); - out_shape_int64[1] = static_cast(in_shape[0]); - auto opts = at::TensorOptions().dtype(input.dtype()).device(input.device()); - out = at::empty(out_shape_int64, opts); +Tensor swap_first_dims(Tensor tensor, std::optional out) { + auto input = torch::stable::contiguous(tensor); + auto shape = getStableTensorShape(input); + NVTE_CHECK(shape.size() >= 2, "Invalid input tensor dimensions."); + + if (!out.has_value()) { + std::vector out_shape(shape.begin(), shape.end()); + out_shape[0] = static_cast(shape[1]); + out_shape[1] = static_cast(shape[0]); + out = allocateStableTensor(out_shape, input.scalar_type(), input.get_device_index()); } - // Launch kernel - const TensorWrapper te_input = makeTransformerEngineTensor(input); - TensorWrapper te_output = makeTransformerEngineTensor(*out); - nvte_swap_first_dims(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); + auto te_input = makeTransformerEngineTensor(input); + auto te_output = makeTransformerEngineTensor(out.value()); + nvte_swap_first_dims(te_input.data(), te_output.data(), + getCurrentCUDAStreamRaw(input.get_device_index())); - return std::move(*out); + return out.value(); } -void nvfp4_2d_multi_tensor_transpose(std::vector rowwise_data_list, - std::vector columnwise_data_list, - std::vector rowwise_scale_inv_list, - std::vector columnwise_scale_inv_list, - std::vector M_list, std::vector K_list) { - init_extension(); - - const size_t num_tensors = rowwise_data_list.size(); - NVTE_CHECK(columnwise_data_list.size() == num_tensors, "Tensor list size mismatch"); - NVTE_CHECK(rowwise_scale_inv_list.size() == num_tensors, "Tensor list size mismatch"); - NVTE_CHECK(columnwise_scale_inv_list.size() == num_tensors, "Tensor list size mismatch"); - NVTE_CHECK(M_list.size() == num_tensors, "M_list size mismatch"); - NVTE_CHECK(K_list.size() == num_tensors, "K_list size mismatch"); - - if (num_tensors == 0) { - return; - } - - auto stream = at::cuda::getCurrentCUDAStream(); - - // Process each tensor - the main benefit is reduced Python overhead - // by doing the iteration in C++ rather than Python - constexpr size_t TILE_SIZE = 16; - - for (size_t i = 0; i < num_tensors; ++i) { - const auto& rowwise_data = rowwise_data_list[i]; - auto& columnwise_data = columnwise_data_list[i]; - const auto& rowwise_scale_inv = rowwise_scale_inv_list[i]; - auto& columnwise_scale_inv = columnwise_scale_inv_list[i]; - const int64_t M = M_list[i]; - const int64_t K = K_list[i]; - - // Transpose data: [M, K/2] -> [K, M/2] - const auto data_shape = getTensorShape(rowwise_data); - NVTE_CHECK(data_shape.size() == 2, "NVFP4 data must be 2D."); - const size_t M_packed = static_cast(M) / 2; - const size_t K_packed = data_shape[1]; - - auto input_cu = makeTransformerEngineTensor( - rowwise_data.data_ptr(), std::vector{static_cast(M), K_packed}, - DType::kByte); - auto output_cu = makeTransformerEngineTensor( - columnwise_data.data_ptr(), std::vector{static_cast(K), M_packed}, - DType::kByte); - nvte_nvfp4_data_transpose(input_cu.data(), output_cu.data(), stream); - - // Transpose scales - const size_t M_tiles = (static_cast(M) + TILE_SIZE - 1) / TILE_SIZE; - const size_t K_tiles = (static_cast(K) + TILE_SIZE - 1) / TILE_SIZE; - - const auto scale_in_shape = getTensorShape(rowwise_scale_inv); - const auto scale_out_shape = getTensorShape(columnwise_scale_inv); - - auto scale_input_cu = makeTransformerEngineTensor( - rowwise_scale_inv.data_ptr(), std::vector{scale_in_shape[0], scale_in_shape[1]}, - DType::kByte); - auto scale_output_cu = makeTransformerEngineTensor( - columnwise_scale_inv.data_ptr(), - std::vector{scale_out_shape[0], scale_out_shape[1]}, DType::kByte); - - nvte_nvfp4_scale_transpose(scale_input_cu.data(), scale_output_cu.data(), M_tiles, K_tiles, - stream); - } +} // namespace transformer_engine::pytorch::stable + +STABLE_TORCH_LIBRARY_FRAGMENT(transformer_engine_stable, m) { + m.def("fp8_transpose(Tensor input, int otype, Tensor? output) -> Tensor"); + m.def("nvfp4_data_transpose(Tensor input, Tensor? output) -> Tensor"); + m.def("nvfp4_2d_scale_transpose(Tensor input, Tensor output, int M_tiles, int K_tiles) -> ()"); + m.def( + "nvfp4_expand_scale_to_fp8(Tensor input, Tensor output, int tile_rows, int tile_cols, int " + "rows_padded, int block_len) -> ()"); + m.def("nvfp4_compute_per_block_scale(Tensor block_amax, Tensor scale, Tensor global_amax) -> ()"); + m.def( + "nvfp4_fused_scale(Tensor block_amax, Tensor global_amax, Tensor per_block_scale, Tensor " + "target_scale, Tensor target_amax, int tile_rows, int tile_cols, int rows_padded, int " + "block_len) -> ()"); + m.def("nvfp4_compute_global_scale(Tensor global_amax, Tensor global_scale) -> ()"); + m.def("swap_first_dims(Tensor tensor, Tensor? out) -> Tensor"); } -} // namespace pytorch -} // namespace transformer_engine +STABLE_TORCH_LIBRARY_IMPL(transformer_engine_stable, CUDA, m) { + using namespace transformer_engine::pytorch::stable; + m.impl("fp8_transpose", TORCH_BOX(fp8_transpose)); + m.impl("nvfp4_data_transpose", TORCH_BOX(nvfp4_data_transpose)); + m.impl("nvfp4_2d_scale_transpose", TORCH_BOX(nvfp4_2d_scale_transpose)); + m.impl("nvfp4_expand_scale_to_fp8", TORCH_BOX(nvfp4_expand_scale_to_fp8)); + m.impl("nvfp4_compute_per_block_scale", TORCH_BOX(nvfp4_compute_per_block_scale)); + m.impl("nvfp4_fused_scale", TORCH_BOX(nvfp4_fused_scale)); + m.impl("nvfp4_compute_global_scale", TORCH_BOX(nvfp4_compute_global_scale)); + m.impl("swap_first_dims", TORCH_BOX(swap_first_dims)); +} diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h deleted file mode 100644 index 9e640537f9..0000000000 --- a/transformer_engine/pytorch/csrc/pybind.h +++ /dev/null @@ -1,121 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#define PYBIND11_DETAILED_ERROR_MESSAGES // TODO remove - -#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_ -#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_ - -#include -#include -#include -#include -#include - -#include "common.h" -#include "transformer_engine/transformer_engine.h" - -namespace transformer_engine::pytorch { - -#define NVTE_SCOPED_GIL_RELEASE(code_block) \ - do { \ - if (PyGILState_Check()) { \ - pybind11::gil_scoped_release _gil_release; \ - code_block \ - } else { \ - code_block \ - } \ - } while (false); - -extern PyTypeObject *Float8TensorPythonClass; -extern PyTypeObject *Float8TensorStoragePythonClass; -extern PyTypeObject *Float8QuantizerClass; -extern PyTypeObject *Float8CurrentScalingQuantizerClass; -extern PyTypeObject *MXFP8TensorPythonClass; -extern PyTypeObject *MXFP8TensorStoragePythonClass; -extern PyTypeObject *MXFP8QuantizerClass; -extern PyTypeObject *Float8BlockwiseQTensorPythonClass; -extern PyTypeObject *Float8BlockwiseQTensorStoragePythonClass; -extern PyTypeObject *Float8BlockwiseQuantizerClass; -extern PyTypeObject *NVFP4TensorPythonClass; -extern PyTypeObject *NVFP4TensorStoragePythonClass; -extern PyTypeObject *NVFP4QuantizerClass; -extern PyTypeObject *GroupedTensorPythonClass; -extern PyTypeObject *GroupedTensorStoragePythonClass; - -void init_extension(); - -namespace detail { - -inline bool IsFloat8Quantizers(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; } - -inline bool IsFloat8CurrentScalingQuantizers(PyObject *obj) { - return Py_TYPE(obj) == Float8CurrentScalingQuantizerClass; -} - -inline bool IsFloat8Tensor(PyObject *obj) { - return Py_TYPE(obj) == Float8TensorPythonClass || Py_TYPE(obj) == Float8TensorStoragePythonClass; -} - -inline bool IsMXFP8Quantizers(PyObject *obj) { return Py_TYPE(obj) == MXFP8QuantizerClass; } - -inline bool IsMXFP8Tensor(PyObject *obj) { - return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorStoragePythonClass; -} - -inline bool IsFloat8BlockwiseQuantizers(PyObject *obj) { - return Py_TYPE(obj) == Float8BlockwiseQuantizerClass; -} - -inline bool IsNVFP4Quantizers(PyObject *obj) { return Py_TYPE(obj) == NVFP4QuantizerClass; } - -inline bool IsFloat8BlockwiseQTensor(PyObject *obj) { - return Py_TYPE(obj) == Float8BlockwiseQTensorPythonClass || - Py_TYPE(obj) == Float8BlockwiseQTensorStoragePythonClass; -} - -inline bool IsNVFP4Tensor(PyObject *obj) { - return Py_TYPE(obj) == NVFP4TensorPythonClass || Py_TYPE(obj) == NVFP4TensorStoragePythonClass; -} - -TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer); - -template -std::unique_ptr CreateQuantizer(const py::handle quantizer) { - return std::make_unique(quantizer); -} - -TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantization_params); - -std::unique_ptr CreateMXFP8Params(const py::handle params); - -TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, - Quantizer *quantization_params); - -TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer); - -GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor); - -inline bool IsFloatingPointType(at::ScalarType type) { - return type == at::kFloat || type == at::kHalf || type == at::kBFloat16; -} - -constexpr std::array custom_types_converters = { - std::make_tuple(IsFloat8Tensor, IsFloat8Quantizers, NVTETensorFromFloat8Tensor, - CreateQuantizer), - std::make_tuple(IsFloat8Tensor, IsFloat8CurrentScalingQuantizers, NVTETensorFromFloat8Tensor, - CreateQuantizer), - std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor, - CreateQuantizer), - std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers, - NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer), - std::make_tuple(IsNVFP4Tensor, IsNVFP4Quantizers, NVTETensorFromNVFP4Tensor, - CreateQuantizer)}; -} // namespace detail - -} // namespace transformer_engine::pytorch - -#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_ diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp deleted file mode 100644 index b59f3fa3c5..0000000000 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ /dev/null @@ -1,2465 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include - -#include "common.h" -#include "common/util/system.h" -#include "pybind.h" -#include "torch/torch.h" - -namespace transformer_engine::pytorch { - -namespace { - -/*! @brief Transposed tensor shape - * - * The tensor is interpreted as a 2D matrix by flattening all but the - * last dimension, and then transposed. - */ -template -std::vector make_transpose_shape(const std::vector& shape) { - std::vector ret; - if (shape.size() > 0) { - ret.push_back(shape.back()); - for (size_t i = 0; i < shape.size() - 1; ++i) { - ret.push_back(shape[i]); - } - } - return ret; -} - -/*! @brief Calculate stride from shape for contiguous tensors */ -template -std::vector stride_from_shape(const std::vector& shape) { - std::vector stride; - if (shape.empty()) { - return stride; - } - std::vector rstride; - rstride.reserve(shape.size()); - rstride.push_back(static_cast(1)); - for (size_t i = shape.size(); i > 1; --i) { - rstride.push_back(rstride.back() * shape[i - 1]); - } - stride.assign(rstride.rbegin(), rstride.rend()); - return stride; -} - -/*! @brief Convert shape for FP4 data by dividing the last dimension by 2 */ -template -std::vector convert_shape_for_fp4(const std::vector& shape) { - std::vector ret; - for (size_t i = 0; i < shape.size() - 1; ++i) { - ret.push_back(shape[i]); - } - ret.push_back(shape.back() / 2); - return ret; -} - -std::optional build_grouped_tensor_offsets(const size_t num_tensors, - const std::optional& first_dims, - const size_t logical_last_dim) { - if (!first_dims.has_value()) { - return std::nullopt; - } - - const auto& first_dims_tensor = first_dims.value(); - NVTE_CHECK(first_dims_tensor.is_cuda(), "first_dims must be on CUDA."); - NVTE_CHECK(first_dims_tensor.scalar_type() == at::kLong, "first_dims must have dtype int64."); - NVTE_CHECK(static_cast(first_dims_tensor.numel()) == num_tensors, - "first_dims must have length ", num_tensors, "."); - - const int64_t logical_last_dim_i64 = static_cast(logical_last_dim); - const auto first_dims_contiguous = first_dims_tensor.contiguous(); - auto tensor_offsets = - at::empty({static_cast(num_tensors) + 1}, first_dims_contiguous.options()); - NVTE_SCOPED_GIL_RELEASE({ - nvte_splits_to_offsets(static_cast(first_dims_contiguous.data_ptr()), - static_cast(tensor_offsets.data_ptr()), num_tensors, - logical_last_dim_i64, at::cuda::getCurrentCUDAStream()); - }); - return tensor_offsets; -} - -at::TensorOptions grouped_tensor_data_options(const DType dtype) { - return at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA); -} - -py::object maybe_tensor_to_py(const std::optional& tensor) { - return tensor ? py::cast(*tensor) : py::none(); -} - -py::handle grouped_tensor_python_class(const bool internal) { - PyTypeObject* cls = internal ? GroupedTensorStoragePythonClass : GroupedTensorPythonClass; - return py::handle(reinterpret_cast(cls)); -} - -} // namespace - -constexpr size_t NVFP4_BLOCK_SIZE = 16; -constexpr size_t MXFP8_BLOCK_SIZE = 32; - -Quantizer::Quantizer(const py::handle& quantizer) { - if (quantizer.is_none()) { - this->rowwise_usage = true; - this->columnwise_usage = true; - this->internal = false; - this->optimize_for_gemm = false; - } else { - this->rowwise_usage = quantizer.attr("rowwise_usage").cast(); - this->columnwise_usage = quantizer.attr("columnwise_usage").cast(); - this->internal = quantizer.attr("internal").cast(); - this->optimize_for_gemm = quantizer.attr("optimize_for_gemm").cast(); - this->quantizer = quantizer; - } -} - -Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { - const at::Tensor& scale = quantizer.attr("scale").cast(); - const at::Tensor& amax = quantizer.attr("amax").cast(); - const DType type = quantizer.attr("dtype").cast(); - - this->amax = amax; - this->scale = scale; - this->dtype = type; -} - -std::pair NoneQuantizer::create_tensor(const std::vector& shape, - DType dtype) const { - const std::vector shape_int64(shape.begin(), shape.end()); - const auto opts = at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA); - return create_tensor(shape, dtype, at::empty(shape_int64, opts)); -} - -std::pair NoneQuantizer::create_tensor(const std::vector& shape, - DType dtype, - at::Tensor data) const { - TensorWrapper out_cpp; - out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape); - set_quantization_params(&out_cpp); - return {std::move(out_cpp), py::cast(data)}; -} - -std::pair NoneQuantizer::create_grouped_tensor( - const size_t num_tensors, const std::vector& logical_shape, const DType dtype, - py::object quantizer, const std::optional& first_dims, - const size_t logical_first_dim, const size_t logical_last_dim) const { - using namespace pybind11::literals; - - const auto tensor_offsets = - build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); - const int64_t total_elements = - static_cast(logical_first_dim) * static_cast(logical_last_dim); - - std::optional rowwise_data; - std::optional columnwise_data; - const bool with_rowwise_data = rowwise_usage; - const bool with_columnwise_data = columnwise_usage; - if (with_rowwise_data) { - rowwise_data = at::empty({total_elements}, grouped_tensor_data_options(dtype)); - } - if (with_columnwise_data) { - columnwise_data = at::empty({total_elements}, grouped_tensor_data_options(dtype)); - } - - GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); - if (with_rowwise_data) { - out_cpp.set_rowwise_data(rowwise_data->data_ptr(), dtype, getTensorShape(*rowwise_data)); - } - if (with_columnwise_data) { - out_cpp.set_columnwise_data(columnwise_data->data_ptr(), dtype, - getTensorShape(*columnwise_data)); - } - if (first_dims.has_value()) { - out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); - } - if (tensor_offsets.has_value()) { - out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, - getTensorShape(*tensor_offsets)); - } - - py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); - py::dict kwargs; - py::tuple args(0); - const std::vector grouped_shape = {static_cast(logical_first_dim), - static_cast(logical_last_dim)}; - const std::vector grouped_stride = stride_from_shape(grouped_shape); - kwargs["shape"] = py::cast(grouped_shape); - kwargs["stride"] = py::cast(grouped_stride); - kwargs["dtype"] = py::cast(GetATenDType(dtype)); - kwargs["num_tensors"] = py::cast(num_tensors); - kwargs["quantizer"] = quantizer; - kwargs["data"] = maybe_tensor_to_py(rowwise_data); - kwargs["columnwise_data"] = maybe_tensor_to_py(columnwise_data); - kwargs["scale_inv"] = py::none(); - kwargs["columnwise_scale_inv"] = py::none(); - kwargs["amax"] = py::none(); - kwargs["columnwise_amax"] = py::none(); - kwargs["scale"] = py::none(); - kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); - kwargs["last_dims"] = py::none(); - kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); - kwargs["with_gemm_swizzled_scales"] = py::cast(false); - PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); - if (result == nullptr) { - PyErr_Print(); - } - NVTE_CHECK(result != nullptr, "Failed to create GroupedTensor instance"); - py::object out_py = py::reinterpret_steal(result); - - return {std::move(out_cpp), std::move(out_py)}; -} - -std::pair NoneQuantizer::convert_and_update_tensor( - py::object tensor) const { - auto tensor_pyt = tensor.cast(); - TensorWrapper out_cpp; - out_cpp.set_rowwise_data(tensor_pyt.data_ptr(), - GetTransformerEngineDType(tensor_pyt.scalar_type()), - getTensorShape(tensor_pyt)); - set_quantization_params(&out_cpp); - return {std::move(out_cpp), std::move(tensor)}; -} - -void NoneQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag) { - NVTE_ERROR("NoneQuantizer does not support quantization"); -} - -void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { - tensor->set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()), - getTensorShape(scale)); - at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); - tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), - getTensorShape(amax)); -} - -std::pair Float8Quantizer::create_tensor( - const std::vector& shape, DType dtype) const { - const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - at::Tensor scale_inv = at::empty(std::vector{1}, opts); - return create_tensor(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv)); -} - -std::pair Float8Quantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional data, - std::optional transpose, std::optional scale_inv) const { - using namespace pybind11::literals; - int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); - // Initialize data tensor - const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; - if (with_data && !data) { - const std::vector shape_int64(shape.begin(), shape.end()); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - data = at::empty(shape_int64, opts); - } else if (!with_data && data) { - data.reset(); - } - py::object data_py = with_data ? py::cast(*data) : py::none(); - - // Initialize transpose tensor - const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; - if (with_transpose && !transpose) { - const auto transpose_shape = make_transpose_shape(shape); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - transpose = at::empty(transpose_shape, opts); - } else if (!with_transpose && transpose) { - transpose.reset(); - } - py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); - // Initialize scale-inverse tensor - if (!scale_inv) { - scale_inv = at::reciprocal(scale); - } - py::object scale_inv_py = py::cast(*scale_inv); - at::Device device = - with_data ? data->device() - : (with_transpose ? transpose->device() - : at::Device(torch::kCUDA, c10::cuda::current_device())); - // Construct Python FP8 tensor - py::object out_py; - if (internal) { - // Use direct C API call bypassing pybind11 overhead - py::dict kwargs; - py::tuple args(0); - kwargs["data"] = data_py; - kwargs["fp8_scale_inv"] = scale_inv_py; - kwargs["fp8_dtype"] = py::cast(this->dtype); - kwargs["data_transpose"] = transpose_py; - kwargs["quantizer"] = this->quantizer; - kwargs["fake_dtype"] = GetATenDType(dtype); - - PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorStoragePythonClass), - args.ptr(), kwargs.ptr()); - if (result == nullptr) { - PyErr_Print(); - } - NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); - out_py = py::reinterpret_steal(result); - } else { - const std::vector shape_int64(shape.begin(), shape.end()); - const auto stride_int64 = stride_from_shape(shape_int64); - - // Use direct C API call bypassing pybind11 overhead - py::dict kwargs; - py::tuple args(0); - kwargs["shape"] = py::cast(shape_int64); - kwargs["stride"] = py::cast(stride_int64); - kwargs["dtype"] = py::cast(GetATenDType(dtype)); - kwargs["data"] = data_py; - kwargs["fp8_scale_inv"] = scale_inv_py; - kwargs["fp8_dtype"] = py::cast(this->dtype); - kwargs["data_transpose"] = transpose_py; - kwargs["quantizer"] = this->quantizer; - kwargs["device"] = py::cast(device); - PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorPythonClass), - args.ptr(), kwargs.ptr()); - if (result == nullptr) { - PyErr_Print(); - } - - NVTE_CHECK(result != nullptr, "Failed to create Float8Tensor instance"); - out_py = py::reinterpret_steal(result); - } - - // Construct C++ FP8 tensor - TensorWrapper out_cpp(this->get_scaling_mode()); - if (with_data) { - out_cpp.set_rowwise_data(data->data_ptr(), this->dtype, shape); - out_cpp.set_rowwise_scale_inv(scale_inv->data_ptr(), DType::kFloat32, std::vector{1}); - } - if (with_transpose) { - const auto transpose_shape = make_transpose_shape(shape); - out_cpp.set_columnwise_data(transpose->data_ptr(), this->dtype, transpose_shape); - out_cpp.set_columnwise_scale_inv(scale_inv->data_ptr(), DType::kFloat32, - std::vector{1}); - } - this->set_quantization_params(&out_cpp); - - return {std::move(out_cpp), std::move(out_py)}; -} - -std::pair Float8Quantizer::create_grouped_tensor( - const size_t num_tensors, const std::vector& logical_shape, const DType dtype, - py::object quantizer, const std::optional& first_dims, - const size_t logical_first_dim, const size_t logical_last_dim) const { - using namespace pybind11::literals; - - const auto tensor_offsets = - build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); - const int64_t total_elements = - static_cast(logical_first_dim) * static_cast(logical_last_dim); - - const auto uint8_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - const auto float_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - - std::optional rowwise_data; - std::optional columnwise_data; - std::optional rowwise_scale_inv; - std::optional columnwise_scale_inv; - at::Tensor amax = at::empty({static_cast(num_tensors)}, float_opts); - - if (rowwise_usage) { - rowwise_data = at::empty({total_elements}, uint8_opts); - rowwise_scale_inv = at::empty({static_cast(num_tensors)}, float_opts); - } - if (columnwise_usage) { - columnwise_data = at::empty({total_elements}, uint8_opts); - columnwise_scale_inv = at::empty({static_cast(num_tensors)}, float_opts); - } - - GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); - if (rowwise_usage) { - out_cpp.set_rowwise_data(rowwise_data->data_ptr(), this->dtype, getTensorShape(*rowwise_data)); - out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat32, - getTensorShape(*rowwise_scale_inv)); - } - if (columnwise_usage) { - out_cpp.set_columnwise_data(columnwise_data->data_ptr(), this->dtype, - getTensorShape(*columnwise_data)); - out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat32, - getTensorShape(*columnwise_scale_inv)); - } - out_cpp.set_amax(amax.data_ptr(), DType::kFloat32, getTensorShape(amax)); - if (first_dims.has_value()) { - out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); - } - if (tensor_offsets.has_value()) { - out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, - getTensorShape(*tensor_offsets)); - } - - py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); - py::dict kwargs; - py::tuple args(0); - const std::vector grouped_shape = {static_cast(logical_first_dim), - static_cast(logical_last_dim)}; - const std::vector grouped_stride = stride_from_shape(grouped_shape); - kwargs["shape"] = py::cast(grouped_shape); - kwargs["stride"] = py::cast(grouped_stride); - kwargs["dtype"] = py::cast(GetATenDType(dtype)); - kwargs["num_tensors"] = py::cast(num_tensors); - kwargs["quantizer"] = quantizer; - kwargs["data"] = maybe_tensor_to_py(rowwise_data); - kwargs["columnwise_data"] = maybe_tensor_to_py(columnwise_data); - kwargs["scale_inv"] = maybe_tensor_to_py(rowwise_scale_inv); - kwargs["columnwise_scale_inv"] = maybe_tensor_to_py(columnwise_scale_inv); - kwargs["amax"] = amax; - kwargs["columnwise_amax"] = py::none(); - kwargs["scale"] = py::none(); - kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); - kwargs["last_dims"] = py::none(); - kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); - kwargs["with_gemm_swizzled_scales"] = py::cast(false); - PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); - if (result == nullptr) { - PyErr_Print(); - } - NVTE_CHECK(result != nullptr, "Failed to create GroupedTensor instance"); - py::object out_py = py::reinterpret_steal(result); - - return {std::move(out_cpp), std::move(out_py)}; -} - -std::pair Float8Quantizer::convert_and_update_tensor( - py::object tensor) const { - NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor."); - int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); - // Expected buffers - const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; - const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; - NVTE_CHECK(need_data || need_transpose, "Invalid usages for Float8Quantizer."); - - // Extract buffers from Python tensor - auto data_py = tensor.attr("_data"); - auto transpose_py = tensor.attr("_transpose"); - const bool has_data = !data_py.is_none(); - const bool has_transpose = !transpose_py.is_none(); - NVTE_CHECK(has_data || has_transpose, "Float8Tensor has no data."); - std::optional data_tensor, transpose_tensor; - if (has_data) { - data_tensor = data_py.cast(); - } - if (has_transpose) { - transpose_tensor = transpose_py.cast(); - } - at::Tensor scale_inv_tensor = tensor.attr("_scale_inv").cast(); - - // Tensor dimensions - std::vector shape; - if (has_transpose) { - const auto transpose_shape = getTensorShape(*transpose_tensor); - if (transpose_shape.size() > 0) { - for (size_t i = 1; i < transpose_shape.size(); ++i) { - shape.push_back(transpose_shape[i]); - } - shape.push_back(transpose_shape.front()); - } - if (has_data) { - auto expected_shape = getTensorShape(*data_tensor); - NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, - ") and transpose (shape=", transpose_shape, ") do not match"); - } - } else { // Already checked has_data == true - shape = getTensorShape(*data_tensor); - } - - // Coerce data tensor - if (has_data && !need_data) { - data_tensor.reset(); - data_py = py::none(); - tensor.attr("_data") = data_py; - } else if (!has_data && need_data) { - const std::vector shape_int64(shape.begin(), shape.end()); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - data_tensor = at::empty(shape_int64, opts); - data_py = py::cast(data_tensor); - tensor.attr("_data") = data_py; - } - - // Coerce transpose tensor - if (has_transpose && !need_transpose) { - transpose_tensor.reset(); - transpose_py = py::none(); - tensor.attr("_transpose") = transpose_py; - } else if (!has_transpose && need_transpose) { - const auto transpose_shape = make_transpose_shape(shape); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - transpose_tensor = at::empty(transpose_shape, opts); - transpose_py = py::cast(transpose_tensor); - tensor.attr("_transpose") = transpose_py; - } - tensor.attr("_transpose_invalid") = !need_transpose; - - // Coerce other attrs - tensor.attr("_fp8_dtype") = dtype; - - // Construct C++ FP8 tensor - TensorWrapper out_cpp; - if (data_tensor) { - out_cpp.set_rowwise_data(data_tensor->data_ptr(), this->dtype, shape); - out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, - std::vector{1}); - } - if (transpose_tensor) { - const auto transpose_shape = make_transpose_shape(shape); - out_cpp.set_columnwise_data(transpose_tensor->data_ptr(), this->dtype, transpose_shape); - out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, - std::vector{1}); - } - this->set_quantization_params(&out_cpp); - - return {std::move(out_cpp), std::move(tensor)}; -} - -void Float8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag) { - if (input.numel() == 0) { - return; - } - QuantizationConfigWrapper quant_config; - if (noop_flag) { - quant_config.set_noop_tensor(noop_flag->data()); - } - NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream()); - }); -} - -Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& quantizer) - : Quantizer(quantizer) { - const at::Tensor& scale = quantizer.attr("scale").cast(); - const at::Tensor& amax = quantizer.attr("amax").cast(); - const DType type = quantizer.attr("dtype").cast(); - this->amax = amax; - this->scale = scale; - this->dtype = type; - - // Get amax reduction group if needed - const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast(); - c10::intrusive_ptr amax_reduction_group; - if (with_amax_reduction) { - auto group = quantizer.attr("_canonicalized_amax_reduction_group")(); - NVTE_CHECK(!group.is_none(), - "Float8CurrentScalingQuantizer could not canonicalize amax reduction group"); - amax_reduction_group = group.cast>(); - } - this->with_amax_reduction = with_amax_reduction; - this->amax_reduction_group = amax_reduction_group; - - // fp8 current scaling specific quantization params - this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast(); - this->amax_epsilon = quantizer.attr("amax_epsilon").cast(); -} - -void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tensor) const { - // transfer amax and scale pointer from quantizer to output tensor (only as gpu buffer, no meaningful data in them) - tensor->set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()), - getTensorShape(scale)); - at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); - tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), - getTensorShape(amax)); -} - -std::pair Float8CurrentScalingQuantizer::create_tensor( - const std::vector& shape, DType dtype) const { - using namespace pybind11::literals; - - // Initialize data tensor - at::Tensor data_tensor; - int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); - const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; - if (with_data) { - const std::vector shape_int64(shape.begin(), shape.end()); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - data_tensor = at::empty(shape_int64, opts); - } - - // Initialize transpose tensor - at::Tensor transpose_tensor; - const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; - if (with_transpose) { - const auto transpose_shape = make_transpose_shape(shape); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - transpose_tensor = at::empty(transpose_shape, opts); - } - // Initialize scale-inverse tensor - at::Tensor scale_inv_tensor; - { - const std::vector scale_inv_shape = {1}; - const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - scale_inv_tensor = at::empty(scale_inv_shape, opts); - } - at::Device device = - with_data ? data_tensor.device() - : (with_transpose ? transpose_tensor.device() - : at::Device(torch::kCUDA, c10::cuda::current_device())); - // Construct Python FP8 tensor - py::object out_py; - py::object scale_inv_py = py::cast(scale_inv_tensor); - py::object data_py = with_data ? py::cast(data_tensor) : py::none(); - py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none(); - if (internal) { - // Use direct C API call bypassing pybind11 overhead - py::dict kwargs; - kwargs["data"] = data_py; - kwargs["fp8_scale_inv"] = scale_inv_py; - kwargs["fp8_dtype"] = py::cast(this->dtype); - kwargs["data_transpose"] = transpose_py; - kwargs["quantizer"] = this->quantizer; - kwargs["fake_dtype"] = GetATenDType(dtype); - - py::tuple args(0); - PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorStoragePythonClass), - args.ptr(), kwargs.ptr()); - if (result == nullptr) { - PyErr_Print(); - } - NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); - out_py = py::reinterpret_steal(result); - } else { - const std::vector shape_int64(shape.begin(), shape.end()); - const auto stride_int64 = stride_from_shape(shape_int64); - // Use direct C API call bypassing pybind11 overhead - py::dict kwargs; - kwargs["shape"] = py::cast(shape_int64); - kwargs["stride"] = py::cast(stride_int64); - kwargs["dtype"] = py::cast(GetATenDType(dtype)); - kwargs["data"] = data_py; - kwargs["fp8_scale_inv"] = scale_inv_py; - kwargs["fp8_dtype"] = py::cast(this->dtype); - kwargs["data_transpose"] = transpose_py; - kwargs["quantizer"] = this->quantizer; - kwargs["device"] = py::cast(device); - py::tuple args(0); - PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorPythonClass), - args.ptr(), kwargs.ptr()); - if (result == nullptr) { - PyErr_Print(); - } - - NVTE_CHECK(result != nullptr, "Failed to create Float8Tensor instance"); - out_py = py::reinterpret_steal(result); - } - - // Construct C++ FP8 tensor - TensorWrapper out_cpp(this->get_scaling_mode()); - if (with_data) { - out_cpp.set_rowwise_data(data_tensor.data_ptr(), this->dtype, shape); - out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, - std::vector{1}); - } - if (with_transpose) { - const auto transpose_shape = make_transpose_shape(shape); - out_cpp.set_columnwise_data(transpose_tensor.data_ptr(), this->dtype, transpose_shape); - out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, - std::vector{1}); - } - this->set_quantization_params(&out_cpp); - - return {std::move(out_cpp), std::move(out_py)}; -} - -std::pair Float8CurrentScalingQuantizer::create_grouped_tensor( - const size_t num_tensors, const std::vector& logical_shape, const DType dtype, - py::object quantizer, const std::optional& first_dims, - const size_t logical_first_dim, const size_t logical_last_dim) const { - using namespace pybind11::literals; - - const auto tensor_offsets = - build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); - const int64_t total_elements = - static_cast(logical_first_dim) * static_cast(logical_last_dim); - - const auto uint8_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - const auto float_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - - std::optional rowwise_data; - std::optional columnwise_data; - std::optional rowwise_scale_inv; - std::optional columnwise_scale_inv; - at::Tensor scale = at::empty({static_cast(num_tensors)}, float_opts); - at::Tensor amax = at::empty({static_cast(num_tensors)}, float_opts); - - if (rowwise_usage) { - rowwise_data = at::empty({total_elements}, uint8_opts); - rowwise_scale_inv = at::empty({static_cast(num_tensors)}, float_opts); - } - if (columnwise_usage) { - columnwise_data = at::empty({total_elements}, uint8_opts); - columnwise_scale_inv = at::empty({static_cast(num_tensors)}, float_opts); - } - - GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); - if (rowwise_usage) { - out_cpp.set_rowwise_data(rowwise_data->data_ptr(), this->dtype, getTensorShape(*rowwise_data)); - out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat32, - getTensorShape(*rowwise_scale_inv)); - } - if (columnwise_usage) { - out_cpp.set_columnwise_data(columnwise_data->data_ptr(), this->dtype, - getTensorShape(*columnwise_data)); - out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat32, - getTensorShape(*columnwise_scale_inv)); - } - out_cpp.set_scale(scale.data_ptr(), DType::kFloat32, getTensorShape(scale)); - out_cpp.set_amax(amax.data_ptr(), DType::kFloat32, getTensorShape(amax)); - if (first_dims.has_value()) { - out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); - } - if (tensor_offsets.has_value()) { - out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, - getTensorShape(*tensor_offsets)); - } - - py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); - py::dict kwargs; - py::tuple args(0); - const std::vector grouped_shape = {static_cast(logical_first_dim), - static_cast(logical_last_dim)}; - const std::vector grouped_stride = stride_from_shape(grouped_shape); - kwargs["shape"] = py::cast(grouped_shape); - kwargs["stride"] = py::cast(grouped_stride); - kwargs["dtype"] = py::cast(GetATenDType(dtype)); - kwargs["num_tensors"] = py::cast(num_tensors); - kwargs["quantizer"] = quantizer; - kwargs["data"] = maybe_tensor_to_py(rowwise_data); - kwargs["columnwise_data"] = maybe_tensor_to_py(columnwise_data); - kwargs["scale_inv"] = maybe_tensor_to_py(rowwise_scale_inv); - kwargs["columnwise_scale_inv"] = maybe_tensor_to_py(columnwise_scale_inv); - kwargs["amax"] = amax; - kwargs["columnwise_amax"] = py::none(); - kwargs["scale"] = scale; - kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); - kwargs["last_dims"] = py::none(); - kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); - kwargs["with_gemm_swizzled_scales"] = py::cast(false); - PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); - if (result == nullptr) { - PyErr_Print(); - } - NVTE_CHECK(result != nullptr, "Failed to create GroupedTensor instance"); - py::object out_py = py::reinterpret_steal(result); - - return {std::move(out_cpp), std::move(out_py)}; -} - -std::pair -Float8CurrentScalingQuantizer::create_unquantized_tensor_with_amax(const std::vector& shape, - DType dtype, - std::optional data) { - amax.zero_(); - auto out = data.has_value() ? NoneQuantizer(py::none()).create_tensor(shape, dtype, data.value()) - : NoneQuantizer(py::none()).create_tensor(shape, dtype); - TensorWrapper out_cpp = std::move(out.first); - py::object out_py = std::move(out.second); - out_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), - getTensorShape(amax)); - return {std::move(out_cpp), std::move(out_py)}; -} - -std::pair Float8CurrentScalingQuantizer::convert_and_update_tensor( - py::object tensor) const { - NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), - "Float8CurrentScalingQuantizer must output to Float8Tensor."); - int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); - // Expected buffers - const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; - const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; - NVTE_CHECK(need_data || need_transpose, "Invalid quantizer usages."); - - // Extract buffers from Python tensor - auto data_py = tensor.attr("_data"); - auto transpose_py = tensor.attr("_transpose"); - const bool has_data = !data_py.is_none(); - const bool has_transpose = !transpose_py.is_none(); - NVTE_CHECK(has_data || has_transpose, "Tensor has no data."); - std::optional data_tensor, transpose_tensor; - if (has_data) { - data_tensor = data_py.cast(); - } - if (has_transpose) { - transpose_tensor = transpose_py.cast(); - } - at::Tensor scale_inv_tensor = tensor.attr("_scale_inv").cast(); - - // Tensor dimensions - std::vector shape; - if (has_transpose) { - const auto transpose_shape = getTensorShape(*transpose_tensor); - if (transpose_shape.size() > 0) { - for (size_t i = 1; i < transpose_shape.size(); ++i) { - shape.push_back(transpose_shape[i]); - } - shape.push_back(transpose_shape.front()); - } - if (has_data) { - auto expected_shape = getTensorShape(*data_tensor); - NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, - ") and transpose (shape=", transpose_shape, ") do not match"); - } - } else { // Already checked has_data == true - shape = getTensorShape(*data_tensor); - } - - // Coerce data tensor in Python tensor - if (has_data && !need_data) { - data_tensor.reset(); - data_py = py::none(); - tensor.attr("_data") = data_py; - } else if (!has_data && need_data) { - const std::vector shape_int64(shape.begin(), shape.end()); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - data_tensor = at::empty(shape_int64, opts); - data_py = py::cast(data_tensor); - tensor.attr("_data") = data_py; - } - - // Coerce transpose tensor - if (has_transpose && !need_transpose) { - transpose_tensor.reset(); - transpose_py = py::none(); - tensor.attr("_transpose") = transpose_py; - } else if (!has_transpose && need_transpose) { - const auto transpose_shape = make_transpose_shape(shape); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - transpose_tensor = at::empty(transpose_shape, opts); - transpose_py = py::cast(transpose_tensor); - tensor.attr("_transpose") = transpose_py; - } - tensor.attr("_transpose_invalid") = !need_transpose; - - // Coerce other attrs - tensor.attr("_fp8_dtype") = dtype; - - // Construct C++ FP8 tensor - TensorWrapper out_cpp; - if (data_tensor) { - out_cpp.set_rowwise_data(data_tensor->data_ptr(), this->dtype, shape); - out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, - std::vector{1}); - } - if (transpose_tensor) { - const auto transpose_shape = make_transpose_shape(shape); - out_cpp.set_columnwise_data(transpose_tensor->data_ptr(), this->dtype, transpose_shape); - out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, - std::vector{1}); - } - this->set_quantization_params(&out_cpp); - - return {std::move(out_cpp), std::move(tensor)}; -} - -void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag, - bool compute_amax) { - auto stream = at::cuda::getCurrentCUDAStream(); - - // Nothing to be done if input is empty - if (input.numel() == 0) { - return; - } - - // Quantization configs - QuantizationConfigWrapper quant_config; - if (noop_flag) { - quant_config.set_noop_tensor(noop_flag->data()); - } - quant_config.set_force_pow_2_scales(force_pow_2_scales); - quant_config.set_amax_epsilon(amax_epsilon); - - // Compute amax - if (compute_amax) { - NVTE_SCOPED_GIL_RELEASE( - { nvte_compute_amax_with_config(input.data(), out.data(), quant_config, stream); }); - } - - // Perform amax reduction if needed - if (with_amax_reduction) { - // allreduce amax tensor - c10d::AllreduceOptions opts; - opts.reduceOp = c10d::ReduceOp::MAX; - std::vector tensors = {amax}; - NVTE_SCOPED_GIL_RELEASE({ amax_reduction_group->allreduce(tensors, opts)->wait(); }); - } - - // Compute scaling factor - NVTE_SCOPED_GIL_RELEASE({ nvte_compute_scale_from_amax(out.data(), quant_config, stream); }); - - // Cast to FP8 - out.set_amax(nullptr, DType::kFloat32, out.defaultShape); // Avoid atomic amax updates - NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); }); -} - -void Float8CurrentScalingQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag) { - this->quantize_impl(input, out, noop_flag, true); -} - -void Float8CurrentScalingQuantizer::quantize_with_amax( - TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag) { - NVTE_CHECK(input.get_amax().data_ptr == amax.data_ptr(), - "Input does not use the appropriate amax tensor"); - input.set_amax(nullptr, DType::kFloat32, input.defaultShape); - this->quantize_impl(input, out, noop_flag, false); -} - -Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { - this->dtype = quantizer.attr("dtype").cast(); - this->block_scaling_dim = quantizer.attr("block_scaling_dim").cast(); - this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast(); - this->amax_epsilon = quantizer.attr("amax_epsilon").cast(); - NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2, - "Unsupported block scaling dim."); -} - -void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {} - -std::pair Float8BlockQuantizer::create_tensor( - const std::vector& shape, DType dtype) const { - using namespace pybind11::literals; - std::vector torch_shape; - for (auto s : shape) { - torch_shape.emplace_back(static_cast(s)); - } - - TensorWrapper tensor(this->get_scaling_mode()); - at::TensorOptions opts; - at::TensorOptions scale_opts; - at::Tensor data_rowwise, data_colwise, scale_inv_rowwise, scale_inv_colwise; - opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); - scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); - - if (rowwise_usage) { - data_rowwise = at::empty(torch_shape, opts); - auto scale_shape = get_scale_shape(shape, false); - size_t sinv0 = scale_shape[0]; - size_t sinv1 = scale_shape[1]; - scale_inv_rowwise = - at::empty({static_cast(sinv0), static_cast(sinv1)}, scale_opts); - tensor.set_rowwise_data(data_rowwise.data_ptr(), this->dtype, shape); - tensor.set_rowwise_scale_inv(scale_inv_rowwise.data_ptr(), DType::kFloat32, - std::vector{sinv0, sinv1}); - } - - if (columnwise_usage) { - std::vector torch_columnwise_shape; - std::vector columnwise_shape; - NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape. Shape ", - columnwise_shape, " torch shape: ", torch_columnwise_shape); - if (torch_shape.size() > 0) { - torch_columnwise_shape.reserve(torch_shape.size()); - columnwise_shape.reserve(shape.size()); - torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); - columnwise_shape.push_back(shape[shape.size() - 1]); - for (size_t i = 0; i < torch_shape.size() - 1; ++i) { - torch_columnwise_shape.push_back(torch_shape[i]); - columnwise_shape.push_back(shape[i]); - } - } - auto scale_shape = get_scale_shape(shape, true); - size_t sinv0 = scale_shape[0]; - size_t sinv1 = scale_shape[1]; - data_colwise = at::empty(torch_columnwise_shape, opts); - scale_inv_colwise = - at::empty({static_cast(sinv0), static_cast(sinv1)}, scale_opts); - - tensor.set_columnwise_data(data_colwise.data_ptr(), this->dtype, columnwise_shape); - tensor.set_columnwise_scale_inv(scale_inv_colwise.data_ptr(), DType::kFloat32, - std::vector{sinv0, sinv1}); - } - this->set_quantization_params(&tensor); - - py::object ret; - if (internal) { - // Use direct C API call bypassing pybind11 overhead - py::dict kwargs; - kwargs["rowwise_data"] = py::cast(data_rowwise); - kwargs["columnwise_data"] = py::cast(data_colwise); - kwargs["rowwise_scale_inv"] = py::cast(scale_inv_rowwise); - kwargs["columnwise_scale_inv"] = py::cast(scale_inv_colwise); - kwargs["fp8_dtype"] = py::cast(this->dtype); - kwargs["quantizer"] = this->quantizer; - kwargs["is_2D_scaled"] = py::cast(block_scaling_dim == 2); - kwargs["fake_dtype"] = GetATenDType(dtype); - - py::tuple args(0); - PyObject* result = - PyObject_Call(reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass), - args.ptr(), kwargs.ptr()); - if (result == nullptr) { - PyErr_Print(); - } - - NVTE_CHECK(result != nullptr, "Failed to create Float8BlockwiseQTensorStorage instance"); - ret = py::reinterpret_steal(result); - } else { - // Use direct C API call bypassing pybind11 overhead - py::dict kwargs; - const auto stride_int64 = stride_from_shape(torch_shape); - kwargs["shape"] = py::cast(torch_shape); - kwargs["stride"] = py::cast(stride_int64); - kwargs["dtype"] = py::cast(GetATenDType(dtype)); - kwargs["rowwise_data"] = py::cast(data_rowwise); - kwargs["columnwise_data"] = py::cast(data_colwise); - kwargs["rowwise_scale_inv"] = py::cast(scale_inv_rowwise); - kwargs["columnwise_scale_inv"] = py::cast(scale_inv_colwise); - kwargs["fp8_dtype"] = py::cast(this->dtype); - kwargs["quantizer"] = this->quantizer; - kwargs["is_2D_scaled"] = py::cast(block_scaling_dim == 2); - - py::tuple args(0); - PyObject* result = PyObject_Call(reinterpret_cast(Float8BlockwiseQTensorPythonClass), - args.ptr(), kwargs.ptr()); - if (result == nullptr) { - PyErr_Print(); - } - NVTE_CHECK(result != nullptr, "Failed to create Float8BlockwiseQTensor instance"); - ret = py::reinterpret_steal(result); - } - - return {std::move(tensor), std::move(ret)}; -} - -std::pair Float8BlockQuantizer::create_grouped_tensor( - const size_t num_tensors, const std::vector& logical_shape, const DType dtype, - py::object quantizer, const std::optional& first_dims, - const size_t logical_first_dim, const size_t logical_last_dim) const { - using namespace pybind11::literals; - - const auto tensor_offsets = - build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); - const int64_t total_elements = - static_cast(logical_first_dim) * static_cast(logical_last_dim); - - const auto uint8_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - const auto float_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - - std::optional rowwise_data; - std::optional columnwise_data; - std::optional rowwise_scale_inv; - std::optional columnwise_scale_inv; - const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; - - if (rowwise_usage) { - rowwise_data = at::empty({total_elements}, uint8_opts); - const auto scale_shape = get_scale_shape(logical_shape_vec, false); - const int64_t total_scale_elements = static_cast(product(scale_shape)); - rowwise_scale_inv = at::empty({total_scale_elements}, float_opts); - } - - if (columnwise_usage) { - columnwise_data = at::empty({total_elements}, uint8_opts); - const auto scale_shape = get_scale_shape(logical_shape_vec, true); - const int64_t total_scale_elements = static_cast(product(scale_shape)); - columnwise_scale_inv = at::empty({total_scale_elements}, float_opts); - } - - GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); - if (rowwise_usage) { - out_cpp.set_rowwise_data(rowwise_data->data_ptr(), this->dtype, getTensorShape(*rowwise_data)); - out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat32, - getTensorShape(*rowwise_scale_inv)); - } - if (columnwise_usage) { - out_cpp.set_columnwise_data(columnwise_data->data_ptr(), this->dtype, - getTensorShape(*columnwise_data)); - out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat32, - getTensorShape(*columnwise_scale_inv)); - } - if (first_dims.has_value()) { - out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); - } - if (tensor_offsets.has_value()) { - out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, - getTensorShape(*tensor_offsets)); - } - - py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); - py::dict kwargs; - py::tuple args(0); - const std::vector grouped_shape = {static_cast(logical_first_dim), - static_cast(logical_last_dim)}; - const std::vector grouped_stride = stride_from_shape(grouped_shape); - kwargs["shape"] = py::cast(grouped_shape); - kwargs["stride"] = py::cast(grouped_stride); - kwargs["dtype"] = py::cast(GetATenDType(dtype)); - kwargs["num_tensors"] = py::cast(num_tensors); - kwargs["quantizer"] = quantizer; - kwargs["data"] = maybe_tensor_to_py(rowwise_data); - kwargs["columnwise_data"] = maybe_tensor_to_py(columnwise_data); - kwargs["scale_inv"] = maybe_tensor_to_py(rowwise_scale_inv); - kwargs["columnwise_scale_inv"] = maybe_tensor_to_py(columnwise_scale_inv); - kwargs["amax"] = py::none(); - kwargs["columnwise_amax"] = py::none(); - kwargs["scale"] = py::none(); - kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); - kwargs["last_dims"] = py::none(); - kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); - kwargs["with_gemm_swizzled_scales"] = py::cast(false); - PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); - if (result == nullptr) { - PyErr_Print(); - } - NVTE_CHECK(result != nullptr, "Failed to create GroupedTensor instance"); - py::object out_py = py::reinterpret_steal(result); - - return {std::move(out_cpp), std::move(out_py)}; -} - -std::pair Float8BlockQuantizer::convert_and_update_tensor( - py::object tensor) const { - const DType dtype = tensor.attr("_fp8_dtype").cast(); - bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast(); - const bool with_gemm_swizzled_scales = true; - - // Extract buffers from Python tensor - auto get_tensor = [&tensor](const char* name) -> std::optional { - auto attr_py = tensor.attr(name); - if (attr_py.is_none()) { - return std::nullopt; - } - return attr_py.cast(); - }; - auto rowwise_data = get_tensor("_rowwise_data"); - auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv"); - auto columnwise_data = get_tensor("_columnwise_data"); - auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv"); - NVTE_CHECK(rowwise_data || columnwise_data, "FP8BlockwiseTensor has no data."); - - // Tensor options and dimensions - at::TensorOptions opts; - at::TensorOptions scale_opts; - opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); - scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); - - auto get_columnwise_shape = [&columnwise_data]() -> std::vector { - if (!columnwise_data) { - return std::vector(); - } - std::vector shape = getTensorShape(*columnwise_data); - std::vector shape_transposed(shape.size()); - for (size_t i = 0; i + 1 < shape.size(); ++i) { - shape_transposed[i] = shape[i + 1]; - } - if (shape.size() > 0) { - shape_transposed[shape.size() - 1] = shape[0]; - } - return shape_transposed; - }; - std::vector shape; - if (rowwise_data) { - shape = getTensorShape(*rowwise_data); - if (columnwise_data) { - auto expected_shape = get_columnwise_shape(); - NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape, - ") and column-wise data (shape=", expected_shape, ") do not match"); - } - } else { - shape = get_columnwise_shape(); - } - std::vector torch_shape; - for (auto s : shape) { - torch_shape.emplace_back(static_cast(s)); - } - - // Coerce row-wise data - if (rowwise_usage) { - if (!rowwise_data) { - rowwise_data = at::empty(torch_shape, opts); - tensor.attr("_rowwise_data") = *rowwise_data; - } - if (!rowwise_scale_inv) { - auto scale_shape = get_scale_shape(shape, false); - size_t sinv0 = scale_shape[0]; - size_t sinv1 = scale_shape[1]; - rowwise_scale_inv = - at::empty({static_cast(sinv0), static_cast(sinv1)}, scale_opts); - tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; - } - } else { // rowwise_usage == false - if (rowwise_data) { - rowwise_data.reset(); - tensor.attr("_rowwise_data") = py::none(); - } - if (rowwise_scale_inv) { - rowwise_scale_inv.reset(); - tensor.attr("_rowwise_scale_inv") = py::none(); - } - } - - // Coerce column-wise data - if (columnwise_usage) { - std::vector columnwise_shape; - std::vector torch_columnwise_shape; - if (torch_shape.size() > 0) { - torch_columnwise_shape.reserve(torch_shape.size()); - columnwise_shape.reserve(shape.size()); - torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); - columnwise_shape.push_back(shape[shape.size() - 1]); - for (size_t i = 0; i < torch_shape.size() - 1; ++i) { - torch_columnwise_shape.push_back(torch_shape[i]); - columnwise_shape.push_back(shape[i]); - } - } - if (!columnwise_data) { - columnwise_data = at::empty(torch_columnwise_shape, opts); - tensor.attr("_columnwise_data") = *columnwise_data; - } - if (!columnwise_scale_inv) { - auto scale_shape = get_scale_shape(shape, true); - size_t sinv0 = scale_shape[0]; - size_t sinv1 = scale_shape[1]; - columnwise_scale_inv = - at::empty({static_cast(sinv0), static_cast(sinv1)}, scale_opts); - tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; - } - } else { // columnwise_usage == false - if (columnwise_data) { - columnwise_data.reset(); - tensor.attr("_columnwise_data") = py::none(); - } - if (columnwise_scale_inv) { - columnwise_scale_inv.reset(); - tensor.attr("_columnwise_scale_inv") = py::none(); - } - } - - auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D); - - if (rowwise_usage) { - const at::Tensor& data_rowwise = tensor.attr("_rowwise_data").cast(); - const at::Tensor& scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast(); - void* scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr(); - const auto& rowwise_shape = getTensorShape(data_rowwise); - ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, rowwise_shape); - const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); - ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape); - } - if (columnwise_usage) { - const at::Tensor& data_colwise = tensor.attr("_columnwise_data").cast(); - const at::Tensor& scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast(); - void* scale_inv_colwise_dptr = scale_inv_colwise.data_ptr(); - const auto& shape = getTensorShape(data_colwise); - ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape); - const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); - ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape); - } - ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); - set_quantization_params(&ret); - return {std::move(ret), std::move(tensor)}; -} - -void Float8BlockQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag) { - if (input.numel() == 0) { - return; - } - QuantizationConfigWrapper quant_config; - if (noop_flag) { - quant_config.set_noop_tensor(noop_flag->data()); - } - quant_config.set_force_pow_2_scales(force_pow_2_scales); - quant_config.set_amax_epsilon(amax_epsilon); - NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream()); - }); -} - -std::vector Float8BlockQuantizer::get_scale_shape(const std::vector& shape, - bool columnwise) const { - size_t numel = 1; - for (auto s : shape) { - numel *= s; - } - - size_t k_dim = shape.size() == 0 ? 1u : shape.back(); - size_t m_dim = numel / k_dim; - constexpr size_t kBlockLen = 128; - - std::vector scale_shape; - - bool rowwise_usage = !columnwise; - - if (rowwise_usage) { - // rowwise scaling factor shape - size_t sinv0 = 0; - size_t sinv1 = 0; - if (block_scaling_dim == 2) { - sinv0 = ceildiv(m_dim, kBlockLen); - sinv1 = roundup(ceildiv(k_dim, kBlockLen), 4); - } else if (block_scaling_dim == 1) { - // default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY - sinv0 = ceildiv(k_dim, kBlockLen); - sinv1 = roundup(m_dim, 4); - } else { - NVTE_ERROR( - "Unsupported block_scaling_dim in create_tensor rowwise." - "Expected 1 or 2. Got ", - block_scaling_dim); - } - scale_shape = {sinv0, sinv1}; - } else { - // columnwise scaling factor shape - size_t sinv0 = 0; - size_t sinv1 = 0; - if (block_scaling_dim == 2) { - sinv0 = ceildiv(k_dim, kBlockLen); - sinv1 = roundup(ceildiv(m_dim, kBlockLen), 4); - } else if (block_scaling_dim == 1) { - sinv0 = ceildiv(m_dim, kBlockLen); - sinv1 = roundup(k_dim, 4); - } else { - NVTE_ERROR( - "Unsupported block_scaling_dim in create_tensor columnwise." - "Expected 1 or 2. Got ", - block_scaling_dim); - } - scale_shape = {sinv0, sinv1}; - } - return scale_shape; -} - -MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { - this->dtype = quantizer.attr("dtype").cast(); -} - -void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {} - -std::pair MXFP8Quantizer::create_tensor(const std::vector& shape, - DType dtype) const { - using namespace pybind11::literals; - - // Scaling factor format - const bool with_gemm_swizzled_scales = this->optimize_for_gemm; - - // Tensor dimensions - const std::vector shape_int64(shape.begin(), shape.end()); - size_t flat_first_dim = 1; - if (shape.size() > 0) { - for (size_t i = 0; i < shape.size() - 1; ++i) { - flat_first_dim *= shape[i]; - } - } - const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; - NVTE_CHECK(flat_first_dim % MXFP8_BLOCK_SIZE == 0 && flat_last_dim % MXFP8_BLOCK_SIZE == 0, - "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, - " (got shape=", shape, ")"); - const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); - const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); - - // Allocate tensors - at::Tensor rowwise_data_tensor, rowwise_scale_inv_tensor; - at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor; - const auto uint8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - if (rowwise_usage) { - const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape.begin(), - rowwise_scale_inv_shape.end()); - rowwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); - rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts); - } - if (columnwise_usage) { - const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), - columnwise_scale_inv_shape.end()); - columnwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); - columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts); - } - - // Convert tensors to Python - auto py_cast = [](at::Tensor& tensor, bool need_cast) -> py::object { - return need_cast ? py::cast(tensor) : py::none(); - }; - auto rowwise_data_py = py_cast(rowwise_data_tensor, rowwise_usage); - auto rowwise_scale_inv_py = py_cast(rowwise_scale_inv_tensor, rowwise_usage); - auto columnwise_data_py = py_cast(columnwise_data_tensor, columnwise_usage); - auto columnwise_scale_inv_py = py_cast(columnwise_scale_inv_tensor, columnwise_usage); - - // Construct Python MXFP8 tensor - py::object out_py; - if (internal) { - // Use direct C API call bypassing pybind11 overhead - py::dict kwargs; - py::tuple args(0); - kwargs["rowwise_data"] = rowwise_data_py; - kwargs["columnwise_data"] = columnwise_data_py; - kwargs["rowwise_scale_inv"] = rowwise_scale_inv_py; - kwargs["columnwise_scale_inv"] = columnwise_scale_inv_py; - kwargs["fp8_dtype"] = py::cast(this->dtype); - kwargs["quantizer"] = this->quantizer; - kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); - kwargs["fake_dtype"] = GetATenDType(dtype); - - PyObject* result = PyObject_Call(reinterpret_cast(MXFP8TensorStoragePythonClass), - args.ptr(), kwargs.ptr()); - if (result == nullptr) { - PyErr_Print(); - } - - NVTE_CHECK(result != nullptr, "Failed to create MXFP8TensorStorage instance"); - out_py = py::reinterpret_steal(result); - } else { - // Use direct C API call bypassing pybind11 overhead - py::dict kwargs; - const auto stride_int64 = stride_from_shape(shape_int64); - kwargs["shape"] = py::cast(shape_int64); - kwargs["stride"] = py::cast(stride_int64); - kwargs["dtype"] = py::cast(GetATenDType(dtype)); - kwargs["rowwise_data"] = rowwise_data_py; - kwargs["columnwise_data"] = columnwise_data_py; - kwargs["rowwise_scale_inv"] = rowwise_scale_inv_py; - kwargs["columnwise_scale_inv"] = columnwise_scale_inv_py; - kwargs["fp8_dtype"] = py::cast(this->dtype); - kwargs["quantizer"] = this->quantizer; - kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); - - py::tuple args(0); - PyObject* result = PyObject_Call(reinterpret_cast(MXFP8TensorPythonClass), - args.ptr(), kwargs.ptr()); - if (result == nullptr) { - PyErr_Print(); - } - - NVTE_CHECK(result != nullptr, "Failed to create MXFP8Tensor instance"); - out_py = py::reinterpret_steal(result); - } - - // Construct C++ MXFP8 tensor - TensorWrapper out_cpp(NVTE_MXFP8_1D_SCALING); - if (rowwise_usage) { - out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), this->dtype, shape); - out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0, - rowwise_scale_inv_shape); - } - if (columnwise_usage) { - out_cpp.set_columnwise_data(columnwise_data_tensor.data_ptr(), this->dtype, shape); - out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0, - columnwise_scale_inv_shape); - } - out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); - this->set_quantization_params(&out_cpp); - - return {std::move(out_cpp), std::move(out_py)}; -} - -std::pair MXFP8Quantizer::create_grouped_tensor( - const size_t num_tensors, const std::vector& logical_shape, const DType dtype, - py::object quantizer, const std::optional& first_dims, - const size_t logical_first_dim, const size_t logical_last_dim) const { - using namespace pybind11::literals; - - const auto tensor_offsets = - build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); - const int64_t total_elements = - static_cast(logical_first_dim) * static_cast(logical_last_dim); - - const auto uint8_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - - std::optional rowwise_data; - std::optional columnwise_data; - std::optional rowwise_scale_inv; - std::optional columnwise_scale_inv; - const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; - - if (rowwise_usage) { - rowwise_data = at::empty({total_elements}, uint8_opts); - const auto scale_shape = get_scale_shape(logical_shape_vec, false); - const int64_t total_scale_elements = static_cast(product(scale_shape)); - rowwise_scale_inv = at::empty({total_scale_elements}, uint8_opts); - } - - if (columnwise_usage) { - columnwise_data = at::empty({total_elements}, uint8_opts); - const auto scale_shape = get_scale_shape(logical_shape_vec, true); - const int64_t total_scale_elements = static_cast(product(scale_shape)); - columnwise_scale_inv = at::empty({total_scale_elements}, uint8_opts); - } - - GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); - if (rowwise_usage) { - out_cpp.set_rowwise_data(rowwise_data->data_ptr(), this->dtype, getTensorShape(*rowwise_data)); - out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E8M0, - getTensorShape(*rowwise_scale_inv)); - } - if (columnwise_usage) { - out_cpp.set_columnwise_data(columnwise_data->data_ptr(), this->dtype, - getTensorShape(*columnwise_data)); - out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0, - getTensorShape(*columnwise_scale_inv)); - } - if (first_dims.has_value()) { - out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); - } - if (tensor_offsets.has_value()) { - out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, - getTensorShape(*tensor_offsets)); - } - - out_cpp.set_with_gemm_swizzled_scales(this->optimize_for_gemm); - - py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); - py::dict kwargs; - py::tuple args(0); - const std::vector grouped_shape = {static_cast(logical_first_dim), - static_cast(logical_last_dim)}; - const std::vector grouped_stride = stride_from_shape(grouped_shape); - kwargs["shape"] = py::cast(grouped_shape); - kwargs["stride"] = py::cast(grouped_stride); - kwargs["dtype"] = py::cast(GetATenDType(dtype)); - kwargs["num_tensors"] = py::cast(num_tensors); - kwargs["quantizer"] = quantizer; - kwargs["data"] = maybe_tensor_to_py(rowwise_data); - kwargs["columnwise_data"] = maybe_tensor_to_py(columnwise_data); - kwargs["scale_inv"] = maybe_tensor_to_py(rowwise_scale_inv); - kwargs["columnwise_scale_inv"] = maybe_tensor_to_py(columnwise_scale_inv); - kwargs["amax"] = py::none(); - kwargs["columnwise_amax"] = py::none(); - kwargs["scale"] = py::none(); - kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); - kwargs["last_dims"] = py::none(); - kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); - kwargs["with_gemm_swizzled_scales"] = this->optimize_for_gemm; - PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); - if (result == nullptr) { - PyErr_Print(); - } - NVTE_CHECK(result != nullptr, "Failed to create GroupedTensor instance"); - py::object out_py = py::reinterpret_steal(result); - - return {std::move(out_cpp), std::move(out_py)}; -} - -std::pair MXFP8Quantizer::convert_and_update_tensor( - py::object tensor) const { - NVTE_CHECK(detail::IsMXFP8Tensor(tensor.ptr()), "MXFP8Quantizer must output to MXFP8Tensor."); - - // Scaling factor format - const bool with_gemm_swizzled_scales = this->optimize_for_gemm; - - // Extract buffers from Python tensor - auto get_tensor = [&tensor](const char* name) -> std::optional { - auto attr_py = tensor.attr(name); - if (attr_py.is_none()) { - return std::nullopt; - } - return attr_py.cast(); - }; - auto rowwise_data = get_tensor("_rowwise_data"); - auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv"); - auto columnwise_data = get_tensor("_columnwise_data"); - auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv"); - NVTE_CHECK(rowwise_data || columnwise_data, "MXFP8Tensor has no data."); - - // Tensor dimensions - std::vector shape; - if (columnwise_data) { - shape = getTensorShape(*columnwise_data); - if (rowwise_data) { - auto expected_shape = getTensorShape(*rowwise_data); - NVTE_CHECK(shape == expected_shape, "MXFP8 row-wise data (shape=", expected_shape, - ") and column-wise data (shape=", shape, ") do not match"); - } - } else { // Already checked columnwise_data_tensor == true - shape = getTensorShape(*rowwise_data); - } - - // Coerce row-wise data - if (rowwise_usage) { - if (!rowwise_data) { - const std::vector shape_int64(shape.begin(), shape.end()); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - rowwise_data = at::empty(shape_int64, opts); - tensor.attr("_rowwise_data") = *rowwise_data; - } - if (!rowwise_scale_inv) { - const auto scale_inv_shape = get_scale_shape(shape, false); - const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), - scale_inv_shape.end()); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts); - tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; - } - } else { // rowwise_usage == false - if (rowwise_data) { - rowwise_data.reset(); - tensor.attr("_rowwise_data") = py::none(); - } - if (rowwise_scale_inv) { - rowwise_scale_inv.reset(); - tensor.attr("_rowwise_scale_inv") = py::none(); - } - } - - // Coerce column-wise data - if (columnwise_usage) { - if (!columnwise_data) { - const std::vector shape_int64(shape.begin(), shape.end()); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - columnwise_data = at::empty(shape_int64, opts); - tensor.attr("_columnwise_data") = *columnwise_data; - } - if (!columnwise_scale_inv) { - const auto scale_inv_shape = get_scale_shape(shape, true); - const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), - scale_inv_shape.end()); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - columnwise_scale_inv = at::empty(scale_inv_shape_int64, opts); - tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; - } - } else { // columnwise_usage == false - if (columnwise_data) { - columnwise_data.reset(); - tensor.attr("_columnwise_data") = py::none(); - } - if (columnwise_scale_inv) { - columnwise_scale_inv.reset(); - tensor.attr("_columnwise_scale_inv") = py::none(); - } - } - - // Coerce other attrs - tensor.attr("_fp8_dtype") = dtype; - tensor.attr("_with_gemm_swizzled_scales") = with_gemm_swizzled_scales; - - // Construct C++ MXFP8 tensor - TensorWrapper out_cpp(NVTE_MXFP8_1D_SCALING); - if (rowwise_usage) { - out_cpp.set_rowwise_data(rowwise_data->data_ptr(), dtype, shape); - out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E8M0, - getTensorShape(*rowwise_scale_inv)); - } - if (columnwise_usage) { - out_cpp.set_columnwise_data(columnwise_data->data_ptr(), dtype, shape); - out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0, - getTensorShape(*columnwise_scale_inv)); - } - out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); - this->set_quantization_params(&out_cpp); - - return {std::move(out_cpp), std::move(tensor)}; -} - -void MXFP8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag) { - if (input.numel() == 0) { - return; - } - QuantizationConfigWrapper quant_config; - if (noop_flag) { - quant_config.set_noop_tensor(noop_flag->data()); - } - NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream()); - }); -} - -std::vector MXFP8Quantizer::get_scale_shape(const std::vector& shape, - bool columnwise) const { - size_t numel = 1; - for (auto s : shape) { - numel *= s; - } - - auto last_dim = shape.back(); - - NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, - "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, - " (got shape=", shape, ")"); - - std::vector scale_shape; - - bool rowwise_usage = !columnwise; - - if (rowwise_usage) { - // rowwise scaling factor shape - size_t sinv0 = roundup(numel / last_dim, 128); - size_t sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); - scale_shape = {sinv0, sinv1}; - } else { - // columnwise scaling factor shape - size_t sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4); - size_t sinv1 = roundup(last_dim, 128); - scale_shape = {sinv0, sinv1}; - } - return scale_shape; -} - -NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { - this->dtype = quantizer.attr("dtype").cast(); - this->with_rht = quantizer.attr("with_rht").cast(); - this->with_post_rht_amax = quantizer.attr("with_post_rht_amax").cast(); - this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); - this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast(); - - // Get amax reduction group if needed for NVFP4 AG - const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast(); - c10::intrusive_ptr amax_reduction_group; - if (with_amax_reduction) { - auto group = quantizer.attr("_canonicalized_amax_reduction_group")(); - NVTE_CHECK(!group.is_none(), "NVFP4Quantizer could not canonicalize amax reduction group"); - amax_reduction_group = group.cast>(); - } - this->with_amax_reduction = with_amax_reduction; - this->amax_reduction_group = amax_reduction_group; - - this->rht_matrix_random_sign_mask_t = quantizer.attr("rht_matrix_random_sign_mask_t").cast(); - this->rht_matrix = quantizer.attr("rht_matrix").cast(); -} - -void NVFP4Quantizer::set_quantization_params(TensorWrapper* tensor) const { - // set dtype for rowwise and columnwise data in tensor wrapper - auto rowwise_data = tensor->get_rowwise_data(); - rowwise_data.dtype = static_cast(this->dtype); - - auto columnwise_data = tensor->get_columnwise_data(); - columnwise_data.dtype = static_cast(this->dtype); - - tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), - rowwise_data.shape); - tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); -} - -std::pair NVFP4Quantizer::create_tensor(const std::vector& shape, - DType dtype) const { - using namespace pybind11::literals; - - // Scaling factor format - const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) self->optimize_for_gemm - - // Tensor dimensions - const std::vector shape_int64(shape.begin(), shape.end()); - size_t flat_first_dim = 1; - if (shape.size() > 0) { - for (size_t i = 0; i < shape.size() - 1; ++i) { - flat_first_dim *= shape[i]; - } - } - const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; - NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, "First dim for NVFP4 must be divisible by ", - NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); - NVTE_CHECK(flat_last_dim % NVFP4_BLOCK_SIZE == 0, - "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, - " (got shape=", shape, ")"); - const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); - const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); - - // Allocate tensors - at::Tensor rowwise_data_tensor, rowwise_scale_inv_tensor, amax_rowwise; - at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor, amax_columnwise; - const auto bit8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - const auto bit32_tensor_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - if (rowwise_usage) { - const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape.begin(), - rowwise_scale_inv_shape.end()); - rowwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts); - rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); - // hadamard amax kernel will zero out pointer with ZeroAmaxKernel - // nvte_compute_amax_with_config will zero out the pointer if needed - amax_rowwise = at::empty({1}, bit32_tensor_opts); - } - if (columnwise_usage) { - const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), - columnwise_scale_inv_shape.end()); - // enforce 2D shape to avoid [S, B, H] shape and B and be 1 - // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero - std::vector shape_int64_2d = {static_cast(flat_first_dim), - static_cast(flat_last_dim)}; - const auto transpose_shape_int64 = make_transpose_shape(shape_int64_2d); - columnwise_data_tensor = - at::empty(convert_shape_for_fp4(transpose_shape_int64), bit8_tensor_opts); - columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); - // hadamard amax kernel will zero out pointer with ZeroAmaxKernel - // nvte_compute_amax_with_config will zero out the pointer if needed - amax_columnwise = at::empty({1}, bit32_tensor_opts); - } - - // Convert tensors to Python - auto py_cast = [](at::Tensor& tensor, bool need_cast) -> py::object { - return need_cast ? py::cast(tensor) : py::none(); - }; - auto rowwise_data_py = py_cast(rowwise_data_tensor, rowwise_usage); - auto rowwise_scale_inv_py = py_cast(rowwise_scale_inv_tensor, rowwise_usage); - auto columnwise_data_py = py_cast(columnwise_data_tensor, columnwise_usage); - auto columnwise_scale_inv_py = py_cast(columnwise_scale_inv_tensor, columnwise_usage); - auto amax_rowwise_py = py_cast(amax_rowwise, rowwise_usage); - auto amax_columnwise_py = py_cast(amax_columnwise, columnwise_usage); - - // Construct Python NVFP4 tensor - py::object out_py; - if (internal) { - // Use direct C API call bypassing pybind11 overhead - py::dict kwargs; - kwargs["rowwise_data"] = rowwise_data_py; - kwargs["columnwise_data"] = columnwise_data_py; - kwargs["rowwise_scale_inv"] = rowwise_scale_inv_py; - kwargs["columnwise_scale_inv"] = columnwise_scale_inv_py; - kwargs["amax_rowwise"] = amax_rowwise_py; - kwargs["amax_columnwise"] = amax_columnwise_py; - kwargs["fp4_dtype"] = py::cast(this->dtype); - kwargs["quantizer"] = this->quantizer; - kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); - kwargs["fake_dtype"] = GetATenDType(dtype); - - py::tuple args(0); - - PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorStoragePythonClass), - args.ptr(), kwargs.ptr()); - if (result == nullptr) { - PyErr_Print(); - } - - NVTE_CHECK(result != nullptr, "Failed to create NVFP4TensorStorage instance"); - out_py = py::reinterpret_steal(result); - } else { - // Use direct C API call bypassing pybind11 overhead - py::dict kwargs; - const auto stride_int64 = stride_from_shape(shape_int64); - kwargs["shape"] = py::cast(shape_int64); - kwargs["stride"] = py::cast(stride_int64); - kwargs["dtype"] = py::cast(GetATenDType(dtype)); - kwargs["rowwise_data"] = rowwise_data_py; - kwargs["columnwise_data"] = columnwise_data_py; - kwargs["rowwise_scale_inv"] = rowwise_scale_inv_py; - kwargs["columnwise_scale_inv"] = columnwise_scale_inv_py; - kwargs["amax_rowwise"] = amax_rowwise_py; - kwargs["amax_columnwise"] = amax_columnwise_py; - kwargs["fp4_dtype"] = py::cast(this->dtype); - kwargs["quantizer"] = this->quantizer; - kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); - py::tuple args(0); - PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorPythonClass), - args.ptr(), kwargs.ptr()); - if (result == nullptr) { - PyErr_Print(); - } - - NVTE_CHECK(result != nullptr, "Failed to create NVFP4Tensor instance"); - out_py = py::reinterpret_steal(result); - } - - // Construct C++ tensor - TensorWrapper out_cpp(NVTE_NVFP4_1D_SCALING); - if (rowwise_usage) { - out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), DType::kFloat4E2M1, shape); - out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, - rowwise_scale_inv_shape); - out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, std::vector{1}); - } - if (columnwise_usage) { - // enforce 2D shape to avoid [S, B, H] shape and B and be 1 - // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero - std::vector shape_2d = {flat_first_dim, flat_last_dim}; - auto col_data_shape_fp4 = make_transpose_shape(shape_2d); - out_cpp.set_columnwise_data(columnwise_data_tensor.data_ptr(), DType::kFloat4E2M1, - col_data_shape_fp4); - out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, - columnwise_scale_inv_shape); - out_cpp.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, - std::vector{1}); - } - out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); - this->set_quantization_params(&out_cpp); - - return {std::move(out_cpp), std::move(out_py)}; -} - -std::pair NVFP4Quantizer::create_grouped_tensor( - const size_t num_tensors, const std::vector& logical_shape, const DType dtype, - py::object quantizer, const std::optional& first_dims, - const size_t logical_first_dim, const size_t logical_last_dim) const { - using namespace pybind11::literals; - - const auto tensor_offsets = - build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); - const int64_t total_elements = - static_cast(logical_first_dim) * static_cast(logical_last_dim); - NVTE_CHECK(total_elements % 2 == 0, "NVFP4 data size must be divisible by 2."); - - const auto uint8_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - const auto float_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - - std::optional rowwise_data; - std::optional columnwise_data; - std::optional rowwise_scale_inv; - std::optional columnwise_scale_inv; - std::optional rowwise_amax; - std::optional columnwise_amax; - const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; - - const int64_t total_data_elements = total_elements / 2; - - if (rowwise_usage) { - rowwise_data = at::empty({total_data_elements}, uint8_opts); - const auto scale_shape = get_scale_shape(logical_shape_vec, false); - const int64_t total_scale_elements = static_cast(product(scale_shape)); - rowwise_scale_inv = at::empty({total_scale_elements}, uint8_opts); - rowwise_amax = at::empty({static_cast(num_tensors)}, float_opts); - } - - if (columnwise_usage) { - columnwise_data = at::empty({total_data_elements}, uint8_opts); - const auto scale_shape = get_scale_shape(logical_shape_vec, true); - const int64_t total_scale_elements = static_cast(product(scale_shape)); - columnwise_scale_inv = at::empty({total_scale_elements}, uint8_opts); - columnwise_amax = at::empty({static_cast(num_tensors)}, float_opts); - } - - GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); - if (rowwise_usage) { - out_cpp.set_rowwise_data(rowwise_data->data_ptr(), this->dtype, getTensorShape(*rowwise_data)); - out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E4M3, - getTensorShape(*rowwise_scale_inv)); - out_cpp.set_amax(rowwise_amax->data_ptr(), DType::kFloat32, getTensorShape(*rowwise_amax)); - } - if (columnwise_usage) { - out_cpp.set_columnwise_data(columnwise_data->data_ptr(), this->dtype, - getTensorShape(*columnwise_data)); - out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E4M3, - getTensorShape(*columnwise_scale_inv)); - out_cpp.set_columnwise_amax(columnwise_amax->data_ptr(), DType::kFloat32, - getTensorShape(*columnwise_amax)); - } - if (first_dims.has_value()) { - out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); - } - if (tensor_offsets.has_value()) { - out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, - getTensorShape(*tensor_offsets)); - } - - out_cpp.set_with_gemm_swizzled_scales(this->optimize_for_gemm); - - py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); - py::dict kwargs; - py::tuple args(0); - const std::vector grouped_shape = {static_cast(logical_first_dim), - static_cast(logical_last_dim)}; - const std::vector grouped_stride = stride_from_shape(grouped_shape); - kwargs["shape"] = py::cast(grouped_shape); - kwargs["stride"] = py::cast(grouped_stride); - kwargs["dtype"] = py::cast(GetATenDType(dtype)); - kwargs["num_tensors"] = py::cast(num_tensors); - kwargs["quantizer"] = quantizer; - kwargs["data"] = maybe_tensor_to_py(rowwise_data); - kwargs["columnwise_data"] = maybe_tensor_to_py(columnwise_data); - kwargs["scale_inv"] = maybe_tensor_to_py(rowwise_scale_inv); - kwargs["columnwise_scale_inv"] = maybe_tensor_to_py(columnwise_scale_inv); - kwargs["amax"] = maybe_tensor_to_py(rowwise_amax); - kwargs["columnwise_amax"] = maybe_tensor_to_py(columnwise_amax); - kwargs["scale"] = py::none(); - kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); - kwargs["last_dims"] = py::none(); - kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); - kwargs["with_gemm_swizzled_scales"] = this->optimize_for_gemm; - PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); - if (result == nullptr) { - PyErr_Print(); - } - NVTE_CHECK(result != nullptr, "Failed to create GroupedTensor instance"); - py::object out_py = py::reinterpret_steal(result); - - return {std::move(out_cpp), std::move(out_py)}; -} - -std::pair NVFP4Quantizer::create_unquantized_tensor_with_amax( - TensorWrapper& quantized_tensor, DType dtype) { - // Construct tensor - auto shape = convertShape(quantized_tensor.shape()); - auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype); - - // Register amax pointer from quantized tensor - void* amax_ptr = quantized_tensor.amax(); - if (amax_ptr == nullptr) { - amax_ptr = quantized_tensor.get_columnwise_amax().data_ptr; - } - NVTE_CHECK(amax_ptr != nullptr, "Could not extract amax pointer from NVFP4 tensor."); - out_cpp.set_amax(amax_ptr, DType::kFloat32, std::vector{1}); - - // Zero out amax - NVTE_CHECK_CUDA(cudaMemsetAsync(amax_ptr, 0, sizeof(float), at::cuda::getCurrentCUDAStream())); - - return {std::move(out_cpp), std::move(out_py)}; -} - -std::pair NVFP4Quantizer::convert_and_update_tensor( - py::object tensor) const { - NVTE_CHECK(detail::IsNVFP4Tensor(tensor.ptr()), "NVFP4Quantizer must output to IsNVFP4Tensor."); - - // Scaling factor format - const bool with_gemm_swizzled_scales = false; // TODO (tmoon) Enable with optimize_for_gemm - - // Extract buffers from Python tensor - auto get_tensor = [&tensor](const char* name) -> std::optional { - auto attr_py = tensor.attr(name); - if (attr_py.is_none()) { - return std::nullopt; - } - return attr_py.cast(); - }; - auto rowwise_data = get_tensor("_rowwise_data"); - auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv"); - auto columnwise_data = get_tensor("_columnwise_data"); - auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv"); - auto amax_rowwise = get_tensor("_amax_rowwise"); - auto amax_columnwise = get_tensor("_amax_columnwise"); - NVTE_CHECK(rowwise_data || columnwise_data, "NVFP4Tensor has no data."); - - // Tensor dimensions, shape means original shape - std::vector shape; - if (columnwise_data) { - shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); - if (rowwise_data) { - auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); - NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, - ") and column-wise data (shape=", shape, ") do not match"); - } - } else { // Already checked columnwise_data_tensor == true - shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); - } - - size_t flat_first_dim = 1; - if (shape.size() > 0) { - for (size_t i = 0; i < shape.size() - 1; ++i) { - flat_first_dim *= shape[i]; - } - } - const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; - - // Coerce row-wise data - if (rowwise_usage) { - if (!rowwise_data) { - const std::vector shape_int64(shape.begin(), shape.end()); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - rowwise_data = at::empty(convert_shape_for_fp4(shape_int64), opts); - tensor.attr("_rowwise_data") = *rowwise_data; - } - if (!rowwise_scale_inv) { - const auto scale_inv_shape = get_scale_shape(shape, false); - const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), - scale_inv_shape.end()); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts); - tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; - } - if (!amax_rowwise) { - const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - // hadamard amax kernel will zero out pointer with ZeroAmaxKernel - // nvte_compute_amax_with_config will zero out the pointer if needed - amax_rowwise = at::empty({1}, opts); - tensor.attr("_amax_rowwise") = *amax_rowwise; - } - } else { // rowwise_usage == false - if (rowwise_data) { - rowwise_data.reset(); - tensor.attr("_rowwise_data") = py::none(); - } - if (rowwise_scale_inv) { - rowwise_scale_inv.reset(); - tensor.attr("_rowwise_scale_inv") = py::none(); - } - if (amax_rowwise) { - amax_rowwise.reset(); - tensor.attr("_amax_rowwise") = py::none(); - } - } - - // Coerce column-wise data - if (columnwise_usage) { - if (!columnwise_data) { - // enforce 2D shape to avoid [S, B, H] shape and B and be 1 - // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero - std::vector shape_int64_2d = {static_cast(flat_first_dim), - static_cast(flat_last_dim)}; - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - const auto transpose_shape_int64 = make_transpose_shape(shape_int64_2d); - columnwise_data = at::empty(convert_shape_for_fp4(transpose_shape_int64), opts); - tensor.attr("_columnwise_data") = *columnwise_data; - } - if (!columnwise_scale_inv) { - const auto scale_inv_shape = get_scale_shape(shape, true); - const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), - scale_inv_shape.end()); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - columnwise_scale_inv = at::empty(scale_inv_shape_int64, opts); - tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; - } - if (!amax_columnwise) { - const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - // hadamard amax kernel will zero out pointer with ZeroAmaxKernel - // nvte_compute_amax_with_config will zero out the pointer if needed - amax_columnwise = at::empty({1}, opts); - tensor.attr("_amax_columnwise") = *amax_columnwise; - } - } else { // columnwise_usage == false - if (columnwise_data) { - columnwise_data.reset(); - tensor.attr("_columnwise_data") = py::none(); - } - if (columnwise_scale_inv) { - columnwise_scale_inv.reset(); - tensor.attr("_columnwise_scale_inv") = py::none(); - } - if (amax_columnwise) { - amax_columnwise.reset(); - tensor.attr("_amax_columnwise") = py::none(); - } - } - - // Construct C++ tensor - TensorWrapper out_cpp(NVTE_NVFP4_1D_SCALING); - if (rowwise_usage) { - out_cpp.set_rowwise_data(rowwise_data->data_ptr(), DType::kFloat4E2M1, shape); - out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E4M3, - getTensorShape(*rowwise_scale_inv)); - out_cpp.set_amax(amax_rowwise->data_ptr(), DType::kFloat32, std::vector{1}); - } - if (columnwise_usage) { - // enforce 2D shape to avoid [S, B, H] shape and B and be 1 - // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero - std::vector shape_2d = {flat_first_dim, flat_last_dim}; - auto col_data_shape_fp4 = make_transpose_shape(shape_2d); - out_cpp.set_columnwise_data(columnwise_data->data_ptr(), DType::kFloat4E2M1, - col_data_shape_fp4); - out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E4M3, - getTensorShape(*columnwise_scale_inv)); - out_cpp.set_columnwise_amax(amax_columnwise->data_ptr(), DType::kFloat32, - std::vector{1}); - } - out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); - this->set_quantization_params(&out_cpp); - - return {std::move(out_cpp), std::move(tensor)}; -} - -void NVFP4Quantizer::quantize_with_rht_unfused_helper( - const TensorWrapper& input, TensorWrapper& out, TensorWrapper& rht_output_t_cpp, - QuantizationConfigWrapper& quant_config, QuantizationConfigWrapper& quant_config_columnwise, - cudaStream_t stream) { - // only triggered for irregular shapes where RHT cast fusion kernel is not eligible - if (rowwise_usage) { - // For rowwise usage, we need to quantize the input directly, but we need to avoid quantizing columnwise - TensorWrapper out_identity(out.scaling_mode()); - auto out_identity_data = out.get_rowwise_data(); - auto out_identity_scale_inv = out.get_rowwise_scale_inv(); - auto out_identity_amax = out.get_amax(); - out_identity.set_rowwise_data(out_identity_data.data_ptr, - static_cast(out_identity_data.dtype), - out_identity_data.shape); - out_identity.set_rowwise_scale_inv(out_identity_scale_inv.data_ptr, - static_cast(out_identity_scale_inv.dtype), - out_identity_scale_inv.shape); - out_identity.set_amax(out_identity_amax.data_ptr, static_cast(out_identity_amax.dtype), - out_identity_amax.shape); - - NVTE_SCOPED_GIL_RELEASE( - { nvte_quantize_v2(input.data(), out_identity.data(), quant_config, stream); }); - } - - if (columnwise_usage) { - // Get the output columnwise data, scale_inv, and amax - auto out_columnwise_data = out.get_columnwise_data(); - auto out_columnwise_scale_inv = out.get_columnwise_scale_inv(); - // NOTE: should already be populated. - auto out_columnwise_amax = out.get_columnwise_amax(); - - // Create a wrapper for the columnwise output, as the rowwise output. - // The reason is due to the input `rht_output_t` is already in the transposed layout. - // Thus, we only need a rowwise quantization to generate the columnwise output. - TensorWrapper out_transpose(out.scaling_mode()); - // Note: since we are faking columnwise tensor into rowwise, the flat first dim check will fail - // need to convert the shape to 2D here - auto colwise_data_shape = out_columnwise_data.shape; - std::vector colwise_data_shape_2d; - // shape could be [512, 32, 64], that's actually 512, 32, 128 because 2 FP4 take 1 byte - // the 2D shape should be [512, 32*128], but columnwise data shape expect last dim to be halved again - // so the multiple 2 get cancelled out - colwise_data_shape_2d.push_back(colwise_data_shape.data[0]); - size_t last_dim = 1; - for (size_t i = 1; i < colwise_data_shape.ndim; ++i) { - last_dim *= colwise_data_shape.data[i]; - } - colwise_data_shape_2d.push_back(last_dim); - - out_transpose.set_rowwise_data(out_columnwise_data.data_ptr, - static_cast(out_columnwise_data.dtype), - colwise_data_shape_2d); - out_transpose.set_rowwise_scale_inv(out_columnwise_scale_inv.data_ptr, - static_cast(out_columnwise_scale_inv.dtype), - out_columnwise_scale_inv.shape); - out_transpose.set_amax(out_columnwise_amax.data_ptr, - static_cast(out_columnwise_amax.dtype), - out_columnwise_amax.shape); - - // Invoking fallback RHT kernel unfused. - - NVTE_SCOPED_GIL_RELEASE({ - // Perform the RHT(input.t), and write to rht_output_cpp.columnwise. - nvte_hadamard_transform(input.data(), rht_output_t_cpp.data(), 0, - this->rht_matrix_random_sign_mask_t, stream); - }); - - // Quantize kernel will treat everything as rowwise input/output, which is - // intended. - NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose.data(), quant_config_columnwise, - stream); - }); - } -} - -void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag, - bool compute_amax) { - // Nothing to be done if input is empty - if (input.numel() == 0) { - return; - } - - auto stream = at::cuda::getCurrentCUDAStream(); - - QuantizationConfigWrapper quant_config; - QuantizationConfigWrapper quant_config_columnwise; - if (noop_flag) { - quant_config.set_noop_tensor(noop_flag->data()); - quant_config_columnwise.set_noop_tensor(noop_flag->data()); - } - quant_config.set_nvfp4_2d_quantization(this->with_2d_quantization); - quant_config.set_stochastic_rounding(this->stochastic_rounding); - - // We only need RHT for columnwise usage. - // flat first dim and last dim for multi dimensional input - size_t rows = 1; - for (size_t i = 0; i < input.ndim() - 1; ++i) { - rows *= input.size(i); - } - size_t cols = input.size(input.ndim() - 1); - - // Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT - bool eligible_for_rht_cast_fusion = - input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; - - // Stochastic rounding - // When both rowwise and columnwise quantization are used with RHT, - // we need separate RNG states for each to ensure they use different random numbers. - TensorWrapper te_rng_state; - TensorWrapper te_rng_state_columnwise; - - // Only need a separate rng state when: - // 1. Stochastic rounding is enabled - // 2. RHT is enabled - // 3. Columnwise usage is enabled - // 4. Rowwise and columnwise quantization are not fused, - // because within a single kernel we can generate two different random numbers for rowwise and columnwise - const bool need_separate_columnwise_rng = this->stochastic_rounding && this->with_rht && - this->columnwise_usage && - (!eligible_for_rht_cast_fusion); - - if (this->stochastic_rounding) { - const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened - auto gen = at::get_generator_or_default( - std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); - auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); - - // Generate RNG state for rowwise quantization - at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); - auto rng_state = torch::empty({2}, opts); - philox_unpack(philox_args, static_cast(rng_state.data_ptr())); - te_rng_state = makeTransformerEngineTensor(rng_state); - quant_config.set_rng_state(te_rng_state.data()); - - // Generate separate RNG state for columnwise quantization - if (need_separate_columnwise_rng) { - at::PhiloxCudaState philox_args_columnwise = init_philox_state(gen, rng_elts_per_thread); - auto rng_state_columnwise = torch::empty({2}, opts); - philox_unpack(philox_args_columnwise, static_cast(rng_state_columnwise.data_ptr())); - te_rng_state_columnwise = makeTransformerEngineTensor(rng_state_columnwise); - quant_config_columnwise.set_stochastic_rounding(true); - quant_config_columnwise.set_rng_state(te_rng_state_columnwise.data()); - quant_config_columnwise.set_nvfp4_2d_quantization(this->with_2d_quantization); - } - } - - // Compute amax. - if (this->with_rht) { - if (input.dtype() != DType::kBFloat16) { - NVTE_ERROR("RHT is only supported for bfloat16 input, got dtype enum value ", - static_cast(input.dtype())); - } - if (this->with_post_rht_amax) { - // We need: - // 1. Rowwise amax = amax for input - // 2. Columnwise amax = amax for RHT(input.t) - NVTE_SCOPED_GIL_RELEASE({ - nvte_hadamard_transform_amax(input.data(), out.data(), 0, - this->rht_matrix_random_sign_mask_t, stream); - }); - } else { - // raise error since it's not supported yet - NVTE_ERROR( - "Pre-RHT amax is not supported yet. " - "Use with_post_rht_amax=true instead."); - } - } else { // Without RHT - if (compute_amax) { - // Amax pointers - auto rowwise_amax_ptr = out.get_amax().data_ptr; - auto columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; - void* amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; - NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); - - // Compute amax of input tensor - out.set_amax(amax_ptr, DType::kFloat32, std::vector{1}); - NVTE_SCOPED_GIL_RELEASE( - { nvte_compute_amax_with_config(input.data(), out.data(), quant_config, stream); }); - out.set_amax(rowwise_amax_ptr, DType::kFloat32, std::vector{1}); - - // Make sure row-wise and column-wise amaxes match - if (rowwise_amax_ptr != amax_ptr && rowwise_amax_ptr != nullptr) { - NVTE_CHECK_CUDA(cudaMemcpyAsync(rowwise_amax_ptr, amax_ptr, sizeof(float), - cudaMemcpyDeviceToDevice, stream)); - } - if (columnwise_amax_ptr != amax_ptr && columnwise_amax_ptr != nullptr) { - NVTE_CHECK_CUDA(cudaMemcpyAsync(columnwise_amax_ptr, amax_ptr, sizeof(float), - cudaMemcpyDeviceToDevice, stream)); - } - } - } - - // amax reduction - if (this->with_amax_reduction) { - std::vector amax_tensors; - // push amax tensors inside if they need to be reduced - auto make_amax_tensor = [](void* data_ptr) { - return at::from_blob( - data_ptr, std::vector{1}, - [](void*) {}, // deleter doing nothing since it doesn't own the data - at::device(at::kCUDA).dtype(torch::kFloat32)); - }; - if (rowwise_usage) { - amax_tensors.push_back(make_amax_tensor(out.get_amax().data_ptr)); - } - if (columnwise_usage) { - amax_tensors.push_back(make_amax_tensor(out.get_columnwise_amax().data_ptr)); - } - c10d::AllreduceCoalescedOptions opts; - opts.reduceOp = c10d::ReduceOp::MAX; - NVTE_SCOPED_GIL_RELEASE( - { this->amax_reduction_group->allreduce_coalesced(amax_tensors, opts)->wait(); }); - } - - // Fast math toggle: RHT transform can be accelerated - // What math is accelerated? Only the high precision math, so numerical impact is minimal - // 1. replace 1 / x by reciprocal_approximate_ftz(x) - // 2. when RHT cast fusion is available, fusion allows cast to be performed on FP32 data, - // this will essentially remove a round trip between FP32 to BF16 then FP32 - const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); - if (use_fast_math) { - quant_config.set_use_fast_math(true); - quant_config_columnwise.set_use_fast_math(true); - } - - if (this->with_rht) { - if (eligible_for_rht_cast_fusion) { - // fusion kernel requires passing in RHT matrix directly for maximum performance - NVTE_CHECK(this->rht_matrix.defined() && this->rht_matrix.numel() > 0, - "RHT matrix is not available."); - auto rht_matrix_nvte = makeTransformerEngineTensor(this->rht_matrix); - // Fusion kernel that does the following: - // 1. Rowwise quantization - // 2. RHT followed by columnwise quantization & transpose - NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_with_hadamard_transform(input.data(), out.data(), rht_matrix_nvte.data(), - quant_config, stream); - }); - } else { - // Use separate RNG state for columnwise to ensure different random numbers than rowwise - // This is only necessary because it's the unfused path where rowwise and columnwise - // are separate kernel launches - auto& columnwise_quant_config_to_use = - need_separate_columnwise_rng ? quant_config_columnwise : quant_config; - // unfused path also needs memory allocation for intermediate buffer for RHT output - at::Tensor rht_output_t; // The RHT(x_t) output, in columnwise layout - // This wrapper is going to be passed as input to the quantization kernel. - TensorWrapper rht_output_t_cpp; // Wrapper to contain the RHT(x) and RHT(x_t) outputs - rht_output_t = - allocateTorchTensor(static_cast(cols), static_cast(rows), input.dtype()); - // NOTE (frsun): This is non-intuitive, we are writing the - // result of transposed RHT to the output of rowwise. - rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(), - std::vector{cols, rows}); - this->quantize_with_rht_unfused_helper(input, out, rht_output_t_cpp, quant_config, - columnwise_quant_config_to_use, stream); - } - } else { - NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); }); - } -} - -void NVFP4Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag) { - this->quantize_impl(input, out, noop_flag, true); -} - -void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out) { - // Update output tensor amaxes with input tensor amax - auto input_amax_ptr = input.amax(); - auto output_rowwise_amax_ptr = out.get_amax().data_ptr; - auto output_columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; - NVTE_CHECK(input_amax_ptr != nullptr || - (output_rowwise_amax_ptr == nullptr && output_columnwise_amax_ptr == nullptr), - "Input tensor does not have pre-computed amax"); - if (input_amax_ptr != output_rowwise_amax_ptr && input_amax_ptr != nullptr && - output_rowwise_amax_ptr != nullptr) { - NVTE_CHECK_CUDA(cudaMemcpyAsync(output_rowwise_amax_ptr, input_amax_ptr, sizeof(float), - cudaMemcpyDeviceToDevice, at::cuda::getCurrentCUDAStream())); - } - if (input_amax_ptr != output_columnwise_amax_ptr && input_amax_ptr != nullptr && - output_columnwise_amax_ptr != nullptr) { - NVTE_CHECK_CUDA(cudaMemcpyAsync(output_columnwise_amax_ptr, input_amax_ptr, sizeof(float), - cudaMemcpyDeviceToDevice, at::cuda::getCurrentCUDAStream())); - } - input.set_amax(nullptr, DType::kFloat32, input.defaultShape); - - // Perform quantization - this->quantize_impl(input, out, std::nullopt, false); -} - -std::vector NVFP4Quantizer::get_scale_shape(const std::vector& shape, - bool columnwise) const { - size_t numel = 1; - for (auto s : shape) { - numel *= s; - } - - auto last_dim = shape.back(); - auto flat_first_dim = numel / last_dim; - - NVTE_CHECK(last_dim % NVFP4_BLOCK_SIZE == 0, "Last dim for NVFP4 must be divisible by ", - NVFP4_BLOCK_SIZE, " (got dim=", last_dim, ")"); - NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, - "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, - " (got shape=", shape, ")"); - - std::vector scale_shape; - - bool rowwise_usage = !columnwise; - - if (rowwise_usage) { - // rowwise scaling factor shape - size_t sinv0 = roundup(flat_first_dim, 128); - size_t sinv1 = roundup(last_dim / NVFP4_BLOCK_SIZE, 4); - scale_shape = {sinv0, sinv1}; - } else { - // columnwise scaling factor shape - size_t sinv0 = roundup(last_dim, 128); - size_t sinv1 = roundup(flat_first_dim / NVFP4_BLOCK_SIZE, 4); - scale_shape = {sinv0, sinv1}; - } - return scale_shape; -} - -} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/stable_common.h b/transformer_engine/pytorch/csrc/stable_common.h new file mode 100644 index 0000000000..968a241dcd --- /dev/null +++ b/transformer_engine/pytorch/csrc/stable_common.h @@ -0,0 +1,265 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_STABLE_COMMON_H_ +#define TRANSFORMER_ENGINE_PYTORCH_CSRC_STABLE_COMMON_H_ + +// PyTorch Stable ABI headers +#include +#include +#include +#include +#include + +// CUDA headers +#include + +// Transformer Engine C API headers +#include + +#include +#include +#include + +#include "common/util/logging.h" + +namespace transformer_engine::pytorch::stable { + +using torch::headeronly::ScalarType; + +// ============================================================================ +// DType <-> ScalarType converters +// ============================================================================ + +inline ScalarType GetStableScalarType(transformer_engine::DType t) { + switch (t) { + case transformer_engine::DType::kInt16: + return ScalarType::Short; + case transformer_engine::DType::kInt32: + return ScalarType::Int; + case transformer_engine::DType::kInt64: + return ScalarType::Long; + case transformer_engine::DType::kFloat32: + return ScalarType::Float; + case transformer_engine::DType::kFloat16: + return ScalarType::Half; + case transformer_engine::DType::kBFloat16: + return ScalarType::BFloat16; + case transformer_engine::DType::kByte: + return ScalarType::Byte; + case transformer_engine::DType::kFloat8E4M3: + return ScalarType::Float8_e4m3fn; + case transformer_engine::DType::kFloat8E5M2: + return ScalarType::Float8_e5m2; + case transformer_engine::DType::kFloat8E8M0: + return ScalarType::Byte; // e8m0 not natively supported + default: + NVTE_ERROR("Invalid DType (", static_cast(t), ")."); + } +} + +inline transformer_engine::DType GetTransformerEngineDType(ScalarType t) { + switch (t) { + case ScalarType::Float8_e4m3fn: + return transformer_engine::DType::kFloat8E4M3; + case ScalarType::Float8_e5m2: + return transformer_engine::DType::kFloat8E5M2; + case ScalarType::Half: + return transformer_engine::DType::kFloat16; + case ScalarType::Float: + return transformer_engine::DType::kFloat32; + case ScalarType::BFloat16: + return transformer_engine::DType::kBFloat16; + case ScalarType::Bool: + return transformer_engine::DType::kByte; + case ScalarType::Byte: + return transformer_engine::DType::kByte; + case ScalarType::Short: + return transformer_engine::DType::kInt16; + case ScalarType::Int: + return transformer_engine::DType::kInt32; + case ScalarType::Long: + return transformer_engine::DType::kInt64; + default: + NVTE_ERROR("Invalid ScalarType (", static_cast(t), ")."); + } +} + +// ============================================================================ +// CUDA stream utilities +// ============================================================================ + +/// Get the current CUDA stream as a raw cudaStream_t. +/// Uses the stable ABI's aoti_torch_get_current_cuda_stream. +inline cudaStream_t getCurrentCUDAStreamRaw(int32_t device_index = -1) { + if (device_index < 0) { + device_index = torch::stable::accelerator::getCurrentDeviceIndex(); + } + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_index, &stream_ptr)); + return reinterpret_cast(stream_ptr); +} + +// ============================================================================ +// Device properties +// ============================================================================ + +/// Get SM count for the given CUDA device (or current device if -1). +/// Replaces at::cuda::getCurrentDeviceProperties()->multiProcessorCount. +inline int getSMCount(int device_index = -1) { + if (device_index < 0) { + device_index = static_cast(torch::stable::accelerator::getCurrentDeviceIndex()); + } + cudaDeviceProp prop; + cudaError_t err = cudaGetDeviceProperties(&prop, device_index); + NVTE_CHECK(err == cudaSuccess, "cudaGetDeviceProperties failed: ", cudaGetErrorString(err)); + return prop.multiProcessorCount; +} + +// ============================================================================ +// Shape utilities +// ============================================================================ + +/// Convert stable tensor sizes (int64_t array) to vector. +inline std::vector getStableTensorShape(const torch::stable::Tensor& t) { + auto sizes = t.sizes(); + std::vector shape; + shape.reserve(sizes.size()); + for (size_t i = 0; i < sizes.size(); ++i) { + shape.push_back(static_cast(sizes[i])); + } + return shape; +} + +// ============================================================================ +// TensorWrapper construction from stable::Tensor +// ============================================================================ + +/// Create a TensorWrapper from a torch::stable::Tensor. +/// Extracts data_ptr, shape, and dtype. +inline transformer_engine::TensorWrapper makeTransformerEngineTensor( + const torch::stable::Tensor& tensor) { + transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); + std::vector shape = getStableTensorShape(tensor); + return transformer_engine::TensorWrapper(tensor.data_ptr(), shape, dtype); +} + +/// Create a TensorWrapper from raw components (same as unstable version). +inline transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const std::vector& shape, const transformer_engine::DType type) { + return transformer_engine::TensorWrapper(data_ptr, shape, type); +} + +// ============================================================================ +// Tensor allocation via stable ABI +// ============================================================================ + +/// Allocate an empty tensor on CUDA via the stable ABI. +inline torch::stable::Tensor allocateStableTensor(const std::vector& shape, + ScalarType dtype, int32_t device_index = -1) { + if (device_index < 0) { + device_index = torch::stable::accelerator::getCurrentDeviceIndex(); + } + torch::headeronly::IntHeaderOnlyArrayRef size_ref(shape.data(), shape.size()); + torch::stable::Device device(torch::headeronly::DeviceType::CUDA, device_index); + return torch::stable::empty(size_ref, dtype, + std::nullopt, // layout + device, + std::nullopt, // pin_memory + std::nullopt // memory_format + ); +} + +/// Allocate an empty tensor on CUDA, using TE DType. +inline torch::stable::Tensor allocateStableTensor(const std::vector& shape, + transformer_engine::DType te_dtype, + int32_t device_index = -1) { + return allocateStableTensor(shape, GetStableScalarType(te_dtype), device_index); +} + +/// Allocate a zero-filled tensor on CUDA via the stable ABI. +inline torch::stable::Tensor allocateStableTensorZeros(const std::vector& shape, + ScalarType dtype, + int32_t device_index = -1) { + auto t = allocateStableTensor(shape, dtype, device_index); + torch::stable::zero_(t); + return t; +} + +/// Allocate a zero-filled tensor on CUDA, using TE DType. +inline torch::stable::Tensor allocateStableTensorZeros(const std::vector& shape, + transformer_engine::DType te_dtype, + int32_t device_index = -1) { + return allocateStableTensorZeros(shape, GetStableScalarType(te_dtype), device_index); +} + +// ============================================================================ +// TensorWrapper construction with quantization metadata +// ============================================================================ + +/// Build a TensorWrapper with rowwise quantization metadata. +/// The output_data tensor holds the quantized data. +/// amax, scale, scale_inv are optional quantization parameters. +/// If scale_inv_dtype is -1, defaults to kFloat32 (use kFloat8E8M0=10 for +/// MXFP8, kFloat8E4M3=8 for NVFP4). +inline transformer_engine::TensorWrapper makeQuantizedTensorWrapper( + const torch::stable::Tensor& output_data, transformer_engine::DType te_dtype, + const std::vector& shape, const std::optional& amax, + const std::optional& scale, + const std::optional& scale_inv, NVTEScalingMode scaling_mode) { + TensorWrapper out(scaling_mode); + out.set_rowwise_data(output_data.data_ptr(), te_dtype, shape); + + const std::vector scalar_shape{1}; + if (amax.has_value() && amax->numel() > 0) { + out.set_amax(amax->data_ptr(), DType::kFloat32, scalar_shape); + } + if (scale.has_value() && scale->numel() > 0) { + out.set_scale(scale->data_ptr(), DType::kFloat32, scalar_shape); + } + if (scale_inv.has_value() && scale_inv->numel() > 0) { + // Determine scale_inv dtype from scaling mode + DType si_dtype = DType::kFloat32; + if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + si_dtype = DType::kFloat8E8M0; + } else if (scaling_mode == NVTE_NVFP4_1D_SCALING) { + si_dtype = DType::kFloat8E4M3; + } + auto si_shape = getStableTensorShape(scale_inv.value()); + out.set_rowwise_scale_inv(scale_inv->data_ptr(), si_dtype, si_shape); + } + return out; +} + +/// Helper to run the two-phase workspace pattern for any NVTE function. +/// The callable should have signature: void(NVTETensor workspace) +/// First call queries workspace size, second call runs the kernel. +template +inline void runWithWorkspace(Fn&& fn, int32_t device_idx) { + TensorWrapper workspace; + fn(workspace.data()); + + // workspace_data must outlive the second fn() call because the kernel + // launched by fn() reads from the workspace via a raw data_ptr(). + // Declaring it here (outside the if block) keeps the tensor alive until + // after the kernel is submitted. + torch::stable::Tensor workspace_data; + auto ws_shape = workspace.shape(); + auto ws_dtype = workspace.dtype(); + if (ws_shape.ndim > 0 && workspace.numel() > 0) { + workspace_data = allocateStableTensor( + std::vector(ws_shape.data, ws_shape.data + ws_shape.ndim), ws_dtype, device_idx); + workspace = makeTransformerEngineTensor( + workspace_data.data_ptr(), + std::vector(ws_shape.data, ws_shape.data + ws_shape.ndim), ws_dtype); + } + + fn(workspace.data()); +} + +} // namespace transformer_engine::pytorch::stable + +#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_STABLE_COMMON_H_ diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp deleted file mode 100644 index e9c6ca882e..0000000000 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ /dev/null @@ -1,296 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include - -#include "common.h" -#include "pybind.h" - -namespace transformer_engine::pytorch { -namespace detail { - -TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer) { - auto ret = TensorWrapper(quantizer->get_scaling_mode()); - - bool data_exists = !tensor.attr("_data").is_none(); - bool transpose_exists = - !tensor.attr("_transpose_invalid").cast() && !tensor.attr("_transpose").is_none(); - - NVTE_CHECK(data_exists || transpose_exists, "No data found for FP8 Tensor."); - - // FP8 data - const DType fp8_dtype = tensor.attr("_fp8_dtype").cast(); - if (data_exists) { - const auto &data = tensor.attr("_data").cast(); - ret.set_rowwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data)); - } - - // FP8 data transpose - if (transpose_exists) { - const auto &data_transpose = tensor.attr("_transpose").cast(); - ret.set_columnwise_data(data_transpose.data_ptr(), fp8_dtype, getTensorShape(data_transpose)); - } - - // Scale-inverse - { - const auto &scale_inv = tensor.attr("_scale_inv").cast(); - float *dptr = reinterpret_cast(scale_inv.data_ptr()); - const auto &dtype = GetTransformerEngineDType(scale_inv.scalar_type()); - const auto &shape = getTensorShape(scale_inv); - ret.set_rowwise_scale_inv(dptr, dtype, shape); - ret.set_columnwise_scale_inv(dptr, dtype, shape); - } - - // Quantizer state - quantizer->set_quantization_params(&ret); - - return ret; -} - -TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) { - auto ret = TensorWrapper(NVTE_MXFP8_1D_SCALING); - - const bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); - const bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); - const bool with_gemm_swizzled_scales = tensor.attr("_with_gemm_swizzled_scales").cast(); - - NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for MXFP8 Tensor."); - - // Row-scaled data - const DType fp8_dtype = tensor.attr("_fp8_dtype").cast(); - if (rowwise_usage) { - const auto &data = tensor.attr("_rowwise_data").cast(); - const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast(); - ret.set_rowwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data)); - ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E8M0, getTensorShape(scale_inv)); - } - - // Column-scaled data - if (columnwise_usage) { - const auto &data = tensor.attr("_columnwise_data").cast(); - const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); - ret.set_columnwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data)); - ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E8M0, - getTensorShape(scale_inv)); - } - - // Scale layout - ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); - - // Quantizer state - quantizer->set_quantization_params(&ret); - - return ret; -} - -TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer *quantizer) { - const DType dtype = tensor.attr("_fp8_dtype").cast(); - bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast(); - - bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); - bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); - - auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D); - - // Row-wise data - if (rowwise_usage) { - const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast(); - const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast(); - void *scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr(); - const auto &rowwise_shape = getTensorShape(data_rowwise); - ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, rowwise_shape); - const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); - ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape); - } - - // Column-wise data - if (columnwise_usage) { - const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast(); - const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast(); - void *scale_inv_colwise_dptr = scale_inv_colwise.data_ptr(); - const auto &shape = getTensorShape(data_colwise); - ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape); - - const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); - ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape); - } - - // Quantizer state - quantizer->set_quantization_params(&ret); - - return ret; -} - -TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) { - const DType dtype = tensor.attr("_fp4_dtype").cast(); - - auto ret = TensorWrapper(NVTE_NVFP4_1D_SCALING); - - const bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); - const bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); - const bool with_gemm_swizzled_scales = tensor.attr("_with_gemm_swizzled_scales").cast(); - - NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for NVFP4 Tensor."); - - // Row-scaled data - if (rowwise_usage) { - const auto &data = tensor.attr("_rowwise_data").cast(); - const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast(); - const auto &amax_rowwise = tensor.attr("_amax_rowwise").cast(); - ret.set_rowwise_data(data.data_ptr(), dtype, - convert_shape_back_from_fp4(getTensorShape(data), false)); - ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); - ret.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise)); - } - - // Column-scaled data - if (columnwise_usage) { - const auto &data = tensor.attr("_columnwise_data").cast(); - const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); - const auto &amax_columnwise = tensor.attr("_amax_columnwise").cast(); - ret.set_columnwise_data(data.data_ptr(), DType::kFloat4E2M1, - convert_shape_back_from_fp4(getTensorShape(data), false)); - ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, - getTensorShape(scale_inv)); - ret.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, - getTensorShape(amax_columnwise)); - } - - // Scale layout - ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); - - // Quantizer state - quantizer->set_quantization_params(&ret); - - return ret; -} - -NVTEScalingMode ScalingModeFromQuantizer(py::handle quantizer) { - auto *quantizer_ptr = quantizer.ptr(); - if (IsMXFP8Quantizers(quantizer_ptr)) { - return NVTE_MXFP8_1D_SCALING; - } - if (IsNVFP4Quantizers(quantizer_ptr)) { - return NVTE_NVFP4_1D_SCALING; - } - if (IsFloat8BlockwiseQuantizers(quantizer_ptr)) { - const int block_scaling_dim = quantizer.attr("block_scaling_dim").cast(); - return (block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D; - } - return NVTE_DELAYED_TENSOR_SCALING; -} - -DType GetTransformerEngineDTypeForScaleInv(py::handle quantizer, at::Tensor scale_inv) { - auto *quantizer_ptr = quantizer.ptr(); - if (IsMXFP8Quantizers(quantizer_ptr)) { - return DType::kFloat8E8M0; - } - if (IsFloat8BlockwiseQuantizers(quantizer_ptr)) { - return DType::kFloat32; - } - if (IsNVFP4Quantizers(quantizer_ptr)) { - return DType::kFloat8E4M3; - } - return GetTransformerEngineDType(scale_inv.scalar_type()); -} - -GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) { - // Returns a GroupedTensorWrapper from a PyTorch GroupedTensor. - const auto num_tensors = tensor.attr("num_tensors").cast(); - const auto logical_shape = tensor.attr("logical_shape").cast>(); - py::handle quantizer = py::none(); - DType quantizer_dtype = DType::kNumTypes; - NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - if (!tensor.attr("quantizer").is_none()) { - quantizer = tensor.attr("quantizer"); - if (!quantizer.is_none()) { - scaling_mode = ScalingModeFromQuantizer(quantizer); - quantizer_dtype = quantizer.attr("dtype").cast(); - } - } - auto ret = GroupedTensorWrapper(num_tensors, logical_shape, scaling_mode); - - // Rowwise data - if (!tensor.attr("rowwise_data").is_none()) { - const auto &data = tensor.attr("rowwise_data").cast(); - DType data_dtype = - quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype; - ret.set_rowwise_data(data.data_ptr(), data_dtype, getTensorShape(data)); - } - - // Columnwise data - if (!tensor.attr("columnwise_data").is_none()) { - const auto &data = tensor.attr("columnwise_data").cast(); - DType data_dtype = - quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype; - ret.set_columnwise_data(data.data_ptr(), data_dtype, getTensorShape(data)); - } - - // Scale - if (!tensor.attr("scale").is_none()) { - const auto &scale = tensor.attr("scale").cast(); - ret.set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()), - getTensorShape(scale)); - } - - // Amax - if (!tensor.attr("amax").is_none()) { - const auto &amax = tensor.attr("amax").cast(); - ret.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), - getTensorShape(amax)); - } - if (!tensor.attr("columnwise_amax").is_none()) { - const auto &amax = tensor.attr("columnwise_amax").cast(); - ret.set_columnwise_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), - getTensorShape(amax)); - } - - // Scale inverse - if (!tensor.attr("scale_inv").is_none()) { - const auto &scale_inv = tensor.attr("scale_inv").cast(); - ret.set_rowwise_scale_inv(scale_inv.data_ptr(), - GetTransformerEngineDTypeForScaleInv(quantizer, scale_inv), - getTensorShape(scale_inv)); - } - if (!tensor.attr("columnwise_scale_inv").is_none()) { - const auto &scale_inv = tensor.attr("columnwise_scale_inv").cast(); - ret.set_columnwise_scale_inv(scale_inv.data_ptr(), - GetTransformerEngineDTypeForScaleInv(quantizer, scale_inv), - getTensorShape(scale_inv)); - } - - // Shape metadata - if (!tensor.attr("first_dims").is_none()) { - const auto &first_dims = tensor.attr("first_dims").cast(); - ret.set_first_dims(first_dims.data_ptr(), GetTransformerEngineDType(first_dims.scalar_type()), - getTensorShape(first_dims)); - } - if (!tensor.attr("last_dims").is_none()) { - const auto &last_dims = tensor.attr("last_dims").cast(); - ret.set_last_dims(last_dims.data_ptr(), GetTransformerEngineDType(last_dims.scalar_type()), - getTensorShape(last_dims)); - } - if (!tensor.attr("tensor_offsets").is_none()) { - const auto &tensor_offsets = tensor.attr("tensor_offsets").cast(); - ret.set_tensor_offsets(tensor_offsets.data_ptr(), - GetTransformerEngineDType(tensor_offsets.scalar_type()), - getTensorShape(tensor_offsets)); - } - - bool with_gemm_swizzled = false; - if (py::hasattr(tensor, "_with_gemm_swizzled_scales")) { - with_gemm_swizzled = tensor.attr("_with_gemm_swizzled_scales").cast(); - } - ret.set_with_gemm_swizzled_scales(with_gemm_swizzled); - - return ret; -} - -} // namespace detail - -} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h deleted file mode 100644 index 587ec289a4..0000000000 --- a/transformer_engine/pytorch/csrc/util.h +++ /dev/null @@ -1,63 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ -#define TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ - -#include - -#include -#include -#include - -#include "transformer_engine/transformer_engine.h" - -namespace transformer_engine { -namespace pytorch { - -/*! \brief Convert tensor block scales into GEMM swizzled format. - * - * The returned swizzled scales should be kept alive during the GEMM. - */ -std::tuple, std::optional> swizzle_scales_for_gemm( - TensorWrapper& tensor, bool rowwise_usage, bool columnwise_usage); - -/*! \brief Convert multiple tensor block scales into GEMM swizzled format. - * - * The returned swizzled scales should be kept alive during the GEMMs. - */ -std::optional multi_tensor_swizzle_scales_for_gemm(std::vector& tensors, - bool rowwise_usage, - bool columnwise_usage); - -using SwizzledGroupedScales = std::pair, std::optional>; - -/*! \brief Swizzle grouped tensor scales for GEMM if needed. - * Currently only works for MXFP8 1D scaling with uniform shapes. - * - * The returned swizzled scales should be kept alive during the GEMM. - */ -std::optional maybe_swizzle_grouped_tensor_for_gemm( - GroupedTensorWrapper& input); - -/*! \brief Convert a block scaling tensor to an mxfp8 tensor in-place. - * - * If rowwise==false, the columnwise data will be reinterpreted as - * rowwise data to avoid transposing it in memory. Due to differences - * in how block scaling and mxfp8 store data, this requires the - * calling code to treat the output tensor as having been transposed - * in this case. - * - * Returns the swizzled scaling factor of the converted mxfp8 tensor. - * The returned swizzled scaling factor tensor should be kept alive - * during the GEMM. - */ -at::Tensor convert_block_scaling_to_mxfp8_tensor(TensorWrapper& input, bool rowwise); - -} // namespace pytorch -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index dd01ae05d3..2422f57ac3 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -231,8 +231,15 @@ class NVFP4TensorRef(QuantizedTensorStorage): @property def custom(self) -> bool: - """Flag to indicate this quantized tensor is custom.""" - return True + """Flag to indicate this quantized tensor is custom. + + Returns False so that GEMM operations route through cuBLAS + (the same kernel used by the production recipe) rather than + the Python qgemm reference implementation. This keeps the + reference quantization path independent while still exercising + the production GEMM. + """ + return False def prepare_for_saving( self, @@ -255,23 +262,60 @@ def restore_from_saved( self.scale_t = tensors[3] return tensors[4:] - # Compatibility + # NVFP4TensorStorage-compatible properties so that extract_tensor_data() + # and _extract_gemm_operand() can handle this type on the cuBLAS path. @property - def _data(self): + def _rowwise_data(self): return self.data - @_data.setter - def _data(self, value): + @_rowwise_data.setter + def _rowwise_data(self, value): self.data = value @property - def _scale_inv(self): + def _rowwise_scale_inv(self): return self.scale - @_scale_inv.setter - def _scale_inv(self, value): + @_rowwise_scale_inv.setter + def _rowwise_scale_inv(self, value): self.scale = value + @property + def _columnwise_data(self): + return self.data_t + + @_columnwise_data.setter + def _columnwise_data(self, value): + self.data_t = value + + @property + def _columnwise_scale_inv(self): + return self.scale_t + + @_columnwise_scale_inv.setter + def _columnwise_scale_inv(self, value): + self.scale_t = value + + @property + def _amax_rowwise(self): + return self.global_amax_row + + @_amax_rowwise.setter + def _amax_rowwise(self, value): + self.global_amax_row = value + + @property + def _amax_columnwise(self): + return self.global_amax_col + + @_amax_columnwise.setter + def _amax_columnwise(self, value): + self.global_amax_col = value + + @property + def _with_gemm_swizzled_scales(self): + return False + def __repr__(self): return ( f"{self.__class__.__name__}(" @@ -577,6 +621,23 @@ def _rm_pad_tensor(tensor: torch.Tensor, original_size: tuple[int, ...]) -> torc out = tensor[:M, :N].contiguous() return out + @staticmethod + def _pad_scale_for_gemm(scale: torch.Tensor, M: int, K: int) -> torch.Tensor: + """Pad scale tensor to match cuBLAS alignment requirements. + + cuBLAS expects NVFP4 scale tensors with M padded to a multiple + of 128 and K//16 padded to a multiple of 4. + """ + BLOCK = 16 + target_m = ((M + 127) // 128) * 128 + target_k = (((K // BLOCK) + 3) // 4) * 4 + cur_m, cur_k = scale.shape + if cur_m >= target_m and cur_k >= target_k: + return scale + padded = torch.zeros(target_m, target_k, dtype=scale.dtype, device=scale.device) + padded[:cur_m, :cur_k] = scale + return padded + def _quantize(self, tensor: torch.Tensor) -> Tuple[ Optional[torch.Tensor], Optional[torch.Tensor], @@ -619,12 +680,13 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ ) # Prepare inputs once so we can reuse for both amax and quantization # Row-input will always be the original input. + # Note: RHT is NOT applied here because the stable ABI quantize + # kernel does not fuse RHT into standalone quantization; RHT is + # only applied by fused LN/activation+quantize kernels. Matching + # the production behaviour ensures the reference's columnwise data + # is identical to production, allowing exact wgrad comparison. row_input = tensor - col_input = ( - self._apply_rht(tensor.t().contiguous()) - if self.with_rht - else tensor.t().contiguous() - ) + col_input = tensor.t().contiguous() # Compute amax for rowwise and columnwise paths separately global_amax_row = torch.max(torch.abs(row_input)).to(torch.float32).view(1) global_amax_col = ( @@ -654,6 +716,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ sx = sx.T qx = self._rm_pad_tensor(qx, (M, N // 2)) + sx = self._pad_scale_for_gemm(sx, M, N) else: qx = None @@ -675,6 +738,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ ) qx_t = self._rm_pad_tensor(qx_t, (N, M // 2)) + sx_t = self._pad_scale_for_gemm(sx_t, N, M) if transpose_scales: sx_t = sx_t.T diff --git a/transformer_engine/pytorch/ops/basic/dropout.py b/transformer_engine/pytorch/ops/basic/dropout.py index 8850604aad..b20764c88f 100644 --- a/transformer_engine/pytorch/ops/basic/dropout.py +++ b/transformer_engine/pytorch/ops/basic/dropout.py @@ -11,7 +11,6 @@ import transformer_engine_torch as tex from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor import Quantizer -from ...tensor.storage.float8_tensor_storage import Float8TensorStorage from .._common import maybe_autocast_dtype, maybe_dequantize from ..op import BasicOperation, OperationContext @@ -55,9 +54,7 @@ def op_forward( if impl == "evaluation": out = input_ elif impl == "fused": - x = input_ - if not isinstance(x, Float8TensorStorage): - x = maybe_dequantize(x, dtype=dtype) + x = maybe_dequantize(input_, dtype=dtype) out, mask = tex.dropout_fwd(x, self.dropout_probability) elif impl == "unfused": x = maybe_dequantize(input_, dtype=dtype) diff --git a/transformer_engine/pytorch/pyproject.toml b/transformer_engine/pytorch/pyproject.toml index 0b42b0a8da..0435c0f668 100755 --- a/transformer_engine/pytorch/pyproject.toml +++ b/transformer_engine/pytorch/pyproject.toml @@ -3,7 +3,7 @@ # See LICENSE for license information. [build-system] -requires = ["setuptools>=61.0", "pip", "torch>=2.1"] +requires = ["setuptools>=61.0", "pip", "torch>=2.6"] # Use legacy backend to import local packages in setup.py build-backend = "setuptools.build_meta:__legacy__" diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index a7722f777e..951923d84a 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -741,3 +741,21 @@ def to_dtype(self, dtype: torch.dtype) -> QuantizedTensor: """ return self.__class__.make_like(self, dtype=dtype) + + +# Register stable GEMM op as a passthrough so FP8 tensors are not dequantized +# before entering the stable ABI GEMM kernel. We do this here rather than in +# _stable_torch_module.py to avoid a circular import (quantized_tensor is loaded +# before the stable module has had a chance to run its module-level registration). +def _register_stable_abi_passthrough_ops(): + import sys + + if "transformer_engine.pytorch._stable_torch_module" not in sys.modules: + return + try: + _quantized_tensor_passthrough_ops.add(torch.ops.transformer_engine_stable.gemm.default) + except AttributeError: + pass + + +_register_stable_abi_passthrough_ops() diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index 99f6a99efa..0fb6c925f2 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -48,7 +48,7 @@ from build_tools.utils import copy_common_headers, min_python_version_str from build_tools.te_version import te_version from build_tools.pytorch import ( - setup_pytorch_extension, + setup_pytorch_stable_extension, install_requirements, test_requirements, ) @@ -151,11 +151,12 @@ def run(self): # Extensions common_headers_dir = "common_headers" copy_common_headers(current_file_path.parent, str(current_file_path / common_headers_dir)) - ext_modules = [ - setup_pytorch_extension( - "csrc", current_file_path / "csrc", current_file_path / common_headers_dir - ) - ] + ext_modules = [] + stable_ext = setup_pytorch_stable_extension( + "csrc", current_file_path / "csrc", current_file_path / common_headers_dir + ) + if stable_ext is not None: + ext_modules.append(stable_ext) # Setup version and requirements. # Having the framework extension depend on the core lib allows diff --git a/transformer_engine/pytorch/tensor/_extract.py b/transformer_engine/pytorch/tensor/_extract.py new file mode 100644 index 0000000000..f1f3a26e29 --- /dev/null +++ b/transformer_engine/pytorch/tensor/_extract.py @@ -0,0 +1,226 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Extract raw tensor data + quantization metadata from quantized tensor types. + +Used by the stable ABI Python shim to convert quantized tensors into +raw buffers that can be passed to stable ABI ops. +""" + +import torch + +# TE DType values (must match transformer_engine/transformer_engine.h) +_DTYPE_MAP = { + torch.float32: 4, # kFloat32 + torch.float16: 5, # kFloat16 + torch.bfloat16: 6, # kBFloat16 + torch.uint8: 0, # kByte (used for FP8 storage) + torch.int32: 2, # kInt32 + torch.int64: 3, # kInt64 + torch.bool: 0, # kByte +} + +# FP8 dtype enum values — maps various string representations to TE dtype ints +_FP8_DTYPE_TO_TE = { + "fp8e4m3": 7, # kFloat8E4M3 + "fp8e5m2": 8, # kFloat8E5M2 + "torch.float8_e4m3fn": 7, + "torch.float8_e5m2": 8, + "float8_e4m3fn": 7, + "float8_e5m2": 8, + # Integer DType enum values (set by _stable_torch_module.py) + "7": 7, + "8": 8, +} + +# Scaling mode values (must match transformer_engine.h NVTEScalingMode enum) +NVTE_DELAYED_TENSOR_SCALING = 0 +NVTE_MXFP8_1D_SCALING = 1 +NVTE_BLOCK_SCALING_1D = 2 +NVTE_BLOCK_SCALING_2D = 3 +NVTE_NVFP4_1D_SCALING = 4 + + +def _detect_scaling_mode(tensor): + """Detect the NVTEScalingMode for a quantized tensor. + + Checks tensor attributes (_is_2D_scaled, _block_scaling_dim) first, + then falls back to type-name detection for MXFP8/NVFP4 tensors which + lack those attributes. + """ + if hasattr(tensor, "_is_2D_scaled"): + return NVTE_BLOCK_SCALING_2D if tensor._is_2D_scaled else NVTE_BLOCK_SCALING_1D + if hasattr(tensor, "_block_scaling_dim"): + return NVTE_BLOCK_SCALING_2D if tensor._block_scaling_dim == 2 else NVTE_BLOCK_SCALING_1D + cls_name = type(tensor).__name__ + if "MXFP8" in cls_name: + return NVTE_MXFP8_1D_SCALING + if "NVFP4" in cls_name: + return NVTE_NVFP4_1D_SCALING + quantizer = getattr(tensor, "_quantizer", None) + if quantizer is not None: + q_name = type(quantizer).__name__ + if "MXFP8" in q_name: + return NVTE_MXFP8_1D_SCALING + if "NVFP4" in q_name: + return NVTE_NVFP4_1D_SCALING + if "Block" in q_name: + block_dim = getattr(quantizer, "block_scaling_dim", None) + return NVTE_BLOCK_SCALING_2D if block_dim == 2 else NVTE_BLOCK_SCALING_1D + return NVTE_DELAYED_TENSOR_SCALING + + +def extract_tensor_data(tensor): + """Extract raw data, dtype, scale_inv, and scaling_mode from a tensor. + + For regular PyTorch tensors, returns the tensor as-is with default metadata. + For quantized TE tensor types, extracts the underlying raw buffers. + + Returns: + (data, te_dtype, scale_inv, scaling_mode) + """ + # Check for quantized TE tensor types FIRST (they subclass torch.Tensor) + # TE quantized tensors have _rowwise_data or _data attributes + if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None: + data = tensor._rowwise_data + scale_inv = getattr(tensor, "_rowwise_scale_inv", None) + fp8_dtype = getattr(tensor, "_fp8_dtype", None) + cls_name = type(tensor).__name__ + # Detect NVFP4 tensors first (they don't have _fp8_dtype or _is_2D_scaled) + if "NVFP4" in cls_name: + te_dtype = 10 # kFloat4E2M1 + sm = NVTE_NVFP4_1D_SCALING + else: + te_dtype = 0 # kByte + if fp8_dtype is not None: + te_dtype = _FP8_DTYPE_TO_TE.get(str(fp8_dtype), 7) + if hasattr(tensor, "_is_2D_scaled"): + sm = NVTE_BLOCK_SCALING_2D if tensor._is_2D_scaled else NVTE_BLOCK_SCALING_1D + elif hasattr(tensor, "_block_scaling_dim"): + sm = ( + NVTE_BLOCK_SCALING_2D + if tensor._block_scaling_dim == 2 + else NVTE_BLOCK_SCALING_1D + ) + elif "MXFP8" in cls_name: + sm = NVTE_MXFP8_1D_SCALING + else: + sm = NVTE_DELAYED_TENSOR_SCALING + return data, te_dtype, scale_inv, sm + + # Columnwise-only block-scaling tensor (after update_usage(rowwise_usage=False)). + # _rowwise_data is None but _columnwise_data exists — return columnwise data with + # correct scaling_mode so callers don't fall through to the generic tensor path (sm=0). + if ( + hasattr(tensor, "_rowwise_data") + and tensor._rowwise_data is None + and hasattr(tensor, "_columnwise_data") + and tensor._columnwise_data is not None + and not hasattr(tensor, "_data") + ): + col_data = tensor._columnwise_data + col_si = getattr(tensor, "_columnwise_scale_inv", None) + fp8_dtype = getattr(tensor, "_fp8_dtype", None) + cls_name = type(tensor).__name__ + if "NVFP4" in cls_name: + te_dtype = 10 # kFloat4E2M1 + sm = NVTE_NVFP4_1D_SCALING + else: + te_dtype = 0 # kByte + if fp8_dtype is not None: + te_dtype = _FP8_DTYPE_TO_TE.get(str(fp8_dtype), 7) + if hasattr(tensor, "_is_2D_scaled"): + sm = NVTE_BLOCK_SCALING_2D if tensor._is_2D_scaled else NVTE_BLOCK_SCALING_1D + elif hasattr(tensor, "_block_scaling_dim"): + sm = ( + NVTE_BLOCK_SCALING_2D + if tensor._block_scaling_dim == 2 + else NVTE_BLOCK_SCALING_1D + ) + elif "MXFP8" in cls_name: + sm = NVTE_MXFP8_1D_SCALING + else: + sm = NVTE_DELAYED_TENSOR_SCALING + return col_data, te_dtype, col_si, sm + + if hasattr(tensor, "_data") and tensor._data is not None: + data = tensor._data + scale_inv = getattr(tensor, "_scale_inv", None) + fp8_dtype = getattr(tensor, "_fp8_dtype", None) + te_dtype = 0 + if fp8_dtype is not None: + te_dtype = _FP8_DTYPE_TO_TE.get(str(fp8_dtype), 7) + return data, te_dtype, scale_inv, NVTE_DELAYED_TENSOR_SCALING + + if isinstance(tensor, torch.Tensor): + # Regular PyTorch tensor + te_dtype = _DTYPE_MAP.get(tensor.dtype, 4) # default kFloat32 + return tensor, te_dtype, None, NVTE_DELAYED_TENSOR_SCALING + + # Try Float8TensorStorage + try: + from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import ( + Float8TensorStorage, + ) + + if isinstance(tensor, Float8TensorStorage): + data = tensor._data # uint8 tensor + scale_inv = tensor._scale_inv # float32 tensor + fp8_dtype = str(tensor._fp8_dtype) + te_dtype = _FP8_DTYPE_TO_TE.get(fp8_dtype, 7) # default e4m3 + return data, te_dtype, scale_inv, NVTE_DELAYED_TENSOR_SCALING + except ImportError: + pass + + # Try MXFP8TensorStorage + try: + from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import ( + MXFP8TensorStorage, + ) + + if isinstance(tensor, MXFP8TensorStorage): + data = tensor._rowwise_data + scale_inv = tensor._rowwise_scale_inv + fp8_dtype = str(tensor._fp8_dtype) + te_dtype = _FP8_DTYPE_TO_TE.get(fp8_dtype, 7) + return data, te_dtype, scale_inv, NVTE_MXFP8_1D_SCALING + except ImportError: + pass + + # Try Float8BlockwiseQTensorStorage + try: + from transformer_engine.pytorch.tensor.storage.float8_blockwise_tensor_storage import ( + Float8BlockwiseQTensorStorage, + ) + + if isinstance(tensor, Float8BlockwiseQTensorStorage): + data = tensor._rowwise_data + scale_inv = tensor._rowwise_scale_inv + fp8_dtype = str(tensor._fp8_dtype) + te_dtype = _FP8_DTYPE_TO_TE.get(fp8_dtype, 7) + # Check 1D vs 2D block scaling + sm = NVTE_BLOCK_SCALING_2D if tensor._is_2D_scaled else NVTE_BLOCK_SCALING_1D + return data, te_dtype, scale_inv, sm + except ImportError: + pass + + # Try NVFP4TensorStorage + try: + from transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage import ( + NVFP4TensorStorage, + ) + + if isinstance(tensor, NVFP4TensorStorage): + data = tensor._rowwise_data + scale_inv = tensor._rowwise_scale_inv + te_dtype = 10 # kFloat4E2M1 + return data, te_dtype, scale_inv, NVTE_NVFP4_1D_SCALING + except ImportError: + pass + + # Fallback: treat as regular tensor + if hasattr(tensor, "data") and isinstance(tensor.data, torch.Tensor): + return tensor.data, _DTYPE_MAP.get(tensor.data.dtype, 4), None, NVTE_DELAYED_TENSOR_SCALING + + raise TypeError(f"Cannot extract tensor data from type {type(tensor)}") diff --git a/transformer_engine/pytorch/tensor/_quantize_stable.py b/transformer_engine/pytorch/tensor/_quantize_stable.py new file mode 100644 index 0000000000..12eb37bd51 --- /dev/null +++ b/transformer_engine/pytorch/tensor/_quantize_stable.py @@ -0,0 +1,585 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Stable ABI quantize implementation for TE quantizer classes. + +Replaces tex.quantize(tensor, quantizer, output, noop) calls with +direct calls to stable ABI ops, eliminating the pybind11 dependency. +""" + +import torch + +from transformer_engine.pytorch.tensor._extract import extract_tensor_data + +# Load stable ops +_ops = None + + +def _get_ops(): + global _ops + if _ops is not None: + return _ops + import glob + import importlib.util + from pathlib import Path + + te_spec = importlib.util.find_spec("transformer_engine") + if te_spec is not None and te_spec.origin is not None: + te_dir = Path(te_spec.origin).parent + candidates = glob.glob(str(te_dir / "te_stable_abi*")) + if candidates: + torch.ops.load_library(candidates[0]) + _ops = torch.ops.transformer_engine_stable + return _ops + + +def _maybe_allreduce_amax(quantizer, amax_tensors): + """All-reduce amax tensors across ranks if the quantizer requires it. + + The pybind path did this in C++ via process_group->allreduce(). We replicate + it here in Python using torch.distributed. + """ + if not getattr(quantizer, "with_amax_reduction", False): + return + group = getattr(quantizer, "amax_reduction_group", None) + if group is None: + return + # Canonicalize the process group if needed + if hasattr(quantizer, "_canonicalized_amax_reduction_group"): + group = quantizer._canonicalized_amax_reduction_group() + import torch.distributed as dist + + for amax in amax_tensors: + if amax is not None and isinstance(amax, torch.Tensor) and amax.numel() > 0: + dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=group) + + +def quantize_into(src, quantizer, dst, noop_flag=None): + """Quantize src into pre-allocated dst using stable ABI ops. + + Replaces: tex.quantize(src, quantizer, dst, noop_flag) + """ + ops = _get_ops() + + # Early return for empty tensors + if src.numel() == 0: + return + + # Ensure contiguous input + if not src.is_contiguous(): + src = src.contiguous() + + # Helper: transpose src to match columnwise layout [K, *M_dims]. + # get_columnwise_shape puts the last dim first, rest in original order. + def _transpose_for_colwise(t): + if t.ndim == 2: + return t.T.contiguous() + # For ndim >= 3: put last dim first, keep remaining dims in order + perm = [t.ndim - 1] + list(range(t.ndim - 1)) + return t.permute(*perm).contiguous() + + # Handle columnwise-only Float8BlockwiseQTensor destination. + # When _rowwise_data=None but _columnwise_data exists, we quantize the + # transposed input (to match the [K,M] columnwise layout) into _columnwise_data. + _col_only = ( + hasattr(dst, "_rowwise_data") + and getattr(dst, "_rowwise_data", None) is None + and hasattr(dst, "_columnwise_data") + and getattr(dst, "_columnwise_data", None) is not None + and not hasattr(dst, "_data") # exclude Float8Tensor (which uses _data/_transpose) + ) + # For MXFP8/NVFP4 columnwise-only: allocate temporary rowwise buffers so the + # bidirectional kernel can fill both. The GEMM only uses columnwise. + _col_only_bidir = _col_only and ( + "MXFP8" in type(quantizer).__name__ or "NVFP4" in type(quantizer).__name__ + ) + if _col_only_bidir: + from transformer_engine.pytorch.utils import round_up_to_nearest_multiple + import math + + q_name = type(quantizer).__name__ + col_data = dst._columnwise_data + col_si = dst._columnwise_scale_inv + + shape = list(src.shape) + M = math.prod(shape[:-1]) + K = shape[-1] + + if "MXFP8" in q_name: + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8_BLOCK_SCALING_SIZE + + BLOCK = MXFP8_BLOCK_SCALING_SIZE + tmp_rw_data = torch.empty(M, K, dtype=torch.uint8, device=src.device) + si_shape = ( + round_up_to_nearest_multiple(M, 128), + round_up_to_nearest_multiple(K // BLOCK, 4), + ) + tmp_rw_si = torch.empty(si_shape, dtype=torch.uint8, device=src.device) + sm = 1 # MXFP8_1D_SCALING + te_dtype = 7 # kFloat8E4M3 + else: + # NVFP4 + BLOCK = 16 # NVFP4_BLOCK_SCALING_SIZE + tmp_rw_data = torch.empty(M, K // 2, dtype=torch.uint8, device=src.device) + si_shape = ( + round_up_to_nearest_multiple(M, 128), + round_up_to_nearest_multiple(K // BLOCK, 4), + ) + tmp_rw_si = torch.empty(si_shape, dtype=torch.uint8, device=src.device) + sm = 4 # NVFP4_1D_SCALING + te_dtype = 10 # kFloat4E2M1 + + fp8_dtype_attr = getattr(dst, "_fp8_dtype", None) + if fp8_dtype_attr is not None: + from transformer_engine.pytorch.tensor._extract import _FP8_DTYPE_TO_TE + + te_dtype = _FP8_DTYPE_TO_TE.get(str(fp8_dtype_attr), te_dtype) + + # For NVFP4, compute amax first + amax_tmp = None + if "NVFP4" in q_name: + amax_tmp = getattr(dst, "_amax_rowwise", None) + if amax_tmp is None: + amax_tmp = torch.zeros(1, dtype=torch.float32, device=src.device) + ops.compute_amax(src, amax_tmp) + _maybe_allreduce_amax(quantizer, [amax_tmp]) + + nvfp4_2d = getattr(quantizer, "with_2d_quantization", False) + force_pow_2 = getattr(quantizer, "force_pow_2_scales", False) + amax_eps = getattr(quantizer, "amax_epsilon", 0.0) + + ops.quantize_bidirectional( + src, + tmp_rw_data, + te_dtype, + amax_tmp, + None, + tmp_rw_si, + col_data, + col_si, + sm, + force_pow_2, + amax_eps, + noop_flag, + nvfp4_2d, + ) + + # Copy amax for NVFP4 + if "NVFP4" in q_name and amax_tmp is not None: + amax_cw = getattr(dst, "_amax_columnwise", None) + if amax_cw is not None: + amax_cw.copy_(amax_tmp) + # Replace zero amax with default value (GPU-only, no .item() sync). + amax_tmp.masked_fill_(amax_tmp == 0.0, 6.0 * 448.0) + + dst._with_gemm_swizzled_scales = False + return + if _col_only: + col_data = dst._columnwise_data + col_si = getattr(dst, "_columnwise_scale_inv", None) + fp8_dtype_attr = getattr(dst, "_fp8_dtype", None) + from transformer_engine.pytorch.tensor._extract import _FP8_DTYPE_TO_TE + + out_dtype = _FP8_DTYPE_TO_TE.get(str(fp8_dtype_attr), 7) if fp8_dtype_attr else 7 + q_type_col = type(quantizer).__name__ + if "NVFP4" in q_type_col: + out_sm = 4 # NVFP4_1D_SCALING + else: + block_dim = getattr(quantizer, "block_scaling_dim", 2) + out_sm = 3 if block_dim == 2 else 2 # BLOCK_SCALING_2D=3, BLOCK_1D=2 + force_pow_2 = getattr(quantizer, "force_pow_2_scales", False) + amax_eps = getattr(quantizer, "amax_epsilon", 0.0) + if ( + hasattr(quantizer, "block_scaling_dim") + and getattr(quantizer, "block_scaling_dim", 2) == 2 + ): + # 2D block scaling: quantize src (original shape) → tmp rowwise buffer, + # then FP8-transpose into col_data and transpose the scale. + # Do NOT pass src_transposed to ops.quantize: the kernel computes scale + # block count from the input tensor shape, and src_transposed has a + # different shape (e.g. 512×128×1) than what col_si expects (e.g. 1×4). + rowwise_scale_shape = quantizer.get_scale_shape(list(src.shape), columnwise=False) + tmp_si = torch.empty(rowwise_scale_shape, dtype=torch.float32, device=src.device) + tmp_rowwise = col_data.new_empty(list(src.shape)) # uint8, same shape as src + ops.quantize( + src, + tmp_rowwise, + out_dtype, + None, + None, + tmp_si, + out_sm, + force_pow_2, + amax_eps, + noop_flag, + ) + ops.fp8_transpose(tmp_rowwise, out_dtype, col_data) + if col_si is not None: + col_si.zero_() + transposed_si = tmp_si.T.contiguous() + h = min(col_si.shape[0], transposed_si.shape[0]) + w = min(col_si.shape[1], transposed_si.shape[1]) + col_si[0:h, 0:w].copy_(transposed_si[0:h, 0:w]) + else: + # 1D block scaling: the kernel must see src_transposed as (K, M) 2D. + # _transpose_for_colwise gives (K, *M_dims) which the kernel would + # flatten to (K*M_rest, M_last) — wrong shape for col_si. + # Reshape to (K=last dim of src, M=all other dims) so the kernel sees + # (dim0=K, dim1=M) and produces the correct per-row scale shape. + K = src.shape[-1] + M = src.numel() // K + src_transposed_2d = _transpose_for_colwise(src).view(K, M) + ops.quantize( + src_transposed_2d, + col_data, + out_dtype, + None, + None, + col_si, + out_sm, + force_pow_2, + amax_eps, + noop_flag, + ) + dst._fp8_dtype = quantizer.dtype if hasattr(quantizer, "dtype") else dst._fp8_dtype + return + + # Extract raw output buffers from dst + out_data, out_dtype, out_scale_inv, out_sm = extract_tensor_data(dst) + + # Override scaling mode from quantizer if available (more reliable than tensor attrs) + q_type = type(quantizer).__name__ + if "Block" in q_type: + block_dim = getattr(quantizer, "block_scaling_dim", 2) + out_sm = 3 if block_dim == 2 else 2 # BLOCK_SCALING_2D=3 or 1D=2 + elif "MXFP8" in q_type: + out_sm = 1 # MXFP8_1D_SCALING=1 + elif "NVFP4" in q_type: + out_sm = 4 # NVFP4_1D_SCALING=4 + elif "CurrentScaling" in q_type: + out_sm = 0 # DELAYED_TENSOR_SCALING (current scaling uses delayed mode internally) + + # Get scale/amax from quantizer + scale = getattr(quantizer, "scale", None) + amax = getattr(quantizer, "amax", None) + if scale is not None and (not isinstance(scale, torch.Tensor) or scale.numel() == 0): + scale = None + if amax is not None and (not isinstance(amax, torch.Tensor) or amax.numel() == 0): + amax = None + + # Also check output for amax + if amax is None: + for attr in ("_amax", "_amax_rowwise", "amax_rowwise"): + a = getattr(dst, attr, None) + if isinstance(a, torch.Tensor) and a.numel() > 0: + amax = a + break + force_pow_2 = getattr(quantizer, "force_pow_2_scales", False) + amax_eps = getattr(quantizer, "amax_epsilon", 0.0) + use_existing_amax = getattr(quantizer, "use_existing_amax", False) + + # For CurrentScaling, zero amax/scale to avoid stale values from a + # previous quantization affecting the current one. Skip when + # use_existing_amax is set because the caller already computed amax + # (e.g. fused activation+amax path) and zeroing would destroy it. + if "CurrentScaling" in q_type and not use_existing_amax: + if amax is not None: + amax.zero_() + if scale is not None: + scale.zero_() + nvfp4_2d = getattr(quantizer, "with_2d_quantization", False) + q_type = type(quantizer).__name__ + + # Only pass scale_inv for FP8/FP4 output dtypes. The C++ CheckOutputTensor + # asserts that scale_inv must NOT be set for non-FP8 outputs. + is_fp8_or_fp4 = out_dtype in (7, 8, 9, 10) # kFloat8E4M3, kFloat8E5M2, kFloat8E8M0, kFloat4E2M1 + effective_scale_inv = out_scale_inv if is_fp8_or_fp4 else None + + # For MXFP8/NVFP4 with both rowwise and columnwise pre-allocated, use the fused + # bidirectional kernel that fills both buffers in one nvte_quantize_v2 call. + # This is essential because GEMM with NT layout reads columnwise data. + # For NVFP4, the columnwise data must be independently quantized (not just + # transposed from rowwise) because the per-block scales differ. + _bidir = ( + ("MXFP8" in q_type or "NVFP4" in q_type) + and hasattr(dst, "_rowwise_data") + and getattr(dst, "_rowwise_data", None) is not None + and hasattr(dst, "_columnwise_data") + and getattr(dst, "_columnwise_data", None) is not None + and hasattr(dst, "_columnwise_scale_inv") + and getattr(dst, "_columnwise_scale_inv", None) is not None + and out_scale_inv is not None + ) + if _bidir: + # For NVFP4, compute amax before quantization (the kernel doesn't do it) + if "NVFP4" in q_type and amax is not None: + ops.compute_amax(src, amax) + _maybe_allreduce_amax(quantizer, [amax]) + col_data = dst._columnwise_data + col_si = dst._columnwise_scale_inv + ops.quantize_bidirectional( + src, + out_data, + out_dtype, + amax, + scale, + out_scale_inv, + col_data, + col_si, + out_sm, + force_pow_2, + amax_eps, + noop_flag, + nvfp4_2d, + ) + # Set with_2d_quantization config for NVFP4 if needed + if "NVFP4" in q_type: + # Copy rowwise amax to columnwise amax + amax_rw = getattr(dst, "_amax_rowwise", None) + amax_cw = getattr(dst, "_amax_columnwise", None) + if amax_rw is not None and amax_cw is not None: + amax_cw.copy_(amax_rw) + elif amax_rw is not None and amax_cw is None: + dst._amax_columnwise = amax_rw.clone() + # Safety fallback for zero amax (GPU-only, no .item() sync). + if amax is not None: + amax.masked_fill_(amax == 0.0, 6.0 * 448.0) + # Ensure swizzle flag is False (stable path doesn't swizzle during quantize) + if hasattr(dst, "_with_gemm_swizzled_scales"): + dst._with_gemm_swizzled_scales = False + return + + if use_existing_amax and amax is not None: + ops.quantize_from_amax( + src, + out_data, + out_dtype, + amax, + scale if scale is not None else torch.ones(1, dtype=torch.float32, device=src.device), + effective_scale_inv, + out_sm, + force_pow_2, + amax_eps, + noop_flag, + ) + elif "CurrentScaling" in q_type: + if amax is None: + amax = torch.zeros(1, dtype=torch.float32, device=src.device) + if scale is None: + scale = torch.zeros(1, dtype=torch.float32, device=src.device) + if getattr(quantizer, "with_amax_reduction", False) and src.is_floating_point(): + # Distributed path: compute local amax, all-reduce to get global amax, + # then quantize with the global amax. This matches the pybind path which + # called nvte_compute_amax → allreduce(MAX) → nvte_compute_scale → nvte_quantize. + # Skip for non-float inputs (e.g., FP8 uint8 data) which can't compute amax. + ops.compute_amax(src, amax) + _maybe_allreduce_amax(quantizer, [amax]) + ops.quantize_from_amax( + src, + out_data, + out_dtype, + amax, + scale, + effective_scale_inv, + out_sm, + force_pow_2, + amax_eps, + noop_flag, + ) + else: + ops.quantize_with_amax( + src, + out_data, + out_dtype, + amax, + scale, + effective_scale_inv, + out_sm, + force_pow_2, + amax_eps, + noop_flag, + ) + else: + # For NVFP4, nvte_quantize_v2 does NOT compute amax internally. + # The pybind path calls nvte_compute_amax_with_config first, then quantizes. + # Replicate that by computing amax before quantization. + if "NVFP4" in q_type and amax is not None: + ops.compute_amax(src, amax) + _maybe_allreduce_amax(quantizer, [amax]) + ops.quantize( + src, + out_data, + out_dtype, + amax, + scale, + effective_scale_inv, + out_sm, + force_pow_2, + amax_eps, + noop_flag, + nvfp4_2d, + ) + + # For NVFP4 dequantize: output = fp4_value * scale_e4m3 * amax / (6 * 448). + # With correct amax from compute_amax above, this formula is already correct. + # But if amax is still 0 (e.g. empty tensor), set it to 2688 as a safety fallback. + # Safety fallback for zero amax in NVFP4 (GPU-only, no .item() sync). + if "NVFP4" in q_type and amax is not None: + amax.masked_fill_(amax == 0.0, 6.0 * 448.0) + + # The stable ABI quantize path does not swizzle MXFP8/NVFP4 scales during + # quantization. Ensure the flag is False so the GEMM C++ code will swizzle + # on-the-fly. This overrides any True value that may have been set during + # tensor construction via optimize_for_gemm=True on the quantizer. + if hasattr(dst, "_with_gemm_swizzled_scales"): + dst._with_gemm_swizzled_scales = False + + # For Float8Tensor (delayed scaling), mark transpose as invalid after filling + # rowwise data. Callers like Float8Quantizer.update_quantized / quantize_impl + # call _create_transpose() immediately after to fill it, matching pybind11 + # tex.quantize() behavior. Other callers (e.g. quantize_new) rely on + # update_usage(columnwise_usage=True) to lazily create the transpose. + if hasattr(dst, "_transpose_invalid"): + dst._transpose_invalid = True + + # For block-scaling tensors with both rowwise AND columnwise pre-allocated, + # also fill the columnwise buffer. The pybind path filled both in one fused + # nvte_quantize_v2 kernel. The stable path fills rowwise above, then derives + # columnwise by FP8-transposing the quantized bytes and transposing the scales. + # This matches _create_columnwise() in float8_blockwise_tensor_storage.py. + _has_colwise = ( + hasattr(dst, "_rowwise_data") + and getattr(dst, "_rowwise_data", None) is not None + and hasattr(dst, "_columnwise_data") + and getattr(dst, "_columnwise_data", None) is not None + and not hasattr(dst, "_data") # exclude Float8Tensor (uses _transpose/_create_transpose) + ) + if _has_colwise and ("Block" in q_type or "NVFP4" in q_type): + col_data = dst._columnwise_data + col_si = getattr(dst, "_columnwise_scale_inv", None) + fp8_dtype_attr = getattr(dst, "_fp8_dtype", None) + from transformer_engine.pytorch.tensor._extract import _FP8_DTYPE_TO_TE + + col_dtype = ( + _FP8_DTYPE_TO_TE.get(str(fp8_dtype_attr), out_dtype) if fp8_dtype_attr else out_dtype + ) + if "NVFP4" in q_type: + # NVFP4 columnwise: derive from rowwise data by transposing the + # already-quantized FP4 bytes and scales. This matches the pybind + # path (_create_columnwise in nvfp4_tensor_storage.py) which uses + # nvfp4_data_transpose + nvfp4_2d_scale_transpose. + # nvfp4_data_transpose expects 2D [M, K_bytes]; flatten leading dims + rd = out_data.reshape(-1, out_data.shape[-1]) if out_data.ndim > 2 else out_data + ops.nvfp4_data_transpose(rd, col_data) + if col_si is not None and out_scale_inv is not None: + logical_shape = list(src.shape) + M_val = 1 + for d in logical_shape[:-1]: + M_val *= d + K_val = logical_shape[-1] + TILE_SIZE = 16 + M_tiles = (M_val + TILE_SIZE - 1) // TILE_SIZE + K_tiles = (K_val + TILE_SIZE - 1) // TILE_SIZE + ops.nvfp4_2d_scale_transpose(out_scale_inv, col_si, M_tiles, K_tiles) + # Copy rowwise amax to columnwise amax (matches _create_columnwise + # in nvfp4_tensor_storage.py:445-447). cuBLAS NVFP4 GEMM uses amax + # in the formula: out = fp4 * scale * amax / (6*448). + amax_rw = getattr(dst, "_amax_rowwise", None) + amax_cw = getattr(dst, "_amax_columnwise", None) + if amax_rw is not None and amax_cw is not None: + amax_cw.copy_(amax_rw) + elif amax_rw is not None and amax_cw is None: + dst._amax_columnwise = amax_rw.clone() + else: + block_dim = getattr(quantizer, "block_scaling_dim", 2) + if block_dim == 2: + # 2D block scaling: columnwise scale = transposed rowwise scale. + # FP8-transpose the quantized bytes (identical to _create_columnwise) + ops.fp8_transpose(out_data, col_dtype, col_data) + # Transpose the rowwise scale_inv into the columnwise scale_inv buffer + if col_si is not None and out_scale_inv is not None: + col_si.zero_() + transposed_si = out_scale_inv.T.contiguous() + h = min(col_si.shape[0], transposed_si.shape[0]) + w = min(col_si.shape[1], transposed_si.shape[1]) + col_si[0:h, 0:w].copy_(transposed_si[0:h, 0:w]) + else: + # 1D block scaling: columnwise scale ≠ transposed rowwise scale (they cover + # different block directions). Quantize src in "columnwise mode" by reshaping + # the transposed src to (K, M) and calling ops.quantize in ROWWISE mode. + # This gives per-K-block scales == the per-M-block columnwise scales we need. + if col_si is not None and src.ndim >= 1: + K = src.shape[-1] + M = src.numel() // K + src_transposed_2d = _transpose_for_colwise(src).view(K, M) + ops.quantize( + src_transposed_2d, + col_data, + col_dtype, + None, + None, + col_si, + out_sm, + force_pow_2, + amax_eps, + noop_flag, + ) + + +def quantize_new(tensor, quantizer): + """Allocate output and quantize tensor using stable ABI ops. + + Replaces: return tex.quantize(tensor, quantizer) + """ + # Ensure contiguous + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + + # MXFP8 requires dimensions divisible by block size (32). The pybind fused + # C++ kernel handles non-aligned sizes internally by padding. In the stable + # path we pad the input, quantize, then slice back to the original shape. + _MXFP8_BLOCK = 32 + padded = False + orig_shape = list(tensor.shape) + q_type = type(quantizer).__name__ + if "MXFP8" in q_type and len(orig_shape) >= 2: + last_dim = orig_shape[-1] + first_dims_prod = 1 + for d in orig_shape[:-1]: + first_dims_prod *= d + need_pad_last = last_dim % _MXFP8_BLOCK != 0 + need_pad_first = first_dims_prod % _MXFP8_BLOCK != 0 + if need_pad_last or need_pad_first: + pad_last = (_MXFP8_BLOCK - last_dim % _MXFP8_BLOCK) % _MXFP8_BLOCK + # Flatten to 2D for padding, then reshape back + flat = tensor.reshape(first_dims_prod, last_dim) + pad_first = (_MXFP8_BLOCK - first_dims_prod % _MXFP8_BLOCK) % _MXFP8_BLOCK + if pad_last > 0 or pad_first > 0: + flat = torch.nn.functional.pad(flat, (0, pad_last, 0, pad_first)) + tensor = flat # keep as 2D for quantize + padded = True + + # Allocate output via quantizer's make_empty (pure Python) + dst = quantizer.make_empty(list(tensor.shape), dtype=tensor.dtype, device=tensor.device) + + # Quantize into the new output + quantize_into(tensor, quantizer, dst) + + # If we padded, slice the quantized output back to the original shape + if padded: + first_dims_prod = 1 + for d in orig_shape[:-1]: + first_dims_prod *= d + # Slice back to original 2D shape, then restore original dims + if hasattr(dst, "_rowwise_data") and dst._rowwise_data is not None: + dst._rowwise_data = dst._rowwise_data[:first_dims_prod, : orig_shape[-1]] + if hasattr(dst, "_rowwise_scale_inv") and dst._rowwise_scale_inv is not None: + si = dst._rowwise_scale_inv + # Scale has ceil(M/32) rows and ceil(K/32) cols + orig_si_rows = (first_dims_prod + _MXFP8_BLOCK - 1) // _MXFP8_BLOCK + orig_si_cols = (orig_shape[-1] + _MXFP8_BLOCK - 1) // _MXFP8_BLOCK + if si.ndim == 2: + dst._rowwise_scale_inv = si[:orig_si_rows, :orig_si_cols] + + return dst diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index bbfc43e9bb..b59055dc64 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -10,7 +10,7 @@ from typing import Any, Optional, Tuple, Union import torch -import transformer_engine_torch as tex + from transformer_engine_torch import DType as TE_DType from transformer_engine.common.recipe import Float8BlockScaling, Recipe from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage @@ -99,6 +99,9 @@ def update_quantized( AssertionError If the destination tensor is not a Float8BlockwiseQTensor """ + # Handle non-torch.Tensor inputs (e.g. DebugQuantizedTensor from debug mode backward pass) + if not isinstance(src, torch.Tensor) and hasattr(src, "dequantize"): + src = src.dequantize() assert isinstance( dst, Float8BlockwiseQTensor ), f"Cannot store quantized blockwise tensor in {type(dst)} type." @@ -108,15 +111,26 @@ def update_quantized( if not src.is_contiguous(): src = src.contiguous() - # Launch cast kernel - tex.quantize(src, self, dst, noop_flag) + # Launch cast kernel via stable ABI + from transformer_engine.pytorch.tensor._quantize_stable import quantize_into + + quantize_into(src, self, dst, noop_flag) dst._fp8_dtype = self.dtype return dst def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: - """Quantize tensor implementation""" - return tex.quantize(tensor, self) + """Quantize tensor implementation via stable ABI""" + from transformer_engine.pytorch.tensor._quantize_stable import quantize_into + + # Handle non-torch.Tensor inputs (e.g. DebugQuantizedTensor from debug mode backward pass) + if not isinstance(tensor, torch.Tensor) and hasattr(tensor, "dequantize"): + tensor = tensor.dequantize() + dst = self.make_empty(list(tensor.shape), dtype=tensor.dtype, device=tensor.device) + if tensor.numel() > 0: + t = tensor.contiguous() if not tensor.is_contiguous() else tensor + quantize_into(t, self, dst) + return dst def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]: """Scaling tensor shape. diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index e8284eaa53..fbf74604e4 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -8,7 +8,6 @@ import warnings import torch from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState -import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType from transformer_engine.common.recipe import ( @@ -94,23 +93,62 @@ def update_quantized( if not isinstance(dst, Float8Tensor): raise ValueError("Float8Quantizer can only update Float8Tensor") + # Handle non-torch.Tensor inputs (e.g. DebugQuantizedTensor from debug mode backward pass) + if not isinstance(src, torch.Tensor) and hasattr(src, "dequantize"): + src = src.dequantize() # Make sure input is in expected format if not devices_match(src.device, dst.device): src = src.to(device=dst.device) if not src.is_contiguous(): src = src.contiguous() - # Launch cast kernel - tex.quantize(src, self, dst, noop_flag) + # Launch cast kernel via stable ABI + from transformer_engine.pytorch.tensor._quantize_stable import quantize_into + + quantize_into(src, self, dst, noop_flag) # Update FP8 dtype dst._fp8_dtype = self.dtype + # quantize_into only fills rowwise data (_data). Recompute the + # transpose from _data so that _transpose is valid. + if dst._transpose is not None: + if dst._data.ndim <= 1: + dst._transpose.copy_(dst._data) + dst._transpose_invalid = False + else: + dst._create_transpose() + else: + dst._transpose_invalid = True + return dst def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: - """Quantize tensor implementation""" - return tex.quantize(tensor, self) + """Quantize tensor implementation via stable ABI""" + from transformer_engine.pytorch.tensor._quantize_stable import quantize_into + + # Handle non-torch.Tensor inputs (e.g. DebugQuantizedTensor from debug mode backward pass) + if not isinstance(tensor, torch.Tensor) and hasattr(tensor, "dequantize"): + tensor = tensor.dequantize() + dst = self.make_empty(list(tensor.shape), dtype=tensor.dtype, device=tensor.device) + # Initialize scale_inv from quantizer scale (C++ create_tensor does reciprocal(scale)) + if hasattr(self, "scale") and self.scale is not None and self.scale.numel() > 0: + dst._scale_inv.copy_(1.0 / self.scale) + if tensor.numel() > 0: + t = tensor.contiguous() if not tensor.is_contiguous() else tensor + quantize_into(t, self, dst) + # quantize_into only fills rowwise data (_data). Recompute the + # transpose from _data so that _transpose is valid. + if dst._transpose is not None: + if dst._data.ndim <= 1: + # For 0-dim/1-dim tensors, transpose is identity + dst._transpose.copy_(dst._data) + dst._transpose_invalid = False + else: + dst._create_transpose() + else: + dst._transpose_invalid = True + return dst def make_empty( self, @@ -134,7 +172,7 @@ def make_empty( # Allocate FP8 data transpose if needed data_transpose = None if self.columnwise_usage: - transpose_shape = [shape[-1]] + list(shape[:-1]) + transpose_shape = [shape[-1]] + list(shape[:-1]) if len(shape) > 0 else [] data_transpose = torch.empty( transpose_shape, dtype=torch.uint8, @@ -323,23 +361,59 @@ def update_quantized( if not isinstance(dst, Float8Tensor): raise ValueError("Float8CurrentScalingQuantizer can only update Float8Tensor") + # Handle non-torch.Tensor inputs (e.g. DebugQuantizedTensor from debug mode backward pass) + if not isinstance(src, torch.Tensor) and hasattr(src, "dequantize"): + src = src.dequantize() # Make sure input is in expected format if not devices_match(src.device, dst.device): src = src.to(device=dst.device) if not src.is_contiguous(): src = src.contiguous() - # Launch cast kernel - tex.quantize(src, self, dst, noop_flag) + # Launch cast kernel via stable ABI + from transformer_engine.pytorch.tensor._quantize_stable import quantize_into + + quantize_into(src, self, dst, noop_flag) # Update FP8 dtype dst._fp8_dtype = self.dtype + # quantize_into only fills rowwise data (_data). Recompute the + # transpose from _data so that _transpose is valid. + if dst._transpose is not None: + if dst._data.ndim <= 1: + dst._transpose.copy_(dst._data) + dst._transpose_invalid = False + else: + dst._create_transpose() + else: + dst._transpose_invalid = True + return dst def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: - """Quantize tensor implementation""" - return tex.quantize(tensor, self) + """Quantize tensor implementation via stable ABI""" + from transformer_engine.pytorch.tensor._quantize_stable import quantize_into + + # Handle non-torch.Tensor inputs (e.g. DebugQuantizedTensor from debug mode backward pass) + if not isinstance(tensor, torch.Tensor) and hasattr(tensor, "dequantize"): + tensor = tensor.dequantize() + dst = self.make_empty(list(tensor.shape), dtype=tensor.dtype, device=tensor.device) + if tensor.numel() > 0: + t = tensor.contiguous() if not tensor.is_contiguous() else tensor + quantize_into(t, self, dst) + # quantize_into only fills rowwise data (_data). Recompute the + # transpose from _data so that _transpose is valid. + if dst._transpose is not None: + if dst._data.ndim <= 1: + # For 0-dim/1-dim tensors, transpose is identity + dst._transpose.copy_(dst._data) + dst._transpose_invalid = False + else: + dst._create_transpose() + else: + dst._transpose_invalid = True + return dst def make_empty( self, @@ -363,7 +437,7 @@ def make_empty( # Allocate FP8 data transpose if needed data_transpose = None if self.columnwise_usage: - transpose_shape = [shape[-1]] + list(shape[:-1]) + transpose_shape = [shape[-1]] + list(shape[:-1]) if len(shape) > 0 else [] data_transpose = torch.empty( transpose_shape, dtype=torch.uint8, diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 965f59b320..d4c2a21efe 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -68,23 +68,39 @@ def update_quantized( assert isinstance(dst, MXFP8Tensor), f"Cannot store quantized MXFP8 in {type(dst)} type." + # Handle non-torch.Tensor inputs (e.g. DebugQuantizedTensor from debug mode backward pass) + if not isinstance(src, torch.Tensor) and hasattr(src, "dequantize"): + src = src.dequantize() # Make sure input is in expected format if not devices_match(src.device, dst.device): src = src.to(device=dst.device) if not src.is_contiguous(): src = src.contiguous() - # Launch cast kernel - tex.quantize(src, self, dst, noop_flag) + # Launch cast kernel via stable ABI + from transformer_engine.pytorch.tensor._quantize_stable import quantize_into + + quantize_into(src, self, dst, noop_flag) # Update FP8 dtype dst._fp8_dtype = self.dtype + # The stable ABI quantize path does not swizzle scales, so reset the flag + dst._with_gemm_swizzled_scales = False return dst def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: - """Quantize tensor implementation""" - return tex.quantize(tensor, self) + """Quantize tensor implementation via stable ABI""" + from transformer_engine.pytorch.tensor._quantize_stable import quantize_into + + # Handle non-torch.Tensor inputs (e.g. DebugQuantizedTensor from debug mode backward pass) + if not isinstance(tensor, torch.Tensor) and hasattr(tensor, "dequantize"): + tensor = tensor.dequantize() + dst = self.make_empty(list(tensor.shape), dtype=tensor.dtype, device=tensor.device) + if tensor.numel() > 0: + t = tensor.contiguous() if not tensor.is_contiguous() else tensor + quantize_into(t, self, dst) + return dst def is_quantizable(self, inp: torch.Tensor) -> bool: """Returns whether or not given inp can be quantized""" @@ -157,7 +173,10 @@ def make_empty( columnwise_scale_inv=columnwise_scale_inv, quantizer=self, requires_grad=requires_grad, - with_gemm_swizzled_scales=self.optimize_for_gemm, + # The stable ABI quantize path does not swizzle scales during + # quantization, so always report unswizzled. The GEMM C++ code + # will swizzle on-the-fly when it sees this flag is False. + with_gemm_swizzled_scales=False, ) def calibrate(self, tensor: torch.Tensor) -> None: diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 8ed1b4682c..89f5b24508 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -174,14 +174,19 @@ def update_quantized( assert isinstance(dst, NVFP4Tensor), f"Cannot store quantized NVFP4 in {type(dst)} type." + # Handle non-torch.Tensor inputs (e.g. DebugQuantizedTensor from debug mode backward pass) + if not isinstance(src, torch.Tensor) and hasattr(src, "dequantize"): + src = src.dequantize() # Make sure input is in expected format if not devices_match(src.device, dst.device): src = src.to(device=dst.device) if not src.is_contiguous(): src = src.contiguous() - # Launch cast kernel - tex.quantize(src, self, dst, noop_flag) + # Launch cast kernel via stable ABI + from transformer_engine.pytorch.tensor._quantize_stable import quantize_into + + quantize_into(src, self, dst, noop_flag) return dst @@ -207,8 +212,17 @@ def copy(self) -> NVFP4Quantizer: return quantizer def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: - """Quantize tensor implementation""" - return tex.quantize(tensor, self) + """Quantize tensor implementation via stable ABI""" + from transformer_engine.pytorch.tensor._quantize_stable import quantize_into + + # Handle non-torch.Tensor inputs (e.g. DebugQuantizedTensor from debug mode backward pass) + if not isinstance(tensor, torch.Tensor) and hasattr(tensor, "dequantize"): + tensor = tensor.dequantize() + dst = self.make_empty(list(tensor.shape), dtype=tensor.dtype, device=tensor.device) + if tensor.numel() > 0: + t = tensor.contiguous() if not tensor.is_contiguous() else tensor + quantize_into(t, self, dst) + return dst def is_quantizable(self, inp: torch.Tensor) -> bool: """Returns whether or not given inp can be quantized""" diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index 7bbe809c9d..3cd7a81506 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -265,6 +265,10 @@ def update_usage( """ For MXFP8, columnwise scaled output is only produced by x2 scaling kernels, so this function only disables usages. + + Note: rowwise data is preserved even when rowwise_usage=False + because the stable ABI GEMM wrapper may need it to create + columnwise data on-the-fly via _ensure_mxfp8_columnwise. """ # Default usage is based on available data @@ -283,21 +287,17 @@ def update_usage( raise RuntimeError( "Requested row-wise usage, but MXFP8Tensor is missing row-scaled scale-inverses" ) - else: - self._rowwise_data = None - self._rowwise_scale_inv = None + # Note: do NOT clear rowwise data when rowwise_usage=False. + # The GEMM wrapper needs rowwise data to create columnwise on-the-fly. # Update column-scaled data if columnwise_usage: if self._columnwise_data is None: - raise RuntimeError( - "Requested column-wise usage, but MXFP8Tensor is missing column-scaled FP8 data" - ) + # Columnwise data not available — the GEMM wrapper will create + # it on-the-fly from rowwise data via _ensure_mxfp8_columnwise. + pass if self._columnwise_scale_inv is None: - raise RuntimeError( - "Requested column-wise usage, " - "but MXFP8Tensor is missing column-scaled scale-inverses" - ) + pass # Will be created on-the-fly else: self._columnwise_data = None self._columnwise_scale_inv = None