-
Notifications
You must be signed in to change notification settings - Fork 683
[Common][JAX] Add CUB TopK MaxPairs interface #2784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| /************************************************************************* | ||
| * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
| ************************************************************************/ | ||
|
|
||
| #ifndef TRANSFORMER_ENGINE_CUB_H_ | ||
| #define TRANSFORMER_ENGINE_CUB_H_ | ||
|
|
||
| #include "transformer_engine.h" | ||
|
|
||
| #ifdef __cplusplus | ||
| extern "C" { | ||
| #endif | ||
|
|
||
| /*! \brief Compute the top-K largest (key, value) pairs. | ||
| * | ||
| * \param[in] stream CUDA stream used for the operation. | ||
| * \param[in] keys_in Input 1D keys tensor, shape (num_items,) | ||
| * \param[in] values_in Input 1D values tensor, shape (num_items,) | ||
| * \param[in,out] keys_out Output 1D keys tensor, shape (k,) | ||
| * \param[in,out] values_out Output 1D values tensor, shape (k,) | ||
| * \param[in,out] workspace Workspace tensor, shape (workspace_bytes,) | ||
| * \param[in] num_items Number of items in the input tensor | ||
| * \param[in] k Number of top-K largest values to return | ||
| * \param[in] workspace_bytes Workspace size in bytes | ||
| * | ||
| * Requirements: | ||
| * - Only supports float32, float16, bfloat16 keys and int32 values. | ||
| */ | ||
| void nvte_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor values_in, | ||
| NVTETensor keys_out, NVTETensor values_out, NVTETensor workspace, | ||
| const int num_items, const int k, const size_t workspace_bytes); | ||
|
|
||
| #ifdef __cplusplus | ||
| } // extern "C" | ||
| #endif | ||
|
|
||
| #endif |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| /************************************************************************* | ||
| * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
| ************************************************************************/ | ||
|
|
||
| #include <transformer_engine/cub.h> | ||
|
|
||
| #include <cub/device/device_topk.cuh> | ||
| #include <cuda/std/execution> | ||
|
|
||
| #include "../common.h" | ||
|
|
||
| void nvte_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor values_in, | ||
| NVTETensor keys_out, NVTETensor values_out, NVTETensor workspace, int num_items, | ||
| int k, size_t workspace_bytes) { | ||
| NVTE_API_CALL(nvte_topk); | ||
| using namespace transformer_engine; | ||
|
|
||
| const Tensor *keys_in_tensor = convertNVTETensorCheck(keys_in); | ||
| const Tensor *values_in_tensor = convertNVTETensorCheck(values_in); | ||
| Tensor *keys_out_tensor = convertNVTETensor(keys_out); | ||
| Tensor *values_out_tensor = convertNVTETensor(values_out); | ||
| Tensor *workspace_tensor = convertNVTETensor(workspace); | ||
| auto keys_in_dtype = keys_in_tensor->data.dtype; | ||
| auto values_in_dtype = values_in_tensor->data.dtype; | ||
|
|
||
| auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed, | ||
| cuda::execution::output_ordering::unsorted); | ||
| cuda::stream_ref stream_ref{stream}; | ||
| auto env = cuda::std::execution::env{stream_ref, requirements}; | ||
|
|
||
| #define DISPATCH_CUB_TOPK(KeyT, ValueT) \ | ||
| do { \ | ||
| KeyT *d_keys_in = reinterpret_cast<KeyT *>(keys_in_tensor->data.dptr); \ | ||
| KeyT *d_keys_out = reinterpret_cast<KeyT *>(keys_out_tensor->data.dptr); \ | ||
| ValueT *d_values_in = reinterpret_cast<ValueT *>(values_in_tensor->data.dptr); \ | ||
| ValueT *d_values_out = reinterpret_cast<ValueT *>(values_out_tensor->data.dptr); \ | ||
| void *d_workspace = reinterpret_cast<void *>(workspace_tensor->data.dptr); \ | ||
| cub::DeviceTopK::MaxPairs(d_workspace, workspace_bytes, d_keys_in, d_keys_out, d_values_in, \ | ||
| d_values_out, num_items, k, env); \ | ||
| } while (0); | ||
|
|
||
| if (keys_in_dtype == DType::kFloat32 && values_in_dtype == DType::kInt32) { | ||
| DISPATCH_CUB_TOPK(float, int); | ||
| } else if (keys_in_dtype == DType::kFloat16 && values_in_dtype == DType::kInt32) { | ||
| DISPATCH_CUB_TOPK(__half, int); | ||
| } else if (keys_in_dtype == DType::kBFloat16 && values_in_dtype == DType::kInt32) { | ||
| DISPATCH_CUB_TOPK(__nv_bfloat16, int); | ||
| } else { | ||
| NVTE_ERROR("Unsupported input key and value data types"); | ||
| } | ||
| #undef DISPATCH_CUB_TOPK | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,111 @@ | ||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
| """CUB custom ops""" | ||
|
|
||
| from typing import Tuple | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| from jax import dtypes, ffi | ||
|
|
||
| from .base import BasePrimitive, register_primitive | ||
|
|
||
| __all__ = ["topk"] | ||
|
|
||
|
|
||
| def get_cub_topk_workspace_bytes() -> int: | ||
| """ | ||
| Get the workspace size for CUB Topk | ||
| The safe way is calling the CUB kernel to query the workspace size. | ||
| However, JAX JIT compiling needs a fixed tensor size. Using 4MB as | ||
| a WAR since it is large enough for N up to 5,000,000 and K up to 100,000. | ||
| """ | ||
| return 4 * 1024 * 1024 | ||
|
Comment on lines
+17
to
+24
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If a caller passes The correct approach is to call |
||
|
|
||
|
|
||
| class TopKPrimitive(BasePrimitive): | ||
| """ | ||
| Topk Primitive | ||
| """ | ||
|
|
||
| name = "te_topk_ffi" | ||
| multiple_results = True | ||
| impl_static_args = (2,) # k_value | ||
| inner_primitive = None | ||
| outer_primitive = None | ||
|
|
||
| @staticmethod | ||
| def abstract( | ||
| in_keys_aval, | ||
| in_values_aval, | ||
| *, | ||
| k_value, | ||
| ): | ||
| keys_dtype = dtypes.canonicalize_dtype(in_keys_aval.dtype) | ||
| values_dtype = dtypes.canonicalize_dtype(in_values_aval.dtype) | ||
| assert keys_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] | ||
| assert values_dtype == jnp.int32 | ||
|
|
||
| workspace_bytes = get_cub_topk_workspace_bytes() | ||
| out_keys_aval = jax.core.ShapedArray(shape=(k_value,), dtype=keys_dtype) | ||
| out_values_aval = jax.core.ShapedArray(shape=(k_value,), dtype=jnp.int32) | ||
| workspace_aval = jax.core.ShapedArray(shape=(workspace_bytes,), dtype=jnp.uint8) | ||
| return (out_keys_aval, out_values_aval, workspace_aval) | ||
|
|
||
| @staticmethod | ||
| def outer_abstract(*args, **kwargs): | ||
| out_keys_aval, out_values_aval, _workspace_aval = TopKPrimitive.abstract(*args, **kwargs) | ||
| return (out_keys_aval, out_values_aval) | ||
|
|
||
| @staticmethod | ||
| def lowering( | ||
| ctx, | ||
| in_keys, | ||
| in_values, | ||
| k_value, | ||
| ): | ||
| workspace_bytes = get_cub_topk_workspace_bytes() | ||
| return ffi.ffi_lowering( | ||
| TopKPrimitive.name, | ||
| )( | ||
| ctx, | ||
| in_keys, | ||
| in_values, | ||
| k_value=k_value, | ||
| workbuf_bytes=workspace_bytes, | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def impl( | ||
| in_keys, | ||
| in_values, | ||
| k_value, | ||
| ): | ||
| assert TopKPrimitive.inner_primitive is not None | ||
| out_keys, out_values, _workspace = TopKPrimitive.inner_primitive.bind( | ||
| in_keys, | ||
| in_values, | ||
| k_value=k_value, | ||
| ) | ||
| return (out_keys, out_values) | ||
|
|
||
|
|
||
| register_primitive(TopKPrimitive) | ||
|
|
||
|
|
||
| def topk( | ||
| x: jnp.ndarray, | ||
| k_value: int, | ||
| ) -> Tuple[jnp.ndarray, jnp.ndarray]: | ||
| """ | ||
| Topk max pairs | ||
| """ | ||
| keys = x | ||
| values = jnp.arange(x.shape[0], dtype=jnp.int32) | ||
| out_keys, out_values = TopKPrimitive.outer_primitive.bind( | ||
| keys, | ||
| values, | ||
| k_value=k_value, | ||
| ) | ||
| return out_keys, out_values | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| /************************************************************************* | ||
| * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
| ************************************************************************/ | ||
|
|
||
| #include "transformer_engine/cub.h" | ||
|
|
||
| #include "../extensions.h" | ||
| #include "xla/ffi/api/c_api.h" | ||
|
|
||
| namespace transformer_engine { | ||
| namespace jax { | ||
|
|
||
| Error_Type TopkFFI(cudaStream_t stream, Buffer_Type keys_in_buf, Buffer_Type values_in_buf, | ||
| Result_Type keys_out_buf, Result_Type values_out_buf, Result_Type workspace_buf, | ||
| int64_t k_value, int64_t workbuf_bytes) { | ||
| auto keys_in_dtype = convert_ffi_datatype_to_te_dtype(keys_in_buf.element_type()); | ||
| auto values_in_dtype = convert_ffi_datatype_to_te_dtype(values_in_buf.element_type()); | ||
| auto keys_out_dtype = convert_ffi_datatype_to_te_dtype(keys_out_buf->element_type()); | ||
| auto values_out_dtype = convert_ffi_datatype_to_te_dtype(values_out_buf->element_type()); | ||
| NVTE_CHECK(keys_in_dtype == keys_out_dtype, "Input and output keys must have the same datatype"); | ||
| NVTE_CHECK(values_in_dtype == values_out_dtype, | ||
| "Input and output values must have the same datatype"); | ||
| NVTE_CHECK(values_in_dtype == DType::kInt32, "CubTopkFFI() only supports int32 values for now"); | ||
|
|
||
| auto keys_in_shape = keys_in_buf.dimensions(); | ||
| auto values_in_shape = values_in_buf.dimensions(); | ||
| auto keys_out_shape = keys_out_buf->dimensions(); | ||
| auto values_out_shape = values_out_buf->dimensions(); | ||
| NVTE_CHECK(keys_in_shape.size() == 1, "Keys input must have 1 dimension"); | ||
| NVTE_CHECK(values_in_shape.size() == 1, "Values input must have 1 dimension"); | ||
| NVTE_CHECK(keys_out_shape.size() == 1, "Keys output must have 1 dimension"); | ||
| NVTE_CHECK(values_out_shape.size() == 1, "Values output must have 1 dimension"); | ||
| NVTE_CHECK(keys_in_shape[0] == values_in_shape[0], | ||
| "Keys and values input must have the same number of items"); | ||
| NVTE_CHECK(keys_out_shape[0] == values_out_shape[0], | ||
| "Keys and values output must have the same number of items"); | ||
| int num_items = static_cast<int>(keys_in_shape[0]); | ||
| int k = static_cast<int>(k_value); | ||
|
Comment on lines
+39
to
+40
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There is no check that A guard should be added here alongside the existing shape checks: NVTE_CHECK(k <= num_items, "k (", k, ") must be <= num_items (", num_items, ")"); |
||
|
|
||
| auto input_shape = std::vector<size_t>{keys_in_shape[0]}; | ||
| auto output_shape = std::vector<size_t>{keys_out_shape[0]}; | ||
| auto workspace_shape = std::vector<size_t>{workbuf_bytes}; | ||
|
|
||
| auto keys_in_tensor = TensorWrapper(keys_in_buf.untyped_data(), input_shape, keys_in_dtype); | ||
| auto values_in_tensor = TensorWrapper(values_in_buf.untyped_data(), input_shape, values_in_dtype); | ||
| auto keys_out_tensor = TensorWrapper(keys_out_buf->untyped_data(), output_shape, keys_out_dtype); | ||
| auto values_out_tensor = | ||
| TensorWrapper(values_out_buf->untyped_data(), output_shape, values_out_dtype); | ||
| auto workspace_tensor = | ||
| TensorWrapper(workspace_buf->untyped_data(), workspace_shape, DType::kByte); | ||
|
|
||
| nvte_topk(stream, keys_in_tensor.data(), values_in_tensor.data(), keys_out_tensor.data(), | ||
| values_out_tensor.data(), workspace_tensor.data(), num_items, k, workbuf_bytes); | ||
|
|
||
| return ffi_with_cuda_error_check(); | ||
| } | ||
|
|
||
| XLA_FFI_DEFINE_HANDLER_SYMBOL(TopkHandler, TopkFFI, | ||
| FFI::Bind() | ||
| .Ctx<FFI_Stream_Type>() // stream | ||
| .Arg<Buffer_Type>() // keys_buf | ||
| .Arg<Buffer_Type>() // values_buf | ||
| .Ret<Buffer_Type>() // topk_buf | ||
| .Ret<Buffer_Type>() // indices_buf | ||
| .Ret<Buffer_Type>() // workspace_buf | ||
| .Attr<int64_t>("k_value") | ||
| .Attr<int64_t>("workbuf_bytes"), | ||
| FFI_CudaGraph_Traits); | ||
|
|
||
| } // namespace jax | ||
| } // namespace transformer_engine | ||
Uh oh!
There was an error while loading. Please reload this page.