Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@
[submodule "3rdparty/cutlass"]
path = 3rdparty/cutlass
url = https://github.com/NVIDIA/cutlass.git
[submodule "3rdparty/cccl"]
path = 3rdparty/cccl
url = https://github.com/NVIDIA/cccl.git
1 change: 1 addition & 0 deletions 3rdparty/cccl
Submodule cccl added at c262ef
40 changes: 40 additions & 0 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from transformer_engine.jax.activation import activation
from transformer_engine.jax.dense import dense, grouped_dense
from transformer_engine.jax.layernorm_dense import layernorm_dense
from transformer_engine.jax.cpp_extensions.cub import topk

GEMM_CASES = [
(256, 256, 512),
Expand Down Expand Up @@ -1955,3 +1956,42 @@ def f(x):
actual = load_array_dump("my_tensor_gpu0.bin", shape, dtype)

assert_allclose(actual, expected, dtype=dtype)


@pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16, jnp.float32])
@pytest.mark.parametrize(
"problem_size", [(10000, 100), (50000, 200), (100000, 500), (1000000, 1000), (5000000, 2000)]
)
class TestTopk:
def test_topk(self, dtype, problem_size):
n, k = problem_size

prng_key = jax.random.PRNGKey(0)
keys = jax.random.split(prng_key, 3)
topk_values = jax.random.uniform(keys[0], shape=(k,), dtype=dtype, minval=1.5, maxval=2.5)
bottom_values = jax.random.uniform(
keys[1], shape=(n - k,), dtype=dtype, minval=0.0, maxval=1.0
)
x = jnp.concatenate([topk_values, bottom_values])
x = jax.random.permutation(keys[2], x)

ref_topk_jit = jax.jit(jax.lax.top_k, static_argnums=(1,))
prim_topk_jit = jax.jit(topk, static_argnums=(1,))

ref_topk, ref_indices = ref_topk_jit(x, k)
prim_topk, prim_indices = prim_topk_jit(x, k)

# CUB output does not guarantee the order of the topk values, sort them for comparison
ref_topk, ref_indices = jax.lax.sort_key_val(ref_topk, ref_indices)
prim_topk, prim_indices = jax.lax.sort_key_val(prim_topk, prim_indices)

assert_allclose(ref_topk, prim_topk, dtype=dtype)

# sort and sort_key_val are ascending, make sure the smallest topk value
# prim_topk[0] is not smaller than the k+1 largest value in the original array
sorted_x = jax.lax.sort(x)
assert prim_topk[0] >= sorted_x[-(k + 1)]

# TopK values can be duplicated, instead of directly comparing the indices, we check
# if the values at the returned indices are the same
assert_allclose(x[ref_indices], x[prim_indices], dtype=dtype)
16 changes: 15 additions & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,16 @@ set(CUTLASS_INCLUDE_DIR
set(CUTLASS_TOOLS_INCLUDE_DIR
"${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/tools/util/include")

# CCCL (CUDA Core Compute Libraries)
set(CCCL_INCLUDE_DIR
"${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cccl")
if(NOT EXISTS "${CCCL_INCLUDE_DIR}")
message(FATAL_ERROR
"Could not find CCCL at ${CCCL_INCLUDE_DIR}. "
"Try running 'git submodule update --init --recursive' "
"within the Transformer Engine source.")
endif()

# Python
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)

Expand Down Expand Up @@ -151,6 +161,7 @@ list(APPEND transformer_engine_cuda_sources
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu
util/padding.cu
util/cub.cu
swizzle/swizzle.cu
swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_masked_softmax.cu
Expand Down Expand Up @@ -262,8 +273,11 @@ target_link_libraries(transformer_engine PUBLIC

target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
# Use CCCL from 3rdparty instead of the one from CUDA Toolkit
target_include_directories(transformer_engine SYSTEM PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl)
${CCCL_INCLUDE_DIR}/thrust
${CCCL_INCLUDE_DIR}/cub
${CCCL_INCLUDE_DIR}/libcudacxx/include)
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
target_include_directories(transformer_engine PRIVATE
${CUTLASS_INCLUDE_DIR}
Expand Down
39 changes: 39 additions & 0 deletions transformer_engine/common/include/transformer_engine/cub.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#ifndef TRANSFORMER_ENGINE_CUB_H_
#define TRANSFORMER_ENGINE_CUB_H_

#include "transformer_engine.h"

#ifdef __cplusplus
extern "C" {
#endif

/*! \brief Compute the top-K largest (key, value) pairs.
*
* \param[in] stream CUDA stream used for the operation.
* \param[in] keys_in Input 1D keys tensor, shape (num_items,)
* \param[in] values_in Input 1D values tensor, shape (num_items,)
* \param[in,out] keys_out Output 1D keys tensor, shape (k,)
* \param[in,out] values_out Output 1D values tensor, shape (k,)
* \param[in,out] workspace Workspace tensor, shape (workspace_bytes,)
* \param[in] num_items Number of items in the input tensor
* \param[in] k Number of top-K largest values to return
* \param[in] workspace_bytes Workspace size in bytes
*
* Requirements:
* - Only supports float32, float16, bfloat16 keys and int32 values.
*/
void nvte_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor values_in,
NVTETensor keys_out, NVTETensor values_out, NVTETensor workspace,
const int num_items, const int k, const size_t workspace_bytes);

#ifdef __cplusplus
} // extern "C"
#endif

#endif
54 changes: 54 additions & 0 deletions transformer_engine/common/util/cub.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include <transformer_engine/cub.h>

#include <cub/device/device_topk.cuh>
#include <cuda/std/execution>

#include "../common.h"

void nvte_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor values_in,
NVTETensor keys_out, NVTETensor values_out, NVTETensor workspace, int num_items,
int k, size_t workspace_bytes) {
NVTE_API_CALL(nvte_topk);
using namespace transformer_engine;

const Tensor *keys_in_tensor = convertNVTETensorCheck(keys_in);
const Tensor *values_in_tensor = convertNVTETensorCheck(values_in);
Tensor *keys_out_tensor = convertNVTETensor(keys_out);
Tensor *values_out_tensor = convertNVTETensor(values_out);
Tensor *workspace_tensor = convertNVTETensor(workspace);
auto keys_in_dtype = keys_in_tensor->data.dtype;
auto values_in_dtype = values_in_tensor->data.dtype;

auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed,
cuda::execution::output_ordering::unsorted);
cuda::stream_ref stream_ref{stream};
auto env = cuda::std::execution::env{stream_ref, requirements};

#define DISPATCH_CUB_TOPK(KeyT, ValueT) \
do { \
KeyT *d_keys_in = reinterpret_cast<KeyT *>(keys_in_tensor->data.dptr); \
KeyT *d_keys_out = reinterpret_cast<KeyT *>(keys_out_tensor->data.dptr); \
ValueT *d_values_in = reinterpret_cast<ValueT *>(values_in_tensor->data.dptr); \
ValueT *d_values_out = reinterpret_cast<ValueT *>(values_out_tensor->data.dptr); \
void *d_workspace = reinterpret_cast<void *>(workspace_tensor->data.dptr); \
cub::DeviceTopK::MaxPairs(d_workspace, workspace_bytes, d_keys_in, d_keys_out, d_values_in, \
d_values_out, num_items, k, env); \
} while (0);

if (keys_in_dtype == DType::kFloat32 && values_in_dtype == DType::kInt32) {
DISPATCH_CUB_TOPK(float, int);
} else if (keys_in_dtype == DType::kFloat16 && values_in_dtype == DType::kInt32) {
DISPATCH_CUB_TOPK(__half, int);
} else if (keys_in_dtype == DType::kBFloat16 && values_in_dtype == DType::kInt32) {
DISPATCH_CUB_TOPK(__nv_bfloat16, int);
} else {
NVTE_ERROR("Unsupported input key and value data types");
}
#undef DISPATCH_CUB_TOPK
}
111 changes: 111 additions & 0 deletions transformer_engine/jax/cpp_extensions/cub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""CUB custom ops"""

from typing import Tuple

import jax
import jax.numpy as jnp
from jax import dtypes, ffi

from .base import BasePrimitive, register_primitive

__all__ = ["topk"]


def get_cub_topk_workspace_bytes() -> int:
"""
Get the workspace size for CUB Topk
The safe way is calling the CUB kernel to query the workspace size.
However, JAX JIT compiling needs a fixed tensor size. Using 4MB as
a WAR since it is large enough for N up to 5,000,000 and K up to 100,000.
"""
return 4 * 1024 * 1024
Comment on lines +17 to +24
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Hardcoded workspace size may silently corrupt memory for large inputs

get_cub_topk_workspace_bytes() always returns a fixed 4 MiB and the docstring itself acknowledges this only covers "N up to 5,000,000 and K up to 100,000." However, there is no validation in the Python or C++ layer that the user's actual N and K do not exceed these limits.

If a caller passes N > 5_000_000 or K > 100_000, cub::DeviceTopK::MaxPairs will be given an undersized workspace buffer and will write out-of-bounds on the GPU — a silent CUDA memory corruption with no error raised back to the caller.

The correct approach is to call cub::DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, ...) with a null workspace pointer to query the required size at runtime, then allocate that exact amount. The current heuristic should at minimum be accompanied by a runtime guard that raises an error when the inputs exceed the documented limits.



class TopKPrimitive(BasePrimitive):
"""
Topk Primitive
"""

name = "te_topk_ffi"
multiple_results = True
impl_static_args = (2,) # k_value
inner_primitive = None
outer_primitive = None

@staticmethod
def abstract(
in_keys_aval,
in_values_aval,
*,
k_value,
):
keys_dtype = dtypes.canonicalize_dtype(in_keys_aval.dtype)
values_dtype = dtypes.canonicalize_dtype(in_values_aval.dtype)
assert keys_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert values_dtype == jnp.int32

workspace_bytes = get_cub_topk_workspace_bytes()
out_keys_aval = jax.core.ShapedArray(shape=(k_value,), dtype=keys_dtype)
out_values_aval = jax.core.ShapedArray(shape=(k_value,), dtype=jnp.int32)
workspace_aval = jax.core.ShapedArray(shape=(workspace_bytes,), dtype=jnp.uint8)
return (out_keys_aval, out_values_aval, workspace_aval)

@staticmethod
def outer_abstract(*args, **kwargs):
out_keys_aval, out_values_aval, _workspace_aval = TopKPrimitive.abstract(*args, **kwargs)
return (out_keys_aval, out_values_aval)

@staticmethod
def lowering(
ctx,
in_keys,
in_values,
k_value,
):
workspace_bytes = get_cub_topk_workspace_bytes()
return ffi.ffi_lowering(
TopKPrimitive.name,
)(
ctx,
in_keys,
in_values,
k_value=k_value,
workbuf_bytes=workspace_bytes,
)

@staticmethod
def impl(
in_keys,
in_values,
k_value,
):
assert TopKPrimitive.inner_primitive is not None
out_keys, out_values, _workspace = TopKPrimitive.inner_primitive.bind(
in_keys,
in_values,
k_value=k_value,
)
return (out_keys, out_values)


register_primitive(TopKPrimitive)


def topk(
x: jnp.ndarray,
k_value: int,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Topk max pairs
"""
keys = x
values = jnp.arange(x.shape[0], dtype=jnp.int32)
out_keys, out_values = TopKPrimitive.outer_primitive.bind(
keys,
values,
k_value=k_value,
)
return out_keys, out_values
3 changes: 3 additions & 0 deletions transformer_engine/jax/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler);

// Topk
XLA_FFI_DECLARE_HANDLER_SYMBOL(TopkHandler);

} // namespace jax
} // namespace transformer_engine

Expand Down
73 changes: 73 additions & 0 deletions transformer_engine/jax/csrc/extensions/cub.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include "transformer_engine/cub.h"

#include "../extensions.h"
#include "xla/ffi/api/c_api.h"

namespace transformer_engine {
namespace jax {

Error_Type TopkFFI(cudaStream_t stream, Buffer_Type keys_in_buf, Buffer_Type values_in_buf,
Result_Type keys_out_buf, Result_Type values_out_buf, Result_Type workspace_buf,
int64_t k_value, int64_t workbuf_bytes) {
auto keys_in_dtype = convert_ffi_datatype_to_te_dtype(keys_in_buf.element_type());
auto values_in_dtype = convert_ffi_datatype_to_te_dtype(values_in_buf.element_type());
auto keys_out_dtype = convert_ffi_datatype_to_te_dtype(keys_out_buf->element_type());
auto values_out_dtype = convert_ffi_datatype_to_te_dtype(values_out_buf->element_type());
NVTE_CHECK(keys_in_dtype == keys_out_dtype, "Input and output keys must have the same datatype");
NVTE_CHECK(values_in_dtype == values_out_dtype,
"Input and output values must have the same datatype");
NVTE_CHECK(values_in_dtype == DType::kInt32, "CubTopkFFI() only supports int32 values for now");

auto keys_in_shape = keys_in_buf.dimensions();
auto values_in_shape = values_in_buf.dimensions();
auto keys_out_shape = keys_out_buf->dimensions();
auto values_out_shape = values_out_buf->dimensions();
NVTE_CHECK(keys_in_shape.size() == 1, "Keys input must have 1 dimension");
NVTE_CHECK(values_in_shape.size() == 1, "Values input must have 1 dimension");
NVTE_CHECK(keys_out_shape.size() == 1, "Keys output must have 1 dimension");
NVTE_CHECK(values_out_shape.size() == 1, "Values output must have 1 dimension");
NVTE_CHECK(keys_in_shape[0] == values_in_shape[0],
"Keys and values input must have the same number of items");
NVTE_CHECK(keys_out_shape[0] == values_out_shape[0],
"Keys and values output must have the same number of items");
int num_items = static_cast<int>(keys_in_shape[0]);
int k = static_cast<int>(k_value);
Comment on lines +39 to +40
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 No validation that k <= num_items

There is no check that k_value is less than or equal to num_items (the size of the input array). CUB's DeviceTopK::MaxPairs requires k <= num_items; if k > num_items the behavior is undefined and will likely produce a CUDA error or garbage output.

A guard should be added here alongside the existing shape checks:

NVTE_CHECK(k <= num_items, "k (", k, ") must be <= num_items (", num_items, ")");


auto input_shape = std::vector<size_t>{keys_in_shape[0]};
auto output_shape = std::vector<size_t>{keys_out_shape[0]};
auto workspace_shape = std::vector<size_t>{workbuf_bytes};

auto keys_in_tensor = TensorWrapper(keys_in_buf.untyped_data(), input_shape, keys_in_dtype);
auto values_in_tensor = TensorWrapper(values_in_buf.untyped_data(), input_shape, values_in_dtype);
auto keys_out_tensor = TensorWrapper(keys_out_buf->untyped_data(), output_shape, keys_out_dtype);
auto values_out_tensor =
TensorWrapper(values_out_buf->untyped_data(), output_shape, values_out_dtype);
auto workspace_tensor =
TensorWrapper(workspace_buf->untyped_data(), workspace_shape, DType::kByte);

nvte_topk(stream, keys_in_tensor.data(), values_in_tensor.data(), keys_out_tensor.data(),
values_out_tensor.data(), workspace_tensor.data(), num_items, k, workbuf_bytes);

return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(TopkHandler, TopkFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // keys_buf
.Arg<Buffer_Type>() // values_buf
.Ret<Buffer_Type>() // topk_buf
.Ret<Buffer_Type>() // indices_buf
.Ret<Buffer_Type>() // workspace_buf
.Attr<int64_t>("k_value")
.Attr<int64_t>("workbuf_bytes"),
FFI_CudaGraph_Traits);

} // namespace jax
} // namespace transformer_engine
3 changes: 3 additions & 0 deletions transformer_engine/jax/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ pybind11::dict Registrations() {
dict["te_fused_moe_aux_loss_forward_ffi"] = EncapsulateFFI(FusedMoEAuxLossForwardHandler);
dict["te_fused_moe_aux_loss_backward_ffi"] = EncapsulateFFI(FusedMoEAuxLossBackwardHandler);

// Cub Topk
dict["te_topk_ffi"] = EncapsulateFFI(TopkHandler);

return dict;
}

Expand Down