diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index ddb74fd636..8a64c07763 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -43,6 +43,7 @@ noop_quantizer_set, QuantizeMetaSet, QuantizeMeta, + get_device_compute_capability, ) from transformer_engine.jax.quantize import helper from transformer_engine.jax.activation import activation @@ -77,6 +78,9 @@ supported_recipes = helper.get_supported_quantization_recipes() supported_recipes = [pytest.param(r, id=r.__class__.__name__) for r in supported_recipes] +is_v2_grouped_gemm_supported = get_device_compute_capability(0) >= 100 +v2_grouped_gemm_unsupported_reason = "V2 grouped GEMM requires SM100+ (Blackwell or newer)" + def is_shape_supported_by_mxfp8(input_shape): try: @@ -1068,7 +1072,13 @@ def test_rht_gemm(self, in_dtype, q_dtype, scaling_mode, m, n, k, data_layout, w @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) -@pytest_parametrize_wrapper("input_shape", [(8, 16, 32)]) +@pytest_parametrize_wrapper( + "input_shape", + [ + (8, 16, 32), # V1 MXFP8: K=32 not 128-aligned + (4, 8, 128), # V2 MXFP8 eligible: K=128, M*32=256 both 128-aligned + ], +) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn]) @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) @pytest_parametrize_wrapper("flatten_axis", [-1]) @@ -1084,8 +1094,17 @@ def test_grouped_qdq( key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) - # *32 so that the input shapes works for MXFP8 - input_shape = (m * 32, n) + # Use 128 multiplier for V2-eligible MXFP8 shapes (both M and K 128-aligned) + # so that per-group row counts are also 128-aligned as required by the V2 kernel. + # Use 32 for other shapes (V1 handles arbitrary group sizes). + v2_eligible = ( + scaling_mode == ScalingMode.MXFP8_1D_SCALING + and is_v2_grouped_gemm_supported + and (m * 32) % 128 == 0 + and n % 128 == 0 + ) + group_size_multiplier = 128 if v2_eligible else 32 + input_shape = (m * group_size_multiplier, n) if with_group_sizes: group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) @@ -1093,7 +1112,7 @@ def test_grouped_qdq( group_sizes = jnp.diff(group_sizes) assert group_sizes.sum() == m assert jnp.any(group_sizes == 0) # make sure that at least one group has 0 row - group_sizes = group_sizes * 32 + group_sizes = group_sizes * group_size_multiplier else: group_sizes = None input_shape = (n_groups, input_shape[0] // n_groups, input_shape[1]) @@ -1115,6 +1134,28 @@ def test_grouped_qdq( assert_dequantized_grouped_scaled_tensor(scaled_tensor, x) + # Verify MXFP8 pre_swizzled flag for ROWWISE with explicit group_sizes. + # pre_swizzled=True indicates the V2 kernel was used (SM100+, 128-aligned dims). + if ( + scaling_mode == ScalingMode.MXFP8_1D_SCALING + and q_layout == QuantizeLayout.ROWWISE + and with_group_sizes + and isinstance(scaled_tensor, GroupedScaledTensor1x) + ): + total_m = m * group_size_multiplier + k_dim = n + if is_v2_grouped_gemm_supported and total_m % 128 == 0 and k_dim % 128 == 0: + # V2 path on SM100+: scales are pre-swizzled for GEMM + assert scaled_tensor.pre_swizzled, ( + "V2 grouped quantize (SM100+, 128-aligned M and K) must produce" + " pre_swizzled=True" + ) + elif k_dim % 128 != 0: + # V1 path: non-128-aligned K forces V1 quantize + assert ( + not scaled_tensor.pre_swizzled + ), "V1 grouped quantize (non-128-aligned K) must produce pre_swizzled=False" + @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) class TestFusedQuantize: @@ -1713,10 +1754,11 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): ] GROUPED_DENSE_INPUT_SHAPES = [ - # (n_groups, m, n, k), the actual m will be multiplied by 32 - (5, 32, 128, 64), # Test the case where n_groups is not a multiple of 4 - (8, 64, 32, 128), - (8, 64, 128, 256), + # (n_groups, m, n, k), the actual m will be multiplied by group_size_multiplier + (5, 32, 128, 64), # V1 MXFP8: K=64 not 128-aligned; also tests n_groups not a multiple of 4 + (8, 64, 32, 128), # V1 MXFP8 GEMM: N=32 not 128-aligned + (8, 64, 128, 256), # V2 MXFP8 eligible: K=256, N=128 both 128-aligned + (4, 4, 128, 128), # V2 MXFP8 eligible: K=128, N=128 both 128-aligned (smaller shape) ] @@ -1742,7 +1784,9 @@ def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims): ref_out.append(jnp.squeeze(out_i)) return ref_out - def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False): + def _generate_grouped_dense_input( + self, dtype, input_shape, data_layout="NN", with_bias=False, group_size_multiplier=32 + ): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 4) n_groups, m, n, k = input_shape @@ -1755,9 +1799,12 @@ def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", wi group_sizes = group_sizes.at[1].set(0) assert group_sizes.sum() == m - # *32 to make sure that input shape works for MXFP8 - group_sizes = group_sizes * 32 - m = m * 32 + # Scale group sizes by the multiplier. + # Use group_size_multiplier=128 for MXFP8 V2 tests so that each group's row count + # is divisible by 128, satisfying the V2 kernel's per-group alignment requirement. + # Use group_size_multiplier=32 for V1 tests or non-MXFP8 tests. + group_sizes = group_sizes * group_size_multiplier + m = m * group_size_multiplier lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m) rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k) @@ -1831,8 +1878,10 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout quantizer.q_dtype = bwd_dtype out_dtype = jnp.bfloat16 + # MXFP8 V2 kernel requires each group's row count to be divisible by 128. + is_mxfp8 = scaling_mode == ScalingMode.MXFP8_1D_SCALING lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( - out_dtype, input_shape, layout + out_dtype, input_shape, layout, group_size_multiplier=128 if is_mxfp8 else 32 ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) @@ -1906,10 +1955,13 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape): def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): fwd_dtype, bwd_dtype = fwd_bwd_dtype dtype = jnp.bfloat16 + # MXFP8 V2 kernel requires each group's row count to be divisible by 128. + is_mxfp8 = scaling_mode == ScalingMode.MXFP8_1D_SCALING x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( dtype, input_shape, with_bias=True, + group_size_multiplier=128 if is_mxfp8 else 32, ) quantizer_set = QuantizerFactory.create_set( diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index a8e0b6df83..985c53f760 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -1414,6 +1414,24 @@ __global__ void convert_int32_to_int64_kernel(const int32_t *src, int64_t *dst, if (idx < n) dst[idx] = static_cast(src[idx]); } +// Like convert_int32_to_int64_kernel but scales each element by multiplier. +// Used to convert per-expert slice counts to per-expert row counts for multi-dim tensors. +__global__ void convert_int32_to_int64_with_multiplier_kernel(const int32_t *src, int64_t *dst, + size_t n, int64_t multiplier) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) dst[idx] = static_cast(src[idx]) * multiplier; +} + +// Computes exclusive prefix sums: offsets[0]=0, offsets[i]=sum(first_dims[0..i-1]*last_dim). +// Produces n_groups+1 values. Single-threaded sequential scan; n_groups is typically small. +__global__ void compute_grouped_tensor_offsets_kernel(const int64_t *first_dims, int64_t *offsets, + size_t n_groups, int64_t last_dim) { + offsets[0] = 0; + for (size_t i = 0; i < n_groups; i++) { + offsets[i + 1] = offsets[i] + first_dims[i] * last_dim; + } +} + } // namespace void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream) { @@ -1424,3 +1442,23 @@ void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cud convert_int32_to_int64_kernel<<>>(src, dst, n); NVTE_CHECK_CUDA(cudaGetLastError()); } + +void nvte_convert_int32_to_int64_with_multiplier(const int32_t *src, int64_t *dst, size_t n, + int64_t multiplier, cudaStream_t stream) { + NVTE_API_CALL(nvte_convert_int32_to_int64_with_multiplier); + if (n == 0) return; + const int threads = 256; + const int blocks = static_cast((n + threads - 1) / threads); + convert_int32_to_int64_with_multiplier_kernel<<>>(src, dst, n, + multiplier); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void nvte_compute_grouped_tensor_offsets(const int64_t *first_dims, int64_t *offsets, + size_t n_groups, int64_t last_dim, cudaStream_t stream) { + NVTE_API_CALL(nvte_compute_grouped_tensor_offsets); + // Always write at least offsets[0]=0 (needed even for n_groups==0). + compute_grouped_tensor_offsets_kernel<<<1, 1, 0, stream>>>(first_dims, offsets, n_groups, + last_dim); + NVTE_CHECK_CUDA(cudaGetLastError()); +} diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 6999dd857f..fcd08a40a9 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -356,6 +356,35 @@ size_t nvte_get_grouped_gemm_setup_workspace_size(size_t num_tensors); */ void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream); +/*! \brief Convert int32 array to int64 while scaling each element by a multiplier. + * + * Computes dst[i] = (int64_t)src[i] * multiplier for each i in [0, n). + * CUDA-graph safe (no host-device synchronization). + * + * \param[in] src Device pointer to source int32 array. + * \param[out] dst Device pointer to destination int64 array. + * \param[in] n Number of elements. + * \param[in] multiplier Scale factor applied to each element. + * \param[in] stream CUDA stream. + */ +void nvte_convert_int32_to_int64_with_multiplier(const int32_t *src, int64_t *dst, size_t n, + int64_t multiplier, cudaStream_t stream); + +/*! \brief Compute exclusive prefix-sum offsets from per-group first-dimension sizes. + * + * Writes n_groups+1 values to offsets: offsets[0]=0, + * offsets[i] = sum(first_dims[0..i-1] * last_dim) for i in [1, n_groups]. + * This is CUDA-graph safe (no host-device synchronization). + * + * \param[in] first_dims Device pointer to int64 array of length n_groups. + * \param[out] offsets Device pointer to int64 array of length n_groups+1. + * \param[in] n_groups Number of groups. + * \param[in] last_dim Common last dimension (number of columns). + * \param[in] stream CUDA stream. + */ +void nvte_compute_grouped_tensor_offsets(const int64_t *first_dims, int64_t *offsets, + size_t n_groups, int64_t last_dim, cudaStream_t stream); + void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, const NVTETensor beta, NVTETensor workspace_setup, diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index c081e451a7..f819401d52 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1610,9 +1610,9 @@ def _compute_cublas_workspace_size( workspace_size += lhs_scale_inv_aval.size * tensor_scaling_sinv_aligment workspace_size += rhs_scale_inv_aval.size * tensor_scaling_sinv_aligment elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: - # We also pad scale_inv swizzle buffers size for 256 bytes alignment. - workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding - workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding + # Both V1 and V2 quantize now produce pre-swizzled scales, so the GEMM + # does not need extra workspace for nvte_swizzle_scaling_factors. + pass return workspace_size @staticmethod @@ -2040,11 +2040,14 @@ def _can_use_v2_grouped_gemm( scaling_mode: ScalingMode, dtype: jnp.dtype, has_bias: bool, + lhs_shape=None, + rhs_shape=None, + lhs_axis_boundary=None, + rhs_axis_boundary=None, ) -> bool: """Determine whether the cuda-graphable grouped GEMM implementation can be used based on the input parameters.""" - # Use the cuda-graphable path for plain BF16 non-quantized inputs; fall back to the legacy - # nvte_multi_tensor_gemm path for all other cases (FP8, MXFP8, etc.) to stay - # feature-compatible with the main branch. + # Use the cuda-graphable path for plain BF16 non-quantized inputs and MXFP8; fall back to + # the legacy nvte_multi_tensor_gemm path for all other cases (tensor-scaled FP8, etc.). # Bias can be supported in a kernel or in pure-JAX in the future. enforce_v2_gmm = _should_enforce_v2_grouped_gemm() @@ -2069,13 +2072,86 @@ def _can_use_v2_grouped_gemm( ) return False - if scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias: + if has_bias: + if enforce_v2_gmm: + raise RuntimeError( + "Grouped GEMM with bias is not supported in the TE V2 grouped GEMM kernel, but" + " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and has_bias is True." + ) + return False + + if scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16: + return True + + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + # V2 MXFP8 requires that the total first dimension of both operands (up to + # axis_boundary) is divisible by 128, matching the quantize V2 kernel requirement. + # Individual group sizes must also be 128-aligned (dynamic constraint). + if lhs_shape is not None and lhs_axis_boundary is not None: + lhs_first_dim = math.prod(lhs_shape[:lhs_axis_boundary]) + if lhs_first_dim % 128 != 0: + if enforce_v2_gmm: + raise RuntimeError( + "The TE V2 grouped GEMM for MXFP8 requires the product of the first" + " dimensions (up to axis_boundary) of LHS to be divisible by 128, but got" + f" {lhs_first_dim} with lhs_shape={lhs_shape} and" + f" lhs_axis_boundary={lhs_axis_boundary}, and" + " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled." + ) + return False + if rhs_shape is not None and rhs_axis_boundary is not None: + rhs_first_dim = math.prod(rhs_shape[:rhs_axis_boundary]) + if rhs_first_dim % 128 != 0: + if enforce_v2_gmm: + raise RuntimeError( + "The TE V2 grouped GEMM for MXFP8 requires the product of the first" + " dimensions (up to axis_boundary) of RHS to be divisible by 128, but got" + f" {rhs_first_dim} with rhs_shape={rhs_shape} and" + f" rhs_axis_boundary={rhs_axis_boundary}, and" + " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled." + ) + return False + # V2 MXFP8 also requires that the "last" dimension (after axis_boundary) of both + # operands is a multiple of 128. The V2 GEMM setup kernel computes per-group + # scale pointers as ``data_offset / 32``, which equals ``K_blocks * last_dim``. + # The quantize kernel, however, pads the colwise scale stride to + # ``ceil(last_dim / 128) * 128``, making per-group padded scale larger than + # ``K_blocks * last_dim`` when ``last_dim`` is not 128-aligned. This causes + # adjacent groups' scales to overlap in the flat buffer. Fall back to V1 (which + # swizzles per-group scales individually) when the condition is not met. + if lhs_shape is not None and lhs_axis_boundary is not None: + lhs_last_dim = math.prod(lhs_shape[lhs_axis_boundary:]) + if lhs_last_dim % 128 != 0: + if enforce_v2_gmm: + raise RuntimeError( + "The TE V2 grouped GEMM for MXFP8 requires the product of the last" + " dimensions (after axis_boundary) of LHS to be divisible by 128, but got" + f" {lhs_last_dim} with lhs_shape={lhs_shape} and" + f" lhs_axis_boundary={lhs_axis_boundary}, and" + " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled." + ) + return False + if rhs_shape is not None and rhs_axis_boundary is not None: + rhs_last_dim = math.prod(rhs_shape[rhs_axis_boundary:]) + if rhs_last_dim % 128 != 0: + if enforce_v2_gmm: + raise RuntimeError( + "The TE V2 grouped GEMM for MXFP8 requires the product of the last" + " dimensions (after axis_boundary) of RHS to be divisible by 128, but got" + f" {rhs_last_dim} with rhs_shape={rhs_shape} and" + f" rhs_axis_boundary={rhs_axis_boundary}, and" + " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled." + ) + return False return True if enforce_v2_gmm: raise RuntimeError( - "The TE V2 grouped GEMM currently only supports BF16 with no quantization recipe and" - f" without bias, but received {scaling_mode=}, {dtype=}, {has_bias=}" + "The TE V2 grouped GEMM currently only supports non-quantized BF16 and MXFP8 with 1D" + " block scaling, but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and the input" + f" parameters do not meet these requirements (scaling_mode= {scaling_mode}," + f" dtype={dtype}, has_bias={has_bias}, lhs_shape={lhs_shape}, rhs_shape={rhs_shape}," + f" lhs_axis_boundary={lhs_axis_boundary}, rhs_axis_boundary={rhs_axis_boundary})." ) return False @@ -2334,7 +2410,34 @@ def grouped_gemm( " and padded with zeros to not affect the result of the MoE block." ) - use_v2_ffi = _can_use_v2_grouped_gemm(scaling_mode, lhs_data.dtype, has_bias) + use_v2_ffi = _can_use_v2_grouped_gemm( + scaling_mode, + lhs_data.dtype, + has_bias, + lhs_shape=lhs_shape, + rhs_shape=rhs_shape, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, + ) + + # V2 grouped GEMM requires MXFP8 inputs to be pre-swizzled by V2 grouped quantize + # (nvte_group_quantize fuses the swizzle). The C++ V2 GEMM FFI does not re-swizzle. + if use_v2_ffi and scaling_mode == ScalingMode.MXFP8_1D_SCALING: + if isinstance(lhs, GroupedScaledTensor1x) and not lhs.pre_swizzled: + raise ValueError( + "V2 grouped GEMM requires MXFP8 lhs scale_inv to be pre-swizzled. " + "GroupedScaledTensor1x.pre_swizzled is False. " + "Use V2 grouped quantize (nvte_group_quantize, requires SM100+ and " + "128-aligned shapes) to produce pre-swizzled tensors." + ) + if isinstance(rhs, GroupedScaledTensor1x) and not rhs.pre_swizzled: + raise ValueError( + "V2 grouped GEMM requires MXFP8 rhs scale_inv to be pre-swizzled. " + "GroupedScaledTensor1x.pre_swizzled is False. " + "Use V2 grouped quantize (nvte_group_quantize, requires SM100+ and " + "128-aligned shapes) to produce pre-swizzled tensors." + ) + if use_v2_ffi: additional_arg_0 = jnp.ones((num_gemms,), jnp.float32) # alpha additional_arg_1 = jnp.zeros((num_gemms,), jnp.float32) # beta diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index a3d363e42a..3ef1444178 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -994,7 +994,8 @@ class GroupedQuantizePrimitive(BasePrimitive): Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias """ - name = "te_grouped_quantize_ffi" + name = "te_grouped_quantize_ffi" # V1: fallback path (supports all shapes, not CUDA-graph safe) + name_v2 = "te_grouped_quantize_v2_ffi" # V2: MXFP8, CUDA-graph safe multiple_results = True impl_static_args = ( 3, @@ -1006,6 +1007,52 @@ class GroupedQuantizePrimitive(BasePrimitive): inner_primitive = None outer_primitive = None + @staticmethod + def _use_v2_kernel(scaling_mode, x_shape, flatten_axis): + """Return True when the V2 (CUDA-graph-safe) MXFP8 kernel can be used. + + V2 requires: + 1. SM100+ (Blackwell) — V2 grouped quantize fuses the scale_inv swizzle via + nvte_group_quantize. The swizzled scale_inv must then be consumed by the + V2 grouped GEMM, which also requires SM100+. Keeping both decisions tied + to SM100+ prevents a mismatch where V2-quantized (pre-swizzled) tensors + are passed to the V1 grouped GEMM (which would re-swizzle and corrupt). + 2. The total first logical dimension (product of x_shape up to flatten_axis) + is divisible by 128. + 3. For multi-dim group tensors (eff > 1, e.g., kernel shape G×K×N), the + per-group row count non_group_m = prod(x_shape[1:eff]) must also be + divisible by 128. + 4. For lhs-style tensors (eff == 1, shape M×K), individual group sizes must + be 128-aligned — this is a dynamic constraint assumed by the caller. + 5. The last logical dimension (contracting dim K or output dim N) must be + divisible by 128, matching the V2 grouped GEMM constraint so that the + two always agree on V1 vs V2. + + Falls back to V1 when constraints are not met. V1 supports arbitrary shapes + but performs a D2H copy of group_sizes (not CUDA-graph safe). + """ + if ScalingMode(scaling_mode) != ScalingMode.MXFP8_1D_SCALING: + return False + # Require SM100+ so V2 quantize (fused swizzle) is only used alongside V2 GEMM. + if get_min_device_compute_capability() < 100: + return False + ndim = len(x_shape) + eff = flatten_axis if flatten_axis >= 0 else flatten_axis + ndim + total_first_dim = math.prod(x_shape[:eff]) + if total_first_dim % 128 != 0: + return False + # For multi-dim group tensors (e.g., kernel shape G×K×N with eff=2), + # non_group_m = K must also be 128-aligned. + if eff > 1: + non_group_m = math.prod(x_shape[1:eff]) + if non_group_m % 128 != 0: + return False + # Last dim must be 128-aligned to match the V2 grouped GEMM requirement. + last_dim = math.prod(x_shape[eff:]) + if last_dim % 128 != 0: + return False + return True + @staticmethod def abstract( x_aval, @@ -1048,7 +1095,16 @@ def abstract( rowwise_scale_inv_shape = (1,) rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype) - amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) + use_v2 = GroupedQuantizePrimitive._use_v2_kernel(scaling_mode, x_aval.shape, flatten_axis) + if use_v2: + # V2 path: 5th output is int64_workspace laid out as: + # [n_groups int64 group_sizes | n_groups+1 int64 offsets] + # = (2*n_groups + 1) * sizeof(int64_t) bytes stored as uint8. + n_groups = group_sizes_aval.size + fifth_out_aval = jax.core.ShapedArray(shape=((2 * n_groups + 1) * 8,), dtype=jnp.uint8) + else: + # V1 path: 5th output is amax + fifth_out_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) if q_layout.has_colwise: colwise_out_shape = out_shape @@ -1068,7 +1124,7 @@ def abstract( colwise_out_aval, rowwise_scale_inv_aval, colwise_scale_inv_aval, - amax_aval, + fifth_out_aval, ) @staticmethod @@ -1082,9 +1138,17 @@ def outer_abstract(*args, **kwargs): colwise_out, scale_inv, colwise_scale_inv, - updated_amax, + fifth_out, ) = GroupedQuantizePrimitive.abstract(*args, **kwargs) - return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, updated_amax + # When V2 is used, the inner abstract returns int64_workspace as the 5th output. + # The outer interface always presents amax (float32, n_groups) for a consistent API. + scaling_mode = kwargs.get("scaling_mode") + x_aval = args[0] + group_sizes_aval = args[2] + flatten_axis = kwargs.get("flatten_axis") + if GroupedQuantizePrimitive._use_v2_kernel(scaling_mode, x_aval.shape, flatten_axis): + fifth_out = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) + return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, fifth_out @staticmethod def lowering( @@ -1107,6 +1171,21 @@ def lowering( assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval.dtype == jnp.float32 assert group_sizes_aval.dtype == jnp.int32 + use_v2 = GroupedQuantizePrimitive._use_v2_kernel(scaling_mode, x_aval.shape, flatten_axis) + if use_v2: + # V2: CUDA-graph safe; scale is passed but ignored by the C++ handler. + # Requires total_first_dim % 128 == 0 (checked above) and all individual + # group sizes % 128 == 0 (dynamic constraint, enforced by the kernel). + return ffi.ffi_lowering(GroupedQuantizePrimitive.name_v2)( + ctx, + x, + scale, + group_sizes, + q_layout=q_layout.value.value, + flatten_axis=flatten_axis, + ) + # V1: supports arbitrary shapes but not CUDA-graph safe (performs D2H copy of group_sizes). + # Used for non-MXFP8 scaling modes and for MXFP8 when total_first_dim % 128 != 0. return ffi.ffi_lowering(GroupedQuantizePrimitive.name)( ctx, x, @@ -1137,7 +1216,7 @@ def impl( colwise_out, rowwise_scale_inv, colwise_scale_inv, - updated_amax, + fifth, ) = GroupedQuantizePrimitive.inner_primitive.bind( x, scale, @@ -1148,6 +1227,12 @@ def impl( flatten_axis=flatten_axis, scale_dtype=scale_dtype, ) + use_v2 = GroupedQuantizePrimitive._use_v2_kernel(scaling_mode, x.shape, flatten_axis) + if use_v2: + # fifth is int64_workspace; return a dummy zero amax for interface compatibility + updated_amax = jnp.zeros((group_sizes.size,), jnp.float32) + else: + updated_amax = fifth return (rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax) @@ -1259,6 +1344,14 @@ def grouped_quantize( for i, quantizer_i in enumerate(quantizer.quantizers): quantizer_i.update(updated_amax[i].reshape((1,))) + # V2 grouped quantize (nvte_group_quantize) fuses the scale_inv swizzle into + # the kernel, so the resulting tensors are already swizzled for GEMM. + # Note: V1 also produces swizzled scales (via set_with_gemm_swizzled_scales), + # but pre_swizzled is only set for V2 to maintain pytree compatibility. + # The dequantizer detects MXFP8 swizzling via the scaling_mode instead. + use_v2 = GroupedQuantizePrimitive._use_v2_kernel( + quantizer.scaling_mode.value, x.shape, flatten_axis + ) out = ScaledTensorFactory.create( data=rowwise_casted_output, scale_inv=rowwise_scale_inv, @@ -1271,6 +1364,7 @@ def grouped_quantize( flatten_axis=flatten_axis, first_dims=ragged_first_dims, original_shape=original_shape, + pre_swizzled=use_v2, ) return out diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index a74b209e4f..3ba0e7e9b2 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -119,6 +119,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeV2Handler); + XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index a7f16bb31f..6ca907032c 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -481,6 +481,8 @@ class JAXX_GroupedTensorWrapper { m_grouped_tensor(other.m_grouped_tensor), m_data_tensor(other.m_data_tensor), m_scale_inv_tensor(other.m_scale_inv_tensor), + m_colwise_data_tensor(other.m_colwise_data_tensor), + m_colwise_scale_inv_tensor(other.m_colwise_scale_inv_tensor), m_sizes_tensor(other.m_sizes_tensor), m_offsets_tensor(other.m_offsets_tensor) { other.m_grouped_tensor = nullptr; @@ -489,6 +491,10 @@ class JAXX_GroupedTensorWrapper { ~JAXX_GroupedTensorWrapper(); void set_rowwise(Buffer_Type const &data, std::optional const &scale_inv); + void set_columnwise(Buffer_Type const &data, std::optional const &scale_inv); + void set_with_gemm_swizzled_scales(bool val); + void replace_scale_inv(bool use_colwise, uint8_t *sinv_ptr, NVTEDType sinv_dtype, + NVTEShape sinv_shape); void set_group_info(Buffer_Type const &group_sizes, Buffer_Type const &group_offsets, NVTEGroupedTensorParam group_sizes_param_name); // Set only group sizes (no offsets); the setup kernel will compute offsets from sizes. @@ -505,6 +511,8 @@ class JAXX_GroupedTensorWrapper { // Internal tensors. These need to be kept alive as long as the grouped tensor is alive. NVTEBasicTensor m_data_tensor{}; NVTEBasicTensor m_scale_inv_tensor{}; + NVTEBasicTensor m_colwise_data_tensor{}; + NVTEBasicTensor m_colwise_scale_inv_tensor{}; NVTEBasicTensor m_sizes_tensor{}; NVTEBasicTensor m_offsets_tensor{}; @@ -556,6 +564,58 @@ void JAXX_GroupedTensorWrapper::set_rowwise(Buffer_Type const &data, } } +void JAXX_GroupedTensorWrapper::set_columnwise(Buffer_Type const &data, + std::optional const &scale_inv) { + NVTEDType data_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(data.element_type())); + m_colwise_data_tensor = + NVTEBasicTensor{reinterpret_cast(data.untyped_data()), data_dtype, m_data_shape}; + + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedColumnwiseData, + &m_colwise_data_tensor, sizeof(m_colwise_data_tensor)); + + if (scale_inv.has_value()) { + NVTEDType scale_inv_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(scale_inv->element_type())); + NVTEShape logical_scale_shape{}; + if (scale_inv->dimensions().size() == 1) { + logical_scale_shape.ndim = 1; + logical_scale_shape.data[0] = scale_inv->dimensions()[0]; + } else if (scale_inv->dimensions().size() == 2) { + logical_scale_shape.ndim = 2; + logical_scale_shape.data[0] = scale_inv->dimensions()[0]; + logical_scale_shape.data[1] = scale_inv->dimensions()[1]; + } else { + NVTE_CHECK(false, "Expected 1D or 2D tensor for GEMM columnwise scale_inv but received ndim=", + scale_inv->dimensions().size()); + } + m_colwise_scale_inv_tensor = + NVTEBasicTensor{reinterpret_cast(scale_inv->untyped_data()), scale_inv_dtype, + logical_scale_shape}; + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedColumnwiseScaleInv, + &m_colwise_scale_inv_tensor, sizeof(m_colwise_scale_inv_tensor)); + } +} + +void JAXX_GroupedTensorWrapper::set_with_gemm_swizzled_scales(bool val) { + auto v = static_cast(val); + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedWithGEMMSwizzledScales, &v, + sizeof(v)); +} + +void JAXX_GroupedTensorWrapper::replace_scale_inv(bool use_colwise, uint8_t *sinv_ptr, + NVTEDType sinv_dtype, NVTEShape sinv_shape) { + if (use_colwise) { + m_colwise_scale_inv_tensor = NVTEBasicTensor{sinv_ptr, sinv_dtype, sinv_shape}; + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedColumnwiseScaleInv, + &m_colwise_scale_inv_tensor, sizeof(m_colwise_scale_inv_tensor)); + } else { + m_scale_inv_tensor = NVTEBasicTensor{sinv_ptr, sinv_dtype, sinv_shape}; + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedRowwiseScaleInv, + &m_scale_inv_tensor, sizeof(m_scale_inv_tensor)); + } +} + void JAXX_GroupedTensorWrapper::set_group_info(Buffer_Type const &group_sizes, Buffer_Type const &group_offsets, NVTEGroupedTensorParam group_sizes_param_name) { @@ -619,22 +679,19 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, return std::move(grouped_tensor_wrapper); } -// V2 variant: derives data shape from the XLA buffer directly, converts group_sizes +// V2 variant (NO_SCALING): derives data shape from the XLA buffer directly, converts group_sizes // int32→int64 per-tensor into a dedicated slot of int64_workspace, and wires first_dims/last_dims. // int64_offset (in int64 elements) is updated on return to the next available slot so callers can // thread it through successive make_grouped_tensor calls without aliasing. Bounds are checked -// before each slot is used. Only NO_SCALING is supported. +// before each slot is used. Only NO_SCALING is supported by this overload. JAXX_GroupedTensorWrapper make_grouped_tensor( Buffer_Type const &data, Buffer_Type const &first_dims, Buffer_Type const &last_dims, int64_t *int64_workspace_base, size_t int64_workspace_capacity, size_t &int64_offset, - size_t num_gemms, cudaStream_t stream, int64_t axis_boundary = -1) { + size_t num_gemms, cudaStream_t stream, size_t left_size, size_t right_size) { auto dims = data.dimensions(); - NVTE_CHECK(dims.size() >= 2, "grouped GEMM data buffer must be at least 2D."); - // Flatten dims at axis_boundary to produce a 2D NVTE shape. - // axis_boundary=-1 (default) collapses dims[0..N-2] → rows and keeps dims[N-1] → cols, - // preserving the prior behaviour for output buffers (e.g. [G, K, N] for wgrad). - size_t ab = (axis_boundary < 0) ? dims.size() - 1 : static_cast(axis_boundary); - NVTEShape dataShape{.data = {product(dims, 0, ab), product(dims, ab, dims.size())}, .ndim = 2}; + NVTE_CHECK(product(dims) == left_size * right_size, + "grouped GEMM data buffer element count does not match the provided 2D shape."); + NVTEShape dataShape{.data = {left_size, right_size}, .ndim = 2}; JAXX_GroupedTensorWrapper wrapper(JAXX_Scaling_Mode::NO_SCALING, num_gemms, dataShape); wrapper.set_rowwise(data, std::nullopt); if (first_dims.element_count() > 0) { @@ -660,6 +717,56 @@ JAXX_GroupedTensorWrapper make_grouped_tensor( return wrapper; } +// V2 variant with scaling support (MXFP8 or NO_SCALING). Accepts scale_inv buffer and +// use_colwise flag to wire rowwise or columnwise data+scales for the grouped tensor. +// Pre-swizzled scales are indicated via set_with_gemm_swizzled_scales(true). +JAXX_GroupedTensorWrapper make_grouped_tensor( + Buffer_Type const &data, Buffer_Type const &scale_inv, JAXX_Scaling_Mode scaling_mode, + bool use_colwise, Buffer_Type const &first_dims, Buffer_Type const &last_dims, + int64_t *int64_workspace_base, size_t int64_workspace_capacity, size_t &int64_offset, + size_t num_gemms, cudaStream_t stream, size_t left_size, size_t right_size) { + auto dims = data.dimensions(); + NVTE_CHECK(product(dims) == left_size * right_size, + "grouped GEMM data buffer element count does not match the provided 2D shape."); + NVTEShape dataShape{.data = {left_size, right_size}, .ndim = 2}; + JAXX_GroupedTensorWrapper wrapper(scaling_mode, num_gemms, dataShape); + + const bool is_mxfp8 = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; + if (is_mxfp8 && use_colwise) { + wrapper.set_columnwise(data, scale_inv); + } else if (is_mxfp8) { + wrapper.set_rowwise(data, scale_inv); + } else { + // NO_SCALING: no scale_inv needed + wrapper.set_rowwise(data, std::nullopt); + } + if (is_mxfp8) { + wrapper.set_with_gemm_swizzled_scales(true); + } + + if (first_dims.element_count() > 0) { + NVTE_CHECK(first_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, + "int64_workspace overflow: not enough space for first_dims conversion."); + auto *slot = int64_workspace_base + int64_offset; + nvte_convert_int32_to_int64(reinterpret_cast(first_dims.untyped_data()), slot, + num_gemms, stream); + wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedFirstDims); + int64_offset += num_gemms; + } + if (last_dims.element_count() > 0) { + NVTE_CHECK(last_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, + "int64_workspace overflow: not enough space for last_dims conversion."); + auto *slot = int64_workspace_base + int64_offset; + nvte_convert_int32_to_int64(reinterpret_cast(last_dims.untyped_data()), slot, + num_gemms, stream); + wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedLastDims); + int64_offset += num_gemms; + } + return wrapper; +} + // Returns num_gemms from the first non-empty per-tensor group_sizes buffer, // falling back to the element count of alpha for the uniform-batch case. size_t grouped_gemm_num_gemms(Buffer_Type const &lhs_first_dims, Buffer_Type const &lhs_last_dims, @@ -752,13 +859,19 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty auto [lhs_is_trans, rhs_is_trans, scaling_mode, lhs_axis_boundary, rhs_axis_boundary, lhs_left_size, lhs_right_size, rhs_left_size, rhs_right_size] = config; - NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, - "Only non-quantized grouped GEMM is supported in current implementation."); + NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING || + scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING, + "Only NO_SCALING and MXFP8_1D_SCALING are supported in the V2 grouped GEMM."); + + const bool is_mxfp8 = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; size_t num_gemms = grouped_gemm_num_gemms(lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims, out_first_dims, out_last_dims, alpha); // Workspaces. + // V2 GEMM receives scale_inv already swizzled by nvte_group_quantize (V2 grouped quantize + // fuses the swizzle). No extra sinv reservation is needed; the full cublas_workspace is + // available for cuBLAS. auto setup_workspace_ptr = reinterpret_cast(setup_workspace->untyped_data()); auto cublas_workspace_ptr = reinterpret_cast(cublas_workspace->untyped_data()); cublas_workspace_ptr = move_ptr_to_next_256B_aligned(cublas_workspace_ptr); @@ -783,14 +896,39 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty auto *int64_base = reinterpret_cast(int64_workspace->untyped_data()); size_t int64_capacity = int64_workspace->element_count() / sizeof(int64_t); size_t int64_offset = 0; + + // For MXFP8: in JAX, rhs=cuBLAS_A, lhs=cuBLAS_B (swapped). + // Colwise is needed when the operand's contracting dim is NOT the last dim in its layout. + const bool rhs_use_colwise = is_mxfp8 && !rhs_is_trans; + const bool lhs_use_colwise = is_mxfp8 && lhs_is_trans; + + // For MXFP8: scale_inv is already swizzled (pre-swizzled by V2 grouped quantize via + // nvte_group_quantize). Pass the buffers directly to make_grouped_tensor which sets + // with_gemm_swizzled_scales(true) for MXFP8 automatically. No re-swizzling needed. auto rhs_tensor = - make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, int64_capacity, - int64_offset, num_gemms, stream, rhs_axis_boundary); + is_mxfp8 + ? make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, rhs_use_colwise, rhs_first_dims, + rhs_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, + stream, rhs_left_size, rhs_right_size) + : make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, rhs_left_size, rhs_right_size); auto lhs_tensor = - make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, - int64_offset, num_gemms, stream, lhs_axis_boundary); - auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, - int64_capacity, int64_offset, num_gemms, stream); + is_mxfp8 + ? make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, lhs_use_colwise, lhs_first_dims, + lhs_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, + stream, lhs_left_size, lhs_right_size) + : make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, lhs_left_size, lhs_right_size); + + // Output stays NO_SCALING. Derive 2D shape from the output buffer's own dims using + // last-dim-as-columns convention (equivalent to axis_boundary=-1 in the old API). + auto out_dims = output->dimensions(); + NVTE_CHECK(out_dims.size() > 0, "output buffer must have at least 1 dimension"); + size_t out_left_size = product(out_dims, 0, out_dims.size() - 1); + size_t out_right_size = static_cast(out_dims[out_dims.size() - 1]); + auto out_tensor = + make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, out_left_size, out_right_size); auto [avg_m, avg_k_lhs] = grouped_gemm_avg_dims( lhs_first_dims, lhs_last_dims, {lhs_left_size, lhs_right_size}, num_gemms, lhs_is_trans); @@ -943,20 +1081,14 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type const size_t tensor_scaling_sinv_aligment = 16; const size_t mxfp8_scaling_sinv_alignment_padding = 256; auto workspace_size = workspace_total_size - workspace_alignment_padding; - if (is_mxfp8_scaling) { - // For MXFP8 swizzled scale_inv buffers, only the first pointer needs to be with 256B alignment padding. Later pointers are guaranteed to be 256-aligned as the scale_inv shapes are padded by 128x4. - workspace_size -= (lhs_sinv_size + rhs_sinv_size + 2 * mxfp8_scaling_sinv_alignment_padding); - } else if (is_tensor_scaling) { + if (is_tensor_scaling) { // For tensor scaling, each matrix has a single scale value, and all scales need to be aligned // by 16 bytes to meet the requirement of CUDA 12.9.1 and later. workspace_size -= tensor_scaling_sinv_aligment * (lhs_sinv_size + rhs_sinv_size); } workspace_size = workspace_size / num_streams; - auto swizzled_lhs_sinv_ptr = workspace_ptr + workspace_size * num_streams; - swizzled_lhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_lhs_sinv_ptr); - auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size; - swizzled_rhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_rhs_sinv_ptr); - auto lhs_scatter_aligned_ptr = swizzled_lhs_sinv_ptr; // Already 256B aligned + auto lhs_scatter_aligned_ptr = workspace_ptr + workspace_size * num_streams; + lhs_scatter_aligned_ptr = move_ptr_to_next_256B_aligned(lhs_scatter_aligned_ptr); auto rhs_scatter_aligned_ptr = lhs_scatter_aligned_ptr + num_gemms * tensor_scaling_sinv_aligment; size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); @@ -1050,8 +1182,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type // These lists are to keep the TensorWrapper objects alive std::vector lhs_wrapper_list; std::vector rhs_wrapper_list; - std::vector lhs_swizzle_wrapper_list; // For MXFP8 scale_inv swizzling - std::vector rhs_swizzle_wrapper_list; std::vector bias_wrapper_list; std::vector pre_gelu_wrapper_list; std::vector out_wrapper_list; @@ -1060,8 +1190,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type // These lists are the actual NVTETensor (void *) lists for multi-stream GEMM std::vector lhs_list; std::vector rhs_list; - std::vector lhs_swizzle_list; - std::vector rhs_swizzle_list; std::vector bias_list; std::vector pre_gelu_list; std::vector out_list; @@ -1134,13 +1262,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type else lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, tensor_scaling_sinv_shape); } else if (is_mxfp8_scaling) { - auto lhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - auto rhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - void *swizzled_lhs_sinv_vptr = static_cast(swizzled_lhs_sinv_ptr); - void *swizzled_rhs_sinv_vptr = static_cast(swizzled_rhs_sinv_ptr); - - // {lhs, rhs}_swizzle_i point to unswizzled scale_inv data as input, while {lhs, rhs}_i - // point to swizzled scale_inv data (store on workspace, only used for GEMM). + // MXFP8 scales are pre-swizzled by the quantize kernel (both V1 and V2), + // so we pass them directly to the GEMM without a separate swizzle pass. // Note: even if is_empty_gemm is true, sinv are still non-empty, need to move the pointers auto lhs_sinv_shape_i = get_block_scale_shape(scaling_mode, lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise); @@ -1149,32 +1272,17 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv_size_i = lhs_sinv_shape_i[0] * lhs_sinv_shape_i[1]; rhs_sinv_size_i = rhs_sinv_shape_i[0] * rhs_sinv_shape_i[1]; if (lhs_use_colwise) { - lhs_swizzle_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); - lhs_swizzle_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); - lhs_i.set_columnwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); + lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); } else { - lhs_swizzle_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); - lhs_swizzle_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); - lhs_i.set_rowwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); + lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); } lhs_i.set_with_gemm_swizzled_scales(true); if (rhs_use_colwise) { - rhs_swizzle_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); - rhs_swizzle_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); - rhs_i.set_columnwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); + rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); } else { - rhs_swizzle_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); - rhs_swizzle_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); - rhs_i.set_rowwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); + rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); } rhs_i.set_with_gemm_swizzled_scales(true); - - if (!is_empty_gemm) { - lhs_swizzle_wrapper_list.push_back(std::move(lhs_swizzle_i)); - rhs_swizzle_wrapper_list.push_back(std::move(rhs_swizzle_i)); - lhs_swizzle_list.push_back(lhs_swizzle_wrapper_list.back().data()); - rhs_swizzle_list.push_back(rhs_swizzle_wrapper_list.back().data()); - } } else { NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, "Unsupported scaling mode: ", static_cast(scaling_mode)); @@ -1192,10 +1300,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes; lhs_sinv_total_size += lhs_sinv_size_i; rhs_sinv_total_size += rhs_sinv_size_i; - if (is_mxfp8_scaling) { - swizzled_lhs_sinv_ptr += lhs_sinv_size_i * lhs_sinv_dtype_bytes; - swizzled_rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes; - } } if (has_bias) bias_ptr += n * bias_dtype_bytes; @@ -1236,18 +1340,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t num_non_empty_gemms = lhs_list.size(); - if (is_mxfp8_scaling) { - for (int i = 0; i < num_non_empty_gemms; i++) { - // The i-th GEMM will use the (i % num_streams)-th stream to compute, - // use the same stream to swizzle the scaling factors to make sure that - // the swizzling is done before the GEMM computation starts. - int stream_id = i % num_streams; - cudaStream_t stream_i = nvte_get_compute_stream(stream_id); - nvte_swizzle_scaling_factors(lhs_swizzle_list[i], lhs_list[i], stream_i); - nvte_swizzle_scaling_factors(rhs_swizzle_list[i], rhs_list[i], stream_i); - } - } - // Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM size_t num_zero_outs = zero_out_dptr_list.size(); for (int i = 0; i < num_zero_outs; i++) { diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 28cb39b5d1..e3bc122403 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -33,6 +33,7 @@ pybind11::dict Registrations() { // Quantization dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler); dict["te_grouped_quantize_ffi"] = EncapsulateFFI(GroupedQuantizeHandler); + dict["te_grouped_quantize_v2_ffi"] = EncapsulateFFI(GroupedQuantizeV2Handler); dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler); // Softmax diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index c5a766f7f2..db9cf94db5 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -9,6 +9,7 @@ #include "../extensions.h" #include "transformer_engine/cast.h" +#include "transformer_engine/gemm.h" #include "transformer_engine/hadamard_transform.h" #include "transformer_engine/recipe.h" #include "transformer_engine/transformer_engine.h" @@ -451,6 +452,12 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty } } + // For MXFP8, produce pre-swizzled scales so the GEMM can consume them directly + // without a separate swizzle pass. + if (is_mxfp8_scaling) { + out_i.set_with_gemm_swizzled_scales(true); + } + input_holders.push_back(std::move(inp_i)); output_holders.push_back(std::move(out_i)); @@ -494,5 +501,135 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, .Attr("q_layout") .Attr("flatten_axis")); +Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Type scale_unused, + Buffer_Type group_sizes, Result_Type rowwise_out, + Result_Type colwise_out, Result_Type rowwise_sinv, + Result_Type colwise_sinv, Result_Type int64_workspace, + JAXX_Quantize_Layout quantize_layout, int64_t flatten_axis) { + (void)scale_unused; // scale is unused for MXFP8; accepted to match V1 input arity + auto in_dtype = convert_ffi_datatype_to_te_dtype(inputs.element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(rowwise_out->element_type()); + auto sinv_dtype = convert_ffi_datatype_to_te_dtype(rowwise_sinv->element_type()); + + NVTE_CHECK(is_fp8_dtype(out_dtype), "Output datatype must be FP8 for GroupedQuantizeV2."); + NVTE_CHECK(sinv_dtype == DType::kFloat8E8M0, + "scale_inv must be E8M0 for MXFP8 grouped quantize."); + + auto input_dims = inputs.dimensions(); + int64_t input_ndim = input_dims.size(); + if (flatten_axis < 0) flatten_axis += input_ndim; + NVTE_CHECK(flatten_axis < input_ndim && flatten_axis > 0, "flatten_axis is out of bounds!"); + + auto m = product(input_dims, 0, flatten_axis); + auto n = product(input_dims, flatten_axis, input_ndim); + size_t n_groups = group_sizes.dimensions()[0]; + + // Workspace layout (CUDA-graph safe, all device-side): + // int64_ptr[0 .. n_groups-1] : per-group ROW counts (int64) + // int64_ptr[n_groups .. 2*n_groups] : exclusive prefix-sum offsets (n_groups+1 values) + auto *int64_ptr = reinterpret_cast(int64_workspace->untyped_data()); + auto *offsets_ptr_out = int64_ptr + n_groups; // n_groups+1 values follow group_sizes + + // non_group_m handles multi-dim tensors (e.g., kernel shape G×K×N with flatten_axis=2): + // group_sizes[i] counts "slices" along the outermost group axis (e.g., 1 per expert), + // while the kernel expects actual ROW counts (e.g., K rows per expert). + // non_group_m = product(input_dims[1..flatten_axis)) converts slice→row count. + // For the lhs case (shape M×K, flatten_axis=1), non_group_m=1 (no-op). + int64_t non_group_m = + (flatten_axis > 1) ? product(input_dims, 1, static_cast(flatten_axis)) : 1; + + // Convert int32 group_sizes to int64 row counts on device (CUDA-graph safe, no D2H). + nvte_convert_int32_to_int64_with_multiplier( + reinterpret_cast(group_sizes.untyped_data()), int64_ptr, n_groups, + non_group_m, stream); + + // Compute exclusive prefix-sum offsets on device (CUDA-graph safe, no D2H). + nvte_compute_grouped_tensor_offsets(int64_ptr, offsets_ptr_out, n_groups, static_cast(n), + stream); + + NVTEShape data_shape{}; + data_shape.data[0] = m; + data_shape.data[1] = n; + data_shape.ndim = 2; + + NVTEShape sz_shape{}; + sz_shape.ndim = 1; + sz_shape.data[0] = n_groups; + + // Offsets tensor has n_groups+1 elements (exclusive prefix sums with sentinel). + NVTEShape offsets_shape{}; + offsets_shape.ndim = 1; + offsets_shape.data[0] = n_groups + 1; + + // Build input grouped tensor (plain float data, no quantization on the input side). + GroupedTensorWrapper in_grouped(n_groups, data_shape, + get_nvte_scaling_mode(JAXX_Scaling_Mode::NO_SCALING)); + in_grouped + .set_rowwise_data(reinterpret_cast(inputs.untyped_data()), in_dtype, data_shape) + .set_first_dims(reinterpret_cast(int64_ptr), DType::kInt64, sz_shape) + .set_tensor_offsets(reinterpret_cast(offsets_ptr_out), DType::kInt64, offsets_shape); + + // Build output grouped tensor. + GroupedTensorWrapper out_grouped(n_groups, data_shape, + get_nvte_scaling_mode(JAXX_Scaling_Mode::MXFP8_1D_SCALING)); + out_grouped.set_first_dims(reinterpret_cast(int64_ptr), DType::kInt64, sz_shape) + .set_tensor_offsets(reinterpret_cast(offsets_ptr_out), DType::kInt64, offsets_shape); + + // Rowwise output data + scale_inv. + if (is_quantize_rowwise(quantize_layout)) { + NVTEShape rw_sinv_shape{}; + rw_sinv_shape.ndim = 2; + rw_sinv_shape.data[0] = m; + rw_sinv_shape.data[1] = n / 32; // MXFP8 block size = 32 + out_grouped.set_rowwise_data(rowwise_out->untyped_data(), out_dtype, data_shape) + .set_rowwise_scale_inv(rowwise_sinv->untyped_data(), sinv_dtype, rw_sinv_shape); + } + + // Colwise output data + scale_inv. + if (is_quantize_colwise(quantize_layout)) { + NVTEShape cw_sinv_shape{}; + cw_sinv_shape.ndim = 2; + cw_sinv_shape.data[0] = m / 32; // MXFP8 block size = 32 + cw_sinv_shape.data[1] = n; + out_grouped.set_columnwise_data(colwise_out->untyped_data(), out_dtype, data_shape) + .set_columnwise_scale_inv(colwise_sinv->untyped_data(), sinv_dtype, cw_sinv_shape); + } + + // Zero-initialize scale_inv buffers (mirrors V1 behaviour for MXFP8). + size_t total_rowwise_sinv_size = + is_quantize_rowwise(quantize_layout) ? product(rowwise_sinv->dimensions()) : 0; + size_t total_colwise_sinv_size = + is_quantize_colwise(quantize_layout) ? product(colwise_sinv->dimensions()) : 0; + if (total_rowwise_sinv_size > 0) + nvte_memset(rowwise_sinv->untyped_data(), 0, total_rowwise_sinv_size, stream); + if (total_colwise_sinv_size > 0) + nvte_memset(colwise_sinv->untyped_data(), 0, total_colwise_sinv_size, stream); + + // V2 grouped quantize is always paired with V2 grouped GEMM, which expects + // scale_inv in GEMM-swizzled layout. Enable the fused swizzle so the kernel + // writes scales in the layout the GEMM will consume directly. + out_grouped.set_with_gemm_swizzled_scales(true); + + QuantizationConfigWrapper quant_config{}; + nvte_group_quantize(in_grouped.data(), out_grouped.data(), quant_config, stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeV2Handler, GroupedQuantizeV2FFI, + FFI::Bind() + .Ctx() // stream + .Arg() // inputs + .Arg() // scale (unused, for input arity match) + .Arg() // group_sizes (int32) + .Ret() // rowwise_out + .Ret() // colwise_out + .Ret() // rowwise_sinv + .Ret() // colwise_sinv + .Ret() // int64_workspace + .Attr("q_layout") + .Attr("flatten_axis"), + FFI_CudaGraph_Traits); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 31ce6e72e9..17c9a242f0 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -16,6 +16,9 @@ from jax import random as jax_random from jax.ad_checkpoint import checkpoint_name +from transformer_engine.common.recipe import ( + MXFP8BlockScaling, +) from ..dense import dense, grouped_dense @@ -1358,7 +1361,12 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): return out, ln_output # Output, layer_norm_output -def wrap_function_in_te_state_module(f, quantization_recipe, name: Optional[str] = None): +def wrap_function_in_te_state_module( + f, + quantization_recipe, + name: Optional[str] = None, + quantization_checkpoint_name: Optional[str] = None, +): """Wraps the given function `f` to support TransformerEngine quantization. This method does a couple things: @@ -1386,6 +1394,7 @@ def generate_quantizer_set(self, postfix: str = "", n_groups: int = None): return super().generate_quantizer_set( postfix=postfix, variable_collection=OVERWRITE_WITH_GRADIENT, + quantization_checkpoint_name=quantization_checkpoint_name, fp8_recipe=quantization_recipe, n_groups=n_groups, ) @@ -1443,10 +1452,15 @@ def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs): return wrap_function_in_te_state_module(te_dot_general, quantization_recipe, "dot_general") -def make_grouped_dense_cls(quantization_recipe): +def make_grouped_dense_cls(quantization_recipe, quantization_checkpoint_name: Optional[str] = None): """Creates a grouped dense (grouped GEMM) instance for use with TE state module.""" if quantization_recipe is not None: - raise ValueError("Ragged dot grouped GEMM does not support quantization yet") + allowed_grouped_gemm_recipes = [MXFP8BlockScaling] + assert any(isinstance(quantization_recipe, r) for r in allowed_grouped_gemm_recipes), ( + "Only the following quantization recipes are supported for grouped GEMM or `None` for" + f" BF16 without quantization: {allowed_grouped_gemm_recipes}. Got" + f" {type(quantization_recipe)}." + ) def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwargs): del kwargs # Unused @@ -1463,5 +1477,8 @@ def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwa return out return wrap_function_in_te_state_module( - te_grouped_dot_general, quantization_recipe, "ragged_dot" + te_grouped_dot_general, + quantization_recipe, + "ragged_dot", + quantization_checkpoint_name=quantization_checkpoint_name, )() diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 5abb2e74df..b46e4ff9d5 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -263,7 +263,37 @@ def dequantize(scaled_tensor): } -@staticmethod +def _unswizzle_mxfp8_grouped_scale(scale_inv_flat, padded_scale_2d, is_colwise): + """Un-swizzle MXFP8 GEMM-swizzled scale_inv back to plain layout. + + Both V1 and V2 MXFP8 grouped quantize produce scale_inv in a GEMM-swizzled + layout. This is the inverse of ``swizzled_scale`` in ``gemm.py``. + + The swizzle pattern (for rowwise) is: + reshape(R//128, 4, 32, C//4, 4) → transpose(0,3,2,1,4) → reshape(R, C) + The inverse is: + reshape(R//128, C//4, 32, 4, 4) → transpose(0,3,2,1,4) → reshape(R, C) + + For colwise the swizzle is applied to the transposed scale, so the inverse + must un-transpose as well. + """ + if is_colwise: + # Colwise forward: reshape_2d → transpose → swizzle_5d → reshape_original + # Inverse: reshape_to_5d → inverse_swizzle → reshape_to_transposed_2d → transpose + cols, rows = padded_scale_2d + scale_2d = scale_inv_flat.reshape(cols, rows) + # The swizzled data lives in the transposed (rows, cols) domain + reshaped = scale_2d.reshape(rows // 128, cols // 4, 32, 4, 4) + unswizzled = jnp.transpose(reshaped, (0, 3, 2, 1, 4)) + # Back to transposed 2D, then un-transpose + return jnp.transpose(unswizzled.reshape(rows, cols)) + else: + rows, cols = padded_scale_2d + reshaped = scale_inv_flat.reshape(rows // 128, cols // 4, 32, 4, 4) + unswizzled = jnp.transpose(reshaped, (0, 3, 2, 1, 4)) + return unswizzled.reshape(rows, cols) + + def _grouped_dequantize(grouped_scaled_tensor): """Dequantize a grouped tensor. @@ -290,12 +320,13 @@ def _grouped_dequantize(grouped_scaled_tensor): flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis output = [] - # For transposed (colwise) tensors with ragged groups, the group dimension is the last - # axis of original_shape (e.g. original_shape = (N, M) with groups along M), while the - # non-group dimensions are all axes before it. For the uniform-groups case the group - # dimension stays at axis 0, so the existing axis-0 logic applies. + # When data_layout=="T" (colwise, transposed) and first_dims is set (ragged groups), the + # original_shape is stored transposed: the group (variable-size) axis is the LAST dimension + # rather than the first. Non-group dims are original_shape[:-1], not original_shape[1:]. is_transposed_ragged = ( - grouped_scaled_tensor.data_layout == "T" and group_sizes.size != original_shape[0] + grouped_scaled_tensor.data_layout == "T" + and grouped_scaled_tensor.first_dims is not None + and grouped_scaled_tensor.first_dims.size > 0 ) if is_transposed_ragged: non_group_shape = original_shape[:-1] @@ -308,7 +339,7 @@ def _grouped_dequantize(grouped_scaled_tensor): scale_inv_ptr = 0 for i, data_i in enumerate(data): if is_transposed_ragged: - data_shape_i = (*non_group_shape, group_sizes[i]) + data_shape_i = (*non_group_shape, int(group_sizes[i])) else: data_shape_i = ( group_sizes[i], @@ -330,24 +361,49 @@ def _grouped_dequantize(grouped_scaled_tensor): is_padded=False, flatten_axis=flatten_axis, ) - scale_inv_i = scale_inv[ - scale_inv_ptr : scale_inv_ptr + math.prod(padded_scale_shape_i) - ].reshape(padded_scale_shape_i) - scale_inv_i = jax.lax.slice( - scale_inv_i, [0] * len(unpadded_scale_shape_i), unpadded_scale_shape_i - ) + scale_inv_i = scale_inv[scale_inv_ptr : scale_inv_ptr + math.prod(padded_scale_shape_i)] + # MXFP8 grouped quantize (both V1 and V2) always produces GEMM-swizzled + # scales. Detect by scaling_mode (not pre_swizzled, which is only set for V2 + # to maintain pytree compatibility with the GEMM path). + is_colwise = grouped_scaled_tensor.is_colwise + needs_unswizzle = scaling_mode == ScalingMode.MXFP8_1D_SCALING + if needs_unswizzle: + flat_data_2d = ( + math.prod(data_shape_i[:flatten_axis]), + math.prod(data_shape_i[flatten_axis:]), + ) + padded_2d = scaling_mode.get_scale_shape( + flat_data_2d, is_colwise=is_colwise, is_padded=True, flatten_axis=1 + ) + unpadded_2d = scaling_mode.get_scale_shape( + flat_data_2d, is_colwise=is_colwise, is_padded=False, flatten_axis=1 + ) + scale_inv_i = _unswizzle_mxfp8_grouped_scale(scale_inv_i, padded_2d, is_colwise) + scale_inv_i = jax.lax.slice(scale_inv_i, [0, 0], list(unpadded_2d)) + else: + scale_inv_i = scale_inv_i.reshape(padded_scale_shape_i) + scale_inv_i = jax.lax.slice( + scale_inv_i, [0] * len(unpadded_scale_shape_i), unpadded_scale_shape_i + ) dequantizer_type = ScalingModeToDequantizerMap.get(grouped_scaled_tensor.scaling_mode) if len(data_i) == 0: out_i = [] else: + # _dequantize_func is designed for 2D-flattened data. Flatten the + # per-group shape to 2D, dequantize, then reshape back. + flat_shape_i = ( + math.prod(data_shape_i[:flatten_axis]), + math.prod(data_shape_i[flatten_axis:]), + ) out_i = dequantizer_type._dequantize_func( - data_i.reshape(data_shape_i), + data_i.reshape(flat_shape_i), scale_inv_i, grouped_scaled_tensor.dq_dtype, scaling_mode=grouped_scaled_tensor.scaling_mode, is_colwise=grouped_scaled_tensor.is_colwise, - flatten_axis=grouped_scaled_tensor.flatten_axis, + flatten_axis=1, ) + out_i = out_i.reshape(data_shape_i) output.append(out_i) scale_inv_ptr += math.prod(padded_scale_shape_i) diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index b1f49dacdc..c5ad0451fd 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -369,11 +369,15 @@ class GroupedScaledTensor1x(ScaledTensor1x): first_dims: Per-group sizes of the first (row) 2D dim, or None if not ragged last_dims: Per-group sizes of the last (col) 2D dim, or None if not ragged original_shape: The original shape of the tensor before grouping + pre_swizzled: Whether the scale_inv is already swizzled for GEMM. True when produced + by V2 grouped quantize (nvte_group_quantize fuses the swizzle). The V2 grouped + GEMM FFI requires pre_swizzled=True for MXFP8 inputs and will not re-swizzle. """ first_dims: Optional[jnp.ndarray] last_dims: Optional[jnp.ndarray] original_shape: Tuple + pre_swizzled: bool = False def __init__( self, @@ -389,11 +393,13 @@ def __init__( data_layout, flatten_axis, original_shape, + pre_swizzled=False, ): self.flatten_axis = flatten_axis self.first_dims = first_dims self.last_dims = last_dims self.original_shape = original_shape + self.pre_swizzled = pre_swizzled # TODO(Phuong):Handle RHT for grouped quantization once grouped quantization supports NVFP4 super().__init__( data=data, @@ -408,6 +414,18 @@ def __init__( has_rht_applied=False, ) + @property + def group_sizes(self) -> jnp.ndarray: + """Per-group sizes along the group axis. + + When first_dims is set (ragged groups), returns first_dims. + When first_dims is None (equal-sized groups), returns an array of ones with + length equal to the number of groups. + """ + if self.first_dims is not None and self.first_dims.size > 0: + return self.first_dims + return jnp.ones((self.original_shape[0],), dtype=jnp.int32) + def __post_init__(self): assert self.scale_inv.ndim == 1, "Only support flattened scale_inv" assert self.data.ndim == 1, "Only support flattened data" @@ -456,6 +474,7 @@ def tree_flatten(self): self.data_layout, self.flatten_axis, self.original_shape, + self.pre_swizzled, ) return (children, aux_data) @@ -653,6 +672,7 @@ def create_1x( last_dims=None, original_shape=None, has_rht_applied=False, + pre_swizzled=False, ): """Creates a single-scale quantized tensor. @@ -722,6 +742,7 @@ def create_1x( first_dims=first_dims, last_dims=last_dims, original_shape=original_shape, + pre_swizzled=pre_swizzled, ) # Handling attrs of transposed tensors @@ -759,6 +780,7 @@ def create_2x( original_shape=None, rowwise_has_rht_applied=False, colwise_has_rht_applied=False, + pre_swizzled=False, ): """Creates a double-scale quantized tensor. @@ -800,6 +822,7 @@ def create_2x( last_dims=last_dims, original_shape=original_shape, has_rht_applied=rowwise_has_rht_applied, + pre_swizzled=pre_swizzled, ) colwise_tensor = ScaledTensorFactory.create_1x( colwise_data, @@ -814,6 +837,7 @@ def create_2x( last_dims=last_dims, original_shape=original_shape, has_rht_applied=colwise_has_rht_applied, + pre_swizzled=pre_swizzled, ) return ScaledTensor2x(rowwise_tensor, colwise_tensor) @@ -835,6 +859,7 @@ def create( original_shape: Tuple[int] = None, rowwise_has_rht_applied: bool = False, colwise_has_rht_applied: bool = False, + pre_swizzled: bool = False, ): """Creates a scaled tensor based on the quantization axis. @@ -853,6 +878,7 @@ def create( original_shape: The original shape of the tensor before grouping (default: None) rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) colwise_has_rht_applied: Whether the col-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) + pre_swizzled: Whether scale_inv is already swizzled (produced by V2 grouped quantize). Returns: Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout @@ -876,6 +902,7 @@ def create( original_shape=original_shape, rowwise_has_rht_applied=rowwise_has_rht_applied, colwise_has_rht_applied=colwise_has_rht_applied, + pre_swizzled=pre_swizzled, ) if q_layout.is_colwise_only: @@ -892,6 +919,7 @@ def create( last_dims=last_dims, original_shape=original_shape, has_rht_applied=colwise_has_rht_applied, + pre_swizzled=pre_swizzled, ) return ScaledTensorFactory.create_1x( @@ -907,6 +935,7 @@ def create( last_dims=last_dims, original_shape=original_shape, has_rht_applied=rowwise_has_rht_applied, + pre_swizzled=pre_swizzled, )