Skip to content

2D weights support for permute_1D_data_kernel_vec#5557

Open
kausv wants to merge 1 commit intopytorch:mainfrom
kausv:export-D98797897
Open

2D weights support for permute_1D_data_kernel_vec#5557
kausv wants to merge 1 commit intopytorch:mainfrom
kausv:export-D98797897

Conversation

@kausv
Copy link
Copy Markdown
Contributor

@kausv kausv commented Mar 31, 2026

Summary:
X-link: meta-pytorch/torchrec#3970

X-link: https://github.com/facebookresearch/FBGEMM/pull/2522

Add 2D weights support to the vectorized permute_1D_sparse_data CUDA kernel.

The scalar kernel (permute_1D_data_kernel) already handled 2D weights via a
weights_columns loop. The vec kernel had three bugs for this case:

  1. Missing weights_columns parameter — kernel had no way to know the stride.
  2. Wrong pointer offsets — used input_start/output_start instead of
    input_start * weights_columns / output_start * weights_columns.
  3. Wrong copy count — copied segment_length elements instead of
    segment_length * weights_columns.

Fixes:

  • Add int32_t weights_columns parameter to permute_1D_data_kernel_vec.
  • Fix weight pointers to account for 2D stride.
  • Compute total_weight_elements = segment_length * weights_columns and use
    it for both the vec4 and scalar fallback weight copy loops.
  • Split indices and weights into separate vec4 loops (they now have different
    element counts when weights_columns > 1).
  • Add weights_columns == 1 || weights_columns % 4 == 0 guard to the vec4
    alignment check: treating a 2D weight row as vec4 chunks is only safe when
    row size is a multiple of 4, otherwise the scalar fallback is used.
  • Pass weights_columns from the dispatcher to both kernel launches.

Differential Revision: D98797897

Summary:
X-link: meta-pytorch/torchrec#3970

X-link: facebookresearch/FBGEMM#2522

Add 2D weights support to the vectorized permute_1D_sparse_data CUDA kernel.

The scalar kernel (permute_1D_data_kernel) already handled 2D weights via a
weights_columns loop. The vec kernel had three bugs for this case:

1. Missing weights_columns parameter — kernel had no way to know the stride.
2. Wrong pointer offsets — used input_start/output_start instead of
   input_start * weights_columns / output_start * weights_columns.
3. Wrong copy count — copied segment_length elements instead of
   segment_length * weights_columns.

Fixes:
- Add int32_t weights_columns parameter to permute_1D_data_kernel_vec.
- Fix weight pointers to account for 2D stride.
- Compute total_weight_elements = segment_length * weights_columns and use
  it for both the vec4 and scalar fallback weight copy loops.
- Split indices and weights into separate vec4 loops (they now have different
  element counts when weights_columns > 1).
- Add weights_columns == 1 || weights_columns % 4 == 0 guard to the vec4
  alignment check: treating a 2D weight row as vec4 chunks is only safe when
  row size is a multiple of 4, otherwise the scalar fallback is used.
- Pass weights_columns from the dispatcher to both kernel launches.

Differential Revision: D98797897
@meta-cla meta-cla bot added the cla signed label Mar 31, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync bot commented Mar 31, 2026

@kausv has exported this pull request. If you are a Meta employee, you can view the originating Diff in D98797897.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant