From f101b02d38e28e16652fb28a27f05c56978c0529 Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Fri, 27 Mar 2026 00:42:24 -0700 Subject: [PATCH 1/3] Compute swizzle_idx once per thread and pass into ComputeKernel. Signed-off-by: Cael Ling --- .../group_hadamard_transform.cu | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu index 07813be059..8b7f079072 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu @@ -41,19 +41,13 @@ constexpr int kThreadsPerWarp = 32; template __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4], - IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg, + IType* in_sh_ptr, int swizzle_idx, + uint32_t& local_pre_rht_amax_reg, uint32_t& local_amax_reg, uint32_t& local_amax_t_reg) { uint32_t a_frag[4]; // A matrix fragment uint32_t c_frag[4]; // Result fragment - int warp_id = threadIdx.x / kThreadsPerWarp; - int local_rank = (threadIdx.x % kThreadsPerWarp); - - int ld_row_idx = local_rank % kHadamardDimension; - int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; - int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); - uint32_t temp_amax_reg; uint32_t temp_amax_t_reg; @@ -305,6 +299,12 @@ __global__ void GroupHadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap t uint32_t local_amax_reg = *reinterpret_cast(&local_amax); uint32_t local_amax_t_reg = *reinterpret_cast(&local_amax_t); + const int warp_id = threadIdx.x / kThreadsPerWarp; + const int local_rank = threadIdx.x % kThreadsPerWarp; + const int ld_row_idx = local_rank % kHadamardDimension; + const int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; + const int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); + for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) { for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) { int stage = STAGES_X * stage_y + stage_x; @@ -347,7 +347,7 @@ __global__ void GroupHadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap t had_frag_i, had_frag_t, in_sh_ptr + in_row_offset + (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), - local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); + swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); } // Ensure all threads have finished their computation before new data over-writes the shared From e395bdbde6e419a5f412985812483dde2bc6550b Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Mon, 30 Mar 2026 19:10:54 -0700 Subject: [PATCH 2/3] Refactor the change to other variants Signed-off-by: Cael Ling --- .../graph_safe_group_hadamard_transform.cu | 18 +++++++++--------- .../hadamard_transform/hadamard_transform.cu | 18 +++++++++--------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu index 04e965a9da..8f9a30ac60 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu @@ -65,19 +65,13 @@ __device__ __forceinline__ size_t get_current_tensor_id( template __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4], - IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg, + IType* in_sh_ptr, int swizzle_idx, + uint32_t& local_pre_rht_amax_reg, uint32_t& local_amax_reg, uint32_t& local_amax_t_reg) { uint32_t a_frag[4]; // A matrix fragment uint32_t c_frag[4]; // Result fragment - int warp_id = threadIdx.x / kThreadsPerWarp; - int local_rank = (threadIdx.x % kThreadsPerWarp); - - int ld_row_idx = local_rank % kHadamardDimension; - int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; - int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); - uint32_t temp_amax_reg; uint32_t temp_amax_t_reg; @@ -322,6 +316,12 @@ __global__ void GraphSafeGroupHadamardAmaxTmaKernel( uint32_t local_amax_reg = *reinterpret_cast(&local_amax); uint32_t local_amax_t_reg = *reinterpret_cast(&local_amax_t); + const int warp_id = threadIdx.x / kThreadsPerWarp; + const int local_rank = threadIdx.x % kThreadsPerWarp; + const int ld_row_idx = local_rank % kHadamardDimension; + const int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; + const int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); + for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) { for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) { int stage = STAGES_X * stage_y + stage_x; @@ -364,7 +364,7 @@ __global__ void GraphSafeGroupHadamardAmaxTmaKernel( had_frag_i, had_frag_t, in_sh_ptr + in_row_offset + (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), - local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); + swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); } // Ensure all threads have finished their computation before new data over-writes the shared diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index 4adc836886..4e3c528fd4 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -26,19 +26,13 @@ constexpr int kThreadsPerWarp = 32; template __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4], - IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg, + IType* in_sh_ptr, int swizzle_idx, + uint32_t& local_pre_rht_amax_reg, uint32_t& local_amax_reg, uint32_t& local_amax_t_reg) { uint32_t a_frag[4]; // A matrix fragment uint32_t c_frag[4]; // Result fragment - int warp_id = threadIdx.x / kThreadsPerWarp; - int local_rank = (threadIdx.x % kThreadsPerWarp); - - int ld_row_idx = local_rank % kHadamardDimension; - int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; - int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); - uint32_t temp_amax_reg; uint32_t temp_amax_t_reg; @@ -248,6 +242,12 @@ __global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor uint32_t local_amax_reg = *reinterpret_cast(&local_amax); uint32_t local_amax_t_reg = *reinterpret_cast(&local_amax_t); + const int warp_id = threadIdx.x / kThreadsPerWarp; + const int local_rank = threadIdx.x % kThreadsPerWarp; + const int ld_row_idx = local_rank % kHadamardDimension; + const int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; + const int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); + for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) { for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) { int stage = STAGES_X * stage_y + stage_x; @@ -290,7 +290,7 @@ __global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor had_frag_i, had_frag_t, in_sh_ptr + in_row_offset + (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), - local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); + swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); } // Ensure all threads have finished their computation before new data over-writes the shared From ee46c42bbe7fc3ab138b15ae733aa9851f3d8ddf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 31 Mar 2026 02:14:22 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../hadamard_transform/graph_safe_group_hadamard_transform.cu | 2 +- .../common/hadamard_transform/hadamard_transform.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu index 8f9a30ac60..231d522f3a 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu @@ -364,7 +364,7 @@ __global__ void GraphSafeGroupHadamardAmaxTmaKernel( had_frag_i, had_frag_t, in_sh_ptr + in_row_offset + (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), - swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); + swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); } // Ensure all threads have finished their computation before new data over-writes the shared diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index 4e3c528fd4..216ed1930a 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -290,7 +290,7 @@ __global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor had_frag_i, had_frag_t, in_sh_ptr + in_row_offset + (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), - swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); + swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); } // Ensure all threads have finished their computation before new data over-writes the shared