Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,13 @@ __device__ __forceinline__ size_t get_current_tensor_id(
template <typename IType, int kHadamardDimension, int BUFF_DIM_Y, int BUFF_DIM_X,
bool kReturnPreRhtAmax, bool kReturnIdentityAmax, bool kReturnTransposedAmax>
__device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4],
IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg,
IType* in_sh_ptr, int swizzle_idx,
uint32_t& local_pre_rht_amax_reg,
uint32_t& local_amax_reg,
uint32_t& local_amax_t_reg) {
uint32_t a_frag[4]; // A matrix fragment
uint32_t c_frag[4]; // Result fragment

int warp_id = threadIdx.x / kThreadsPerWarp;
int local_rank = (threadIdx.x % kThreadsPerWarp);

int ld_row_idx = local_rank % kHadamardDimension;
int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2;
int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx);

uint32_t temp_amax_reg;
uint32_t temp_amax_t_reg;

Expand Down Expand Up @@ -322,6 +316,12 @@ __global__ void GraphSafeGroupHadamardAmaxTmaKernel(
uint32_t local_amax_reg = *reinterpret_cast<uint32_t*>(&local_amax);
uint32_t local_amax_t_reg = *reinterpret_cast<uint32_t*>(&local_amax_t);

const int warp_id = threadIdx.x / kThreadsPerWarp;
const int local_rank = threadIdx.x % kThreadsPerWarp;
const int ld_row_idx = local_rank % kHadamardDimension;
const int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2;
const int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx);

for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) {
for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) {
int stage = STAGES_X * stage_y + stage_x;
Expand Down Expand Up @@ -364,7 +364,7 @@ __global__ void GraphSafeGroupHadamardAmaxTmaKernel(
had_frag_i, had_frag_t,
in_sh_ptr + in_row_offset +
(compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)),
local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg);
swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg);
}

// Ensure all threads have finished their computation before new data over-writes the shared
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,13 @@ constexpr int kThreadsPerWarp = 32;
template <typename IType, int kHadamardDimension, int BUFF_DIM_Y, int BUFF_DIM_X,
bool kReturnPreRhtAmax, bool kReturnIdentityAmax, bool kReturnTransposedAmax>
__device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4],
IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg,
IType* in_sh_ptr, int swizzle_idx,
uint32_t& local_pre_rht_amax_reg,
uint32_t& local_amax_reg,
uint32_t& local_amax_t_reg) {
uint32_t a_frag[4]; // A matrix fragment
uint32_t c_frag[4]; // Result fragment

int warp_id = threadIdx.x / kThreadsPerWarp;
int local_rank = (threadIdx.x % kThreadsPerWarp);

int ld_row_idx = local_rank % kHadamardDimension;
int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2;
int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx);

uint32_t temp_amax_reg;
uint32_t temp_amax_t_reg;

Expand Down Expand Up @@ -305,6 +299,12 @@ __global__ void GroupHadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap t
uint32_t local_amax_reg = *reinterpret_cast<uint32_t*>(&local_amax);
uint32_t local_amax_t_reg = *reinterpret_cast<uint32_t*>(&local_amax_t);

const int warp_id = threadIdx.x / kThreadsPerWarp;
const int local_rank = threadIdx.x % kThreadsPerWarp;
const int ld_row_idx = local_rank % kHadamardDimension;
const int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2;
const int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx);
Comment on lines +302 to +306
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Same optimization not applied to sibling files

hadamard_transform.cu and graph_safe_group_hadamard_transform.cu contain near-identical ComputeKernel definitions that still recompute warp_id, local_rank, ld_row_idx, ld_col_idx, and swizzle_idx inside the function body on every invocation. If the goal is to eliminate redundant per-iteration work, those two files have the same hot-loop structure and would benefit from the same refactor.

This is not a bug — since ComputeKernel is __forceinline__, the compiler can already hoist these invariants under optimization. But for consistency and to complete the stated intent of the PR, consider applying the same pattern to:

  • hadamard_transform.cu:35-40 / call site at ~line 288
  • graph_safe_group_hadamard_transform.cu:74-79 / call site at ~line 362

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point


for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) {
for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) {
int stage = STAGES_X * stage_y + stage_x;
Expand Down Expand Up @@ -347,7 +347,7 @@ __global__ void GroupHadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap t
had_frag_i, had_frag_t,
in_sh_ptr + in_row_offset +
(compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)),
local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg);
swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg);
}

// Ensure all threads have finished their computation before new data over-writes the shared
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,13 @@ constexpr int kThreadsPerWarp = 32;
template <typename IType, int kHadamardDimension, int BUFF_DIM_Y, int BUFF_DIM_X,
bool kReturnPreRhtAmax, bool kReturnIdentityAmax, bool kReturnTransposedAmax>
__device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4],
IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg,
IType* in_sh_ptr, int swizzle_idx,
uint32_t& local_pre_rht_amax_reg,
uint32_t& local_amax_reg,
uint32_t& local_amax_t_reg) {
uint32_t a_frag[4]; // A matrix fragment
uint32_t c_frag[4]; // Result fragment

int warp_id = threadIdx.x / kThreadsPerWarp;
int local_rank = (threadIdx.x % kThreadsPerWarp);

int ld_row_idx = local_rank % kHadamardDimension;
int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2;
int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx);

uint32_t temp_amax_reg;
uint32_t temp_amax_t_reg;

Expand Down Expand Up @@ -248,6 +242,12 @@ __global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor
uint32_t local_amax_reg = *reinterpret_cast<uint32_t*>(&local_amax);
uint32_t local_amax_t_reg = *reinterpret_cast<uint32_t*>(&local_amax_t);

const int warp_id = threadIdx.x / kThreadsPerWarp;
const int local_rank = threadIdx.x % kThreadsPerWarp;
const int ld_row_idx = local_rank % kHadamardDimension;
const int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2;
const int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx);

for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) {
for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) {
int stage = STAGES_X * stage_y + stage_x;
Expand Down Expand Up @@ -290,7 +290,7 @@ __global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor
had_frag_i, had_frag_t,
in_sh_ptr + in_row_offset +
(compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)),
local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg);
swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg);
}

// Ensure all threads have finished their computation before new data over-writes the shared
Expand Down