Skip to content

Commit 355d9d8

Browse files
committed
change the way to detect no_grad
1 parent c441044 commit 355d9d8

File tree

2 files changed

+4
-8
lines changed

2 files changed

+4
-8
lines changed

deepspeed/runtime/zero/parameter_offload.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -490,19 +490,19 @@ def _run_after_backward_function(sub_module):
490490
# post backward hook
491491
self.backward_hooks.append(module.register_forward_pre_hook(_post_backward_module_hook))
492492

493-
@torch.no_grad()
494493
def pre_sub_module_forward_function(self, sub_module):
495494
see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", force=False)
496-
495+
prev_grad_state = torch.is_grad_enabled() # we don't want to enable grad for sub modules fetching, yet the subfunction need to know if grad is enabled
496+
torch.set_grad_enabled(False)
497497
global FWD_MODULE_STACK
498498
FWD_MODULE_STACK.append(sub_module)
499499

500500
param_coordinator = self.get_param_coordinator(training=sub_module.training)
501501
param_coordinator.trace_prologue(sub_module)
502502
if param_coordinator.is_record_trace():
503503
param_coordinator.record_module(sub_module)
504-
param_coordinator.fetch_sub_module(sub_module, forward=True)
505-
504+
param_coordinator.fetch_sub_module(sub_module, forward=prev_grad_state)
505+
torch.set_grad_enabled(prev_grad_state)
506506
see_memory_usage(f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False)
507507

508508
@torch.no_grad()

deepspeed/runtime/zero/partition_parameters.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,10 +1033,6 @@ def all_gather_coalesced(params: Iterable[Parameter],
10331033
safe_mode: bool = False,
10341034
quantize: bool = False) -> AllGatherCoalescedHandle:
10351035

1036-
# check if currently in torch.no_grad context
1037-
if not torch.is_grad_enabled():
1038-
forward = False
1039-
10401036
# fetches from nvme if the partition is not available and in nvme
10411037
self._ensure_availability_of_partitioned_params(params)
10421038

0 commit comments

Comments
 (0)