2D weights support for permute_1D_data_kernel_vec#5557
Open
kausv wants to merge 1 commit intopytorch:mainfrom
Open
2D weights support for permute_1D_data_kernel_vec#5557kausv wants to merge 1 commit intopytorch:mainfrom
kausv wants to merge 1 commit intopytorch:mainfrom
Conversation
Summary: X-link: meta-pytorch/torchrec#3970 X-link: facebookresearch/FBGEMM#2522 Add 2D weights support to the vectorized permute_1D_sparse_data CUDA kernel. The scalar kernel (permute_1D_data_kernel) already handled 2D weights via a weights_columns loop. The vec kernel had three bugs for this case: 1. Missing weights_columns parameter — kernel had no way to know the stride. 2. Wrong pointer offsets — used input_start/output_start instead of input_start * weights_columns / output_start * weights_columns. 3. Wrong copy count — copied segment_length elements instead of segment_length * weights_columns. Fixes: - Add int32_t weights_columns parameter to permute_1D_data_kernel_vec. - Fix weight pointers to account for 2D stride. - Compute total_weight_elements = segment_length * weights_columns and use it for both the vec4 and scalar fallback weight copy loops. - Split indices and weights into separate vec4 loops (they now have different element counts when weights_columns > 1). - Add weights_columns == 1 || weights_columns % 4 == 0 guard to the vec4 alignment check: treating a 2D weight row as vec4 chunks is only safe when row size is a multiple of 4, otherwise the scalar fallback is used. - Pass weights_columns from the dispatcher to both kernel launches. Differential Revision: D98797897
Contributor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
X-link: meta-pytorch/torchrec#3970
X-link: https://github.com/facebookresearch/FBGEMM/pull/2522
Add 2D weights support to the vectorized permute_1D_sparse_data CUDA kernel.
The scalar kernel (permute_1D_data_kernel) already handled 2D weights via a
weights_columns loop. The vec kernel had three bugs for this case:
input_start * weights_columns / output_start * weights_columns.
segment_length * weights_columns.
Fixes:
it for both the vec4 and scalar fallback weight copy loops.
element counts when weights_columns > 1).
alignment check: treating a 2D weight row as vec4 chunks is only safe when
row size is a multiple of 4, otherwise the scalar fallback is used.
Differential Revision: D98797897