Precompute writeback dedup indices in forward to eliminate GPU-CPU sync in backward (#5522)#5522
Open
Zhihan-Lu wants to merge 1 commit intopytorch:mainfrom
Open
Precompute writeback dedup indices in forward to eliminate GPU-CPU sync in backward (#5522)#5522Zhihan-Lu wants to merge 1 commit intopytorch:mainfrom
Zhihan-Lu wants to merge 1 commit intopytorch:mainfrom
Conversation
…nc in backward (pytorch#5522) Summary: X-link: facebookresearch/FBGEMM#2494 The writeback backward hook (used by EXACT_SGD) performs dedup index computation that triggers GPU-to-CPU synchronization (nonzero, unique, .to()) on every backward pass. This stalls the GPU pipeline and degrades training QPS. This diff splits the writeback logic into two phases: 1. **compute_writeback_indices** (forward): precomputes which gradient rows to keep, running all sync-causing ops during forward where they can overlap with other work. 2. **writeback_apply_mask** (backward): applies the precomputed mask to zero out duplicate gradient rows. This is fully sync-free. The optimization is gated by env var `FBGEMM_PRECOMPUTE_WRITEBACK=1` (default off). When disabled, the legacy code path is used unchanged. Covers all three writeback modes: bag, first-feature-only, and nobag. Key changes: - `writeback_util.py`: Refactored into composable primitives (`compute_writeback_indices`, `compute_writeback_indices_first_feature_only`, `compute_writeback_indices_nobag`, `writeback_apply_mask`). Added `compute_writeback_indices_dispatch` for unified dispatch. Update/gradient functions now delegate to these primitives. - `split_table_batched_embeddings_ops_training.py`: Added `_writeback_precomputed_index` member variable and `_precompute_writeback` flag (cached from env var). Forward pass precomputes indices via `compute_writeback_indices_dispatch` when enabled. `writeback_hook` passes precomputed indices to `writeback_gradient`. - `writeback_util_test.py`: Added property-based test (`test_precomputed_writeback_all_modes`) verifying dispatch routing and precomputed-vs-without equivalence across all 3 modes. Differential Revision: D97415226
001cf5b to
0d3b04b
Compare
Contributor
|
@Zhihan-Lu has exported this pull request. If you are a Meta employee, you can view the originating Diff in D97415226. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/2494
The writeback backward hook (used by EXACT_SGD) performs dedup index computation
that triggers GPU-to-CPU synchronization (nonzero, unique, .to()) on every backward
pass. This stalls the GPU pipeline and degrades training QPS.
This diff splits the writeback logic into two phases:
running all sync-causing ops during forward where they can overlap with other work.
duplicate gradient rows. This is fully sync-free.
The optimization is gated by env var
FBGEMM_PRECOMPUTE_WRITEBACK=1(default off).When disabled, the legacy code path is used unchanged.
Covers all three writeback modes: bag, first-feature-only, and nobag.
Key changes:
writeback_util.py: Refactored into composable primitives (compute_writeback_indices,compute_writeback_indices_first_feature_only,compute_writeback_indices_nobag,writeback_apply_mask). Addedcompute_writeback_indices_dispatchfor unified dispatch.Update/gradient functions now delegate to these primitives.
split_table_batched_embeddings_ops_training.py: Added_writeback_precomputed_indexmember variable and
_precompute_writebackflag (cached from env var). Forward passprecomputes indices via
compute_writeback_indices_dispatchwhen enabled.writeback_hookpasses precomputed indices to
writeback_gradient.writeback_util_test.py: Added property-based test (test_precomputed_writeback_all_modes)verifying dispatch routing and precomputed-vs-without equivalence across all 3 modes.
Differential Revision: D97415226