Skip to content

Avoid CPU offload wait_event for validation#2793

Open
vasunvidia wants to merge 2 commits intoNVIDIA:mainfrom
vasunvidia:vrengasamy/cpu_offloading_cg
Open

Avoid CPU offload wait_event for validation#2793
vasunvidia wants to merge 2 commits intoNVIDIA:mainfrom
vasunvidia:vrengasamy/cpu_offloading_cg

Conversation

@vasunvidia
Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 23, 2026

Greptile Summary

This PR optimises the CPU-offload code path for the case where no tensors end up in fwd_gpu_tensor_group (e.g. during validation / inference, or when every activation falls below the 256 k-element offload threshold). Previously, start_offload() unconditionally created a torch.cuda.Event and recorded it on the offload stream, and release_activation_forward_gpu_memory() unconditionally called current_stream().wait_event(...) — a GPU synchronisation point — even when there was nothing to wait for. The change guards both operations behind len(self.fwd_gpu_tensor_group.tensor_list) > 0, which is evaluated after tensor_group_process_before_offload() has (potentially) deduplicated the list.

Key points:

  • The guard condition is checked on the processed (deduplicated) tensor list, which is consistent between the two call sites.
  • The state machine constraints (_validate_state) mean the two methods are always called in order, so the emptiness check in release_activation_forward_gpu_memory() will always agree with the one in start_offload().
  • finish_offload_event is never created (and therefore never accessed) on the empty-list path, so there is no AttributeError risk.
  • fwd_gpu_tensor_group is not explicitly reset to TensorGroup() on the empty-list path inside release_activation_forward_gpu_memory(), but this is benign: there are no GPU tensor references to release, and release_all_memory() unconditionally reinitialises it before the next forward pass.

Confidence Score: 5/5

  • Safe to merge — the change is a targeted, correct optimisation with no functional regressions on the non-empty path.
  • Both modified methods have consistent emptiness guards on the post-deduplication tensor list, the CUDA event lifecycle is correctly maintained (finish_offload_event is only created/deleted when the list is non-empty), the state machine prevents out-of-order calls, and the only minor asymmetry (not resetting fwd_gpu_tensor_group in the empty case) is benign and was already flagged in a prior review thread. No P0/P1 issues found.
  • No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/cpu_offload.py Guards CUDA event creation and stream wait_event behind a tensor-list emptiness check in both start_offload() and release_activation_forward_gpu_memory(), eliminating unnecessary GPU synchronisation when no tensors were pushed (e.g. during validation or inference with small activations).

Sequence Diagram

sequenceDiagram
    participant CS as Current CUDA Stream
    participant OS as Offload Stream
    participant OLS as OffloadableLayerState

    Note over OLS: start_offload() — tensor_list non-empty
    CS->>OLS: push_tensor(t)
    OLS->>OS: wait_event(t.start_reload_event)
    OS->>OS: copy_(tensor, non_blocking=True)
    OLS->>OS: finish_offload_event.record()

    Note over OLS: release_activation_forward_gpu_memory() — tensor_list non-empty
    OLS->>CS: wait_event(finish_offload_event)
    OLS->>OLS: fwd_gpu_tensor_group = TensorGroup()
    OLS->>OLS: del finish_offload_event

    Note over OLS: start_offload() — tensor_list EMPTY (e.g. validation)
    OLS-->>OLS: (no event created, no copy issued)

    Note over OLS: release_activation_forward_gpu_memory() — tensor_list EMPTY
    OLS-->>OLS: (no wait_event called — PR optimization)
Loading

Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +322 to +328
if len(self.fwd_gpu_tensor_group.tensor_list) > 0:
torch.cuda.current_stream().wait_event(self.finish_offload_event) # type: ignore[arg-type]

torch.cuda.current_stream().wait_event(self.finish_offload_event) # type: ignore[arg-type]

# GPU memory can be released safely after the offload.
# Notice that the memory needs to be kept alive when GPU->CPU copy is performed.
self.fwd_gpu_tensor_group = TensorGroup()
del self.finish_offload_event
# GPU memory can be released safely after the offload.
# Notice that the memory needs to be kept alive when GPU->CPU copy is performed.
self.fwd_gpu_tensor_group = TensorGroup()
del self.finish_offload_event
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 fwd_gpu_tensor_group not reset on empty-list path

When tensor_list is empty, fwd_gpu_tensor_group is left holding the (empty) processed TensorGroup returned by tensor_group_process_before_offload() rather than being reset to a fresh TensorGroup(). This is benign — there are no GPU tensor references to release — and release_all_memory() will always reinitialise it before the next forward pass. However, the behaviour is now asymmetric with the non-empty path and could confuse future readers.

Consider resetting fwd_gpu_tensor_group unconditionally to keep the post-condition of this method consistent regardless of whether tensors were present:

Suggested change
if len(self.fwd_gpu_tensor_group.tensor_list) > 0:
torch.cuda.current_stream().wait_event(self.finish_offload_event) # type: ignore[arg-type]
torch.cuda.current_stream().wait_event(self.finish_offload_event) # type: ignore[arg-type]
# GPU memory can be released safely after the offload.
# Notice that the memory needs to be kept alive when GPU->CPU copy is performed.
self.fwd_gpu_tensor_group = TensorGroup()
del self.finish_offload_event
# GPU memory can be released safely after the offload.
# Notice that the memory needs to be kept alive when GPU->CPU copy is performed.
self.fwd_gpu_tensor_group = TensorGroup()
del self.finish_offload_event
self.state = "offload_finished"
if len(self.fwd_gpu_tensor_group.tensor_list) > 0:
torch.cuda.current_stream().wait_event(self.finish_offload_event) # type: ignore[arg-type]
# GPU memory can be released safely after the offload.
# Notice that the memory needs to be kept alive when GPU->CPU copy is performed.
del self.finish_offload_event
self.fwd_gpu_tensor_group = TensorGroup()

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@pggPL
Copy link
Copy Markdown
Collaborator

pggPL commented Mar 30, 2026

/te-ci pytorch

vasunvidia and others added 2 commits March 30, 2026 13:35
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
@vasunvidia vasunvidia force-pushed the vrengasamy/cpu_offloading_cg branch from 63ed892 to b71c0d5 Compare March 30, 2026 20:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants