Skip to content

Precompute writeback dedup indices in forward to eliminate GPU-CPU sync in backward (#5522)#5522

Open
Zhihan-Lu wants to merge 1 commit intopytorch:mainfrom
Zhihan-Lu:export-D97415226
Open

Precompute writeback dedup indices in forward to eliminate GPU-CPU sync in backward (#5522)#5522
Zhihan-Lu wants to merge 1 commit intopytorch:mainfrom
Zhihan-Lu:export-D97415226

Conversation

@Zhihan-Lu
Copy link
Copy Markdown

@Zhihan-Lu Zhihan-Lu commented Mar 24, 2026

Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/2494

The writeback backward hook (used by EXACT_SGD) performs dedup index computation
that triggers GPU-to-CPU synchronization (nonzero, unique, .to()) on every backward
pass. This stalls the GPU pipeline and degrades training QPS.

This diff splits the writeback logic into two phases:

  1. compute_writeback_indices (forward): precomputes which gradient rows to keep,
    running all sync-causing ops during forward where they can overlap with other work.
  2. writeback_apply_mask (backward): applies the precomputed mask to zero out
    duplicate gradient rows. This is fully sync-free.

The optimization is gated by env var FBGEMM_PRECOMPUTE_WRITEBACK=1 (default off).
When disabled, the legacy code path is used unchanged.

Covers all three writeback modes: bag, first-feature-only, and nobag.

Key changes:

  • writeback_util.py: Refactored into composable primitives (compute_writeback_indices,
    compute_writeback_indices_first_feature_only, compute_writeback_indices_nobag,
    writeback_apply_mask). Added compute_writeback_indices_dispatch for unified dispatch.
    Update/gradient functions now delegate to these primitives.
  • split_table_batched_embeddings_ops_training.py: Added _writeback_precomputed_index
    member variable and _precompute_writeback flag (cached from env var). Forward pass
    precomputes indices via compute_writeback_indices_dispatch when enabled. writeback_hook
    passes precomputed indices to writeback_gradient.
  • writeback_util_test.py: Added property-based test (test_precomputed_writeback_all_modes)
    verifying dispatch routing and precomputed-vs-without equivalence across all 3 modes.

Differential Revision: D97415226

…nc in backward (pytorch#5522)

Summary:
X-link: facebookresearch/FBGEMM#2494


The writeback backward hook (used by EXACT_SGD) performs dedup index computation
that triggers GPU-to-CPU synchronization (nonzero, unique, .to()) on every backward
pass. This stalls the GPU pipeline and degrades training QPS.

This diff splits the writeback logic into two phases:
1. **compute_writeback_indices** (forward): precomputes which gradient rows to keep,
   running all sync-causing ops during forward where they can overlap with other work.
2. **writeback_apply_mask** (backward): applies the precomputed mask to zero out
   duplicate gradient rows. This is fully sync-free.

The optimization is gated by env var `FBGEMM_PRECOMPUTE_WRITEBACK=1` (default off).
When disabled, the legacy code path is used unchanged.

Covers all three writeback modes: bag, first-feature-only, and nobag.

Key changes:
- `writeback_util.py`: Refactored into composable primitives (`compute_writeback_indices`,
  `compute_writeback_indices_first_feature_only`, `compute_writeback_indices_nobag`,
  `writeback_apply_mask`). Added `compute_writeback_indices_dispatch` for unified dispatch.
  Update/gradient functions now delegate to these primitives.
- `split_table_batched_embeddings_ops_training.py`: Added `_writeback_precomputed_index`
  member variable and `_precompute_writeback` flag (cached from env var). Forward pass
  precomputes indices via `compute_writeback_indices_dispatch` when enabled. `writeback_hook`
  passes precomputed indices to `writeback_gradient`.
- `writeback_util_test.py`: Added property-based test (`test_precomputed_writeback_all_modes`)
  verifying dispatch routing and precomputed-vs-without equivalence across all 3 modes.

Differential Revision: D97415226
@meta-codesync meta-codesync bot changed the title Precompute writeback dedup indices in forward to eliminate GPU-CPU sync in backward Precompute writeback dedup indices in forward to eliminate GPU-CPU sync in backward (#5522) Mar 24, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync bot commented Mar 24, 2026

@Zhihan-Lu has exported this pull request. If you are a Meta employee, you can view the originating Diff in D97415226.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant