Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
28e5f53
Refactor to group_sizes per tensor
jberchtold-nvidia Mar 9, 2026
4a57485
Support first_dims and last_dims instead of a single group_sizes per
jberchtold-nvidia Mar 10, 2026
345d940
Refactor GMM FFIs to store static attrs as structs
jberchtold-nvidia Mar 10, 2026
ed9c8e4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 10, 2026
ed0deaf
Cleanup C++ v2 FFI
jberchtold-nvidia Mar 10, 2026
88bb7da
Fix int64 workspace usage
jberchtold-nvidia Mar 10, 2026
60312c8
Address greptile comments
jberchtold-nvidia Mar 10, 2026
025f598
Refactor wgrad-specific checks to be generic for GMM in gemm.py
jberchtold-nvidia Mar 10, 2026
089e530
Refactor XLA FFI struct setup
jberchtold-nvidia Mar 10, 2026
8ad2294
Fix edge case in TE v1 GMM
jberchtold-nvidia Mar 11, 2026
bac092d
Merge remote-tracking branch 'github-upstream/main' into jberchtold/g…
jberchtold-nvidia Mar 11, 2026
4ff5d1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2026
0cb7289
Fix issues on Hopper
jberchtold-nvidia Mar 11, 2026
37d300a
Merge remote-trackint commit --amend -sg branch 'github-upstream/main…
jberchtold-nvidia Mar 11, 2026
cc236ad
Refactor
jberchtold-nvidia Mar 12, 2026
1d1fec9
MXFP8 grouped quantize V2
jberchtold-nvidia Mar 13, 2026
269a518
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2026
2b84dfd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2026
b2b3216
MXFP8 quantization working
jberchtold-nvidia Mar 14, 2026
47218b3
Merge remote-tracking branch 'github-upstream/main' into jberchtold/g…
jberchtold-nvidia Mar 14, 2026
611526f
mxfp8 grouped gemm
jberchtold-nvidia Mar 14, 2026
c97b0b7
te_permutation NaN issue fix
jberchtold-nvidia Mar 14, 2026
0b9a763
Support GroupedDense quantization checkpointing
jberchtold-nvidia Mar 14, 2026
6b64cea
Temporary commit to assert if V1 grouped quantize is used
jberchtold-nvidia Mar 14, 2026
2dd69d4
Fix scale shapes for MXFP8
jberchtold-nvidia Mar 14, 2026
204b326
Fix MXFP8 scale sharding when FSDP+EP on same axis
jberchtold-nvidia Mar 14, 2026
5fb585f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2026
2902eb2
Merge remote-tracking branch 'github-upstream/main' into jberchtold/g…
jberchtold-nvidia Mar 17, 2026
bee7f3b
Address comments
jberchtold-nvidia Mar 23, 2026
d9b9c44
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 23, 2026
ef0d498
Merge branch 'main' into jberchtold/gmm-refactor
jberchtold-nvidia Mar 23, 2026
9438478
Lint
jberchtold-nvidia Mar 23, 2026
09dfd9c
Fixes for Hopper
jberchtold-nvidia Mar 24, 2026
e25538e
Address review comments
jberchtold-nvidia Mar 24, 2026
78674e9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2026
d5229e2
Merge branch 'main' into jberchtold/gmm-refactor
jberchtold-nvidia Mar 24, 2026
b78435a
Merge jberchtold/gmm-refactor into jberchtold/gmm-mxfp8
jberchtold-nvidia Mar 24, 2026
06ebb44
Fixes
jberchtold-nvidia Mar 24, 2026
a3f8042
wip
jberchtold-nvidia Mar 30, 2026
7e99314
Fix grouped colwise dequantize for transposed ragged tensors and V1 p…
jberchtold-nvidia Mar 30, 2026
68bcbfc
2D shape fixes for flattened 1D shape from grouped quantization
jberchtold-nvidia Mar 30, 2026
81cb189
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 31, 2026
75995e4
Merge remote-tracking branch 'github-upstream/main' into jberchtold/g…
jberchtold-nvidia Apr 3, 2026
d7b04cc
Fix swizzling
jberchtold-nvidia Apr 4, 2026
064f314
Remove pre-swizzling from non-grouped quantization
jberchtold-nvidia Apr 4, 2026
5edef90
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
471 changes: 466 additions & 5 deletions tests/jax/test_custom_call_compute.py

Large diffs are not rendered by default.

38 changes: 38 additions & 0 deletions transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1414,6 +1414,24 @@ __global__ void convert_int32_to_int64_kernel(const int32_t *src, int64_t *dst,
if (idx < n) dst[idx] = static_cast<int64_t>(src[idx]);
}

// Like convert_int32_to_int64_kernel but scales each element by multiplier.
// Used to convert per-expert slice counts to per-expert row counts for multi-dim tensors.
__global__ void convert_int32_to_int64_with_multiplier_kernel(const int32_t *src, int64_t *dst,
size_t n, int64_t multiplier) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) dst[idx] = static_cast<int64_t>(src[idx]) * multiplier;
}

// Computes exclusive prefix sums: offsets[0]=0, offsets[i]=sum(first_dims[0..i-1]*last_dim).
// Produces n_groups+1 values. Single-threaded sequential scan; n_groups is typically small.
__global__ void compute_grouped_tensor_offsets_kernel(const int64_t *first_dims, int64_t *offsets,
size_t n_groups, int64_t last_dim) {
offsets[0] = 0;
for (size_t i = 0; i < n_groups; i++) {
offsets[i + 1] = offsets[i] + first_dims[i] * last_dim;
}
}

} // namespace

void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream) {
Expand All @@ -1424,3 +1442,23 @@ void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cud
convert_int32_to_int64_kernel<<<blocks, threads, 0, stream>>>(src, dst, n);
NVTE_CHECK_CUDA(cudaGetLastError());
}

void nvte_convert_int32_to_int64_with_multiplier(const int32_t *src, int64_t *dst, size_t n,
int64_t multiplier, cudaStream_t stream) {
NVTE_API_CALL(nvte_convert_int32_to_int64_with_multiplier);
if (n == 0) return;
const int threads = 256;
const int blocks = static_cast<int>((n + threads - 1) / threads);
convert_int32_to_int64_with_multiplier_kernel<<<blocks, threads, 0, stream>>>(src, dst, n,
multiplier);
NVTE_CHECK_CUDA(cudaGetLastError());
}

void nvte_compute_grouped_tensor_offsets(const int64_t *first_dims, int64_t *offsets,
size_t n_groups, int64_t last_dim, cudaStream_t stream) {
NVTE_API_CALL(nvte_compute_grouped_tensor_offsets);
// Always write at least offsets[0]=0 (needed even for n_groups==0).
compute_grouped_tensor_offsets_kernel<<<1, 1, 0, stream>>>(first_dims, offsets, n_groups,
last_dim);
NVTE_CHECK_CUDA(cudaGetLastError());
}
29 changes: 29 additions & 0 deletions transformer_engine/common/include/transformer_engine/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,35 @@ size_t nvte_get_grouped_gemm_setup_workspace_size(size_t num_tensors);
*/
void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream);

/*! \brief Convert int32 array to int64 while scaling each element by a multiplier.
*
* Computes dst[i] = (int64_t)src[i] * multiplier for each i in [0, n).
* CUDA-graph safe (no host-device synchronization).
*
* \param[in] src Device pointer to source int32 array.
* \param[out] dst Device pointer to destination int64 array.
* \param[in] n Number of elements.
* \param[in] multiplier Scale factor applied to each element.
* \param[in] stream CUDA stream.
*/
void nvte_convert_int32_to_int64_with_multiplier(const int32_t *src, int64_t *dst, size_t n,
int64_t multiplier, cudaStream_t stream);

/*! \brief Compute exclusive prefix-sum offsets from per-group first-dimension sizes.
*
* Writes n_groups+1 values to offsets: offsets[0]=0,
* offsets[i] = sum(first_dims[0..i-1] * last_dim) for i in [1, n_groups].
* This is CUDA-graph safe (no host-device synchronization).
*
* \param[in] first_dims Device pointer to int64 array of length n_groups.
* \param[out] offsets Device pointer to int64 array of length n_groups+1.
* \param[in] n_groups Number of groups.
* \param[in] last_dim Common last dimension (number of columns).
* \param[in] stream CUDA stream.
*/
void nvte_compute_grouped_tensor_offsets(const int64_t *first_dims, int64_t *offsets,
size_t n_groups, int64_t last_dim, cudaStream_t stream);

void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb,
const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha,
const NVTETensor beta, NVTETensor workspace_setup,
Expand Down
123 changes: 113 additions & 10 deletions transformer_engine/jax/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1610,9 +1610,9 @@ def _compute_cublas_workspace_size(
workspace_size += lhs_scale_inv_aval.size * tensor_scaling_sinv_aligment
workspace_size += rhs_scale_inv_aval.size * tensor_scaling_sinv_aligment
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
# We also pad scale_inv swizzle buffers size for 256 bytes alignment.
workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding
workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding
# Both V1 and V2 quantize now produce pre-swizzled scales, so the GEMM
# does not need extra workspace for nvte_swizzle_scaling_factors.
pass
return workspace_size

@staticmethod
Expand Down Expand Up @@ -2034,11 +2034,14 @@ def _can_use_v2_grouped_gemm(
scaling_mode: ScalingMode,
dtype: jnp.dtype,
has_bias: bool,
lhs_shape=None,
rhs_shape=None,
lhs_axis_boundary=None,
rhs_axis_boundary=None,
) -> bool:
"""Determine whether the cuda-graphable grouped GEMM implementation can be used based on the input parameters."""
# Use the cuda-graphable path for plain BF16 non-quantized inputs; fall back to the legacy
# nvte_multi_tensor_gemm path for all other cases (FP8, MXFP8, etc.) to stay
# feature-compatible with the main branch.
# Use the cuda-graphable path for plain BF16 non-quantized inputs and MXFP8; fall back to
# the legacy nvte_multi_tensor_gemm path for all other cases (tensor-scaled FP8, etc.).
# Bias can be supported in a kernel or in pure-JAX in the future.

enforce_v2_gmm = _should_enforce_v2_grouped_gemm()
Expand All @@ -2063,13 +2066,86 @@ def _can_use_v2_grouped_gemm(
)
return False

if scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias:
if has_bias:
if enforce_v2_gmm:
raise RuntimeError(
"Grouped GEMM with bias is not supported in the TE V2 grouped GEMM kernel, but"
" NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and has_bias is True."
)
return False

if scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16:
return True

if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
# V2 MXFP8 requires that the total first dimension of both operands (up to
# axis_boundary) is divisible by 128, matching the quantize V2 kernel requirement.
# Individual group sizes must also be 128-aligned (dynamic constraint).
if lhs_shape is not None and lhs_axis_boundary is not None:
lhs_first_dim = math.prod(lhs_shape[:lhs_axis_boundary])
if lhs_first_dim % 128 != 0:
if enforce_v2_gmm:
raise RuntimeError(
"The TE V2 grouped GEMM for MXFP8 requires the product of the first"
" dimensions (up to axis_boundary) of LHS to be divisible by 128, but got"
f" {lhs_first_dim} with lhs_shape={lhs_shape} and"
f" lhs_axis_boundary={lhs_axis_boundary}, and"
" NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled."
)
return False
if rhs_shape is not None and rhs_axis_boundary is not None:
rhs_first_dim = math.prod(rhs_shape[:rhs_axis_boundary])
if rhs_first_dim % 128 != 0:
if enforce_v2_gmm:
raise RuntimeError(
"The TE V2 grouped GEMM for MXFP8 requires the product of the first"
" dimensions (up to axis_boundary) of RHS to be divisible by 128, but got"
f" {rhs_first_dim} with rhs_shape={rhs_shape} and"
f" rhs_axis_boundary={rhs_axis_boundary}, and"
" NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled."
)
return False
# V2 MXFP8 also requires that the "last" dimension (after axis_boundary) of both
# operands is a multiple of 128. The V2 GEMM setup kernel computes per-group
# scale pointers as ``data_offset / 32``, which equals ``K_blocks * last_dim``.
# The quantize kernel, however, pads the colwise scale stride to
# ``ceil(last_dim / 128) * 128``, making per-group padded scale larger than
# ``K_blocks * last_dim`` when ``last_dim`` is not 128-aligned. This causes
# adjacent groups' scales to overlap in the flat buffer. Fall back to V1 (which
# swizzles per-group scales individually) when the condition is not met.
if lhs_shape is not None and lhs_axis_boundary is not None:
lhs_last_dim = math.prod(lhs_shape[lhs_axis_boundary:])
if lhs_last_dim % 128 != 0:
if enforce_v2_gmm:
raise RuntimeError(
"The TE V2 grouped GEMM for MXFP8 requires the product of the last"
" dimensions (after axis_boundary) of LHS to be divisible by 128, but got"
f" {lhs_last_dim} with lhs_shape={lhs_shape} and"
f" lhs_axis_boundary={lhs_axis_boundary}, and"
" NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled."
)
return False
if rhs_shape is not None and rhs_axis_boundary is not None:
rhs_last_dim = math.prod(rhs_shape[rhs_axis_boundary:])
if rhs_last_dim % 128 != 0:
if enforce_v2_gmm:
raise RuntimeError(
"The TE V2 grouped GEMM for MXFP8 requires the product of the last"
" dimensions (after axis_boundary) of RHS to be divisible by 128, but got"
f" {rhs_last_dim} with rhs_shape={rhs_shape} and"
f" rhs_axis_boundary={rhs_axis_boundary}, and"
" NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled."
)
return False
return True

if enforce_v2_gmm:
raise RuntimeError(
"The TE V2 grouped GEMM currently only supports BF16 with no quantization recipe and"
f" without bias, but received {scaling_mode=}, {dtype=}, {has_bias=}"
"The TE V2 grouped GEMM currently only supports non-quantized BF16 and MXFP8 with 1D"
" block scaling, but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and the input"
f" parameters do not meet these requirements (scaling_mode= {scaling_mode},"
f" dtype={dtype}, has_bias={has_bias}, lhs_shape={lhs_shape}, rhs_shape={rhs_shape},"
f" lhs_axis_boundary={lhs_axis_boundary}, rhs_axis_boundary={rhs_axis_boundary})."
)
return False

Expand Down Expand Up @@ -2328,7 +2404,34 @@ def grouped_gemm(
" and padded with zeros to not affect the result of the MoE block."
)

use_v2_ffi = _can_use_v2_grouped_gemm(scaling_mode, lhs_data.dtype, has_bias)
use_v2_ffi = _can_use_v2_grouped_gemm(
scaling_mode,
lhs_data.dtype,
has_bias,
lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
lhs_axis_boundary=lhs_axis_boundary,
rhs_axis_boundary=rhs_axis_boundary,
)

# V2 grouped GEMM requires MXFP8 inputs to be pre-swizzled by V2 grouped quantize
# (nvte_group_quantize fuses the swizzle). The C++ V2 GEMM FFI does not re-swizzle.
if use_v2_ffi and scaling_mode == ScalingMode.MXFP8_1D_SCALING:
if isinstance(lhs, GroupedScaledTensor1x) and not lhs.pre_swizzled:
raise ValueError(
"V2 grouped GEMM requires MXFP8 lhs scale_inv to be pre-swizzled. "
"GroupedScaledTensor1x.pre_swizzled is False. "
"Use V2 grouped quantize (nvte_group_quantize, requires SM100+ and "
"128-aligned shapes) to produce pre-swizzled tensors."
)
if isinstance(rhs, GroupedScaledTensor1x) and not rhs.pre_swizzled:
raise ValueError(
"V2 grouped GEMM requires MXFP8 rhs scale_inv to be pre-swizzled. "
"GroupedScaledTensor1x.pre_swizzled is False. "
"Use V2 grouped quantize (nvte_group_quantize, requires SM100+ and "
"128-aligned shapes) to produce pre-swizzled tensors."
)

if use_v2_ffi:
additional_arg_0 = jnp.ones((num_gemms,), jnp.float32) # alpha
additional_arg_1 = jnp.zeros((num_gemms,), jnp.float32) # beta
Expand Down
Loading