Add main grad before fwd pass#1142
Add main grad before fwd pass#1142vedanuj wants to merge 5 commits intongoyal_changes_for_pp_fp8from
Conversation
| assert param.grad is not None, param.shape | ||
| if param.grad.requires_grad: | ||
| raise RuntimeError("FSDP only works with gradients that don't require gradients") | ||
| # assert param.grad is not None, param.shape |
There was a problem hiding this comment.
Perhaps some check is needed to make sure parameters are not shared (as would be the case with weights tying)?
| param.grad = None | ||
| if param.main_grad is not None: | ||
| grad = param.main_grad | ||
| param.main_grad = None |
There was a problem hiding this comment.
Doesn't .main_grad need to be restored somewhere before next forward?
|
If we construct An alternative option is to construct One option could be to add this logic to _prep_grads_for_backward(): |
| param.grad.data = param.grad.data.float() | ||
| if param.grad is not None: | ||
| if param.main_grad is not None: | ||
| param.main_grad.copy_(param.grad.float()) |
There was a problem hiding this comment.
nit: torch can upcast and copy in one kernel:
| param.main_grad.copy_(param.grad.float()) | |
| param.main_grad.copy_(param.grad) |
Correctness Example
>>> t_fp32 = torch.empty((4,))
>>> t_bf16 = torch.randn((4,), dtype=torch.bfloat16)
>>> t_fp32
tensor([-8.3762e-20, 3.0801e-41, -1.3043e-16, 3.0801e-41])
>>> t_bf16
tensor([-1.3516, -0.5156, -0.6055, 0.3535], dtype=torch.bfloat16)
>>> t_fp32.copy_(t_bf16)
tensor([-1.3516, -0.5156, -0.6055, 0.3535])
>>> t_fp32
tensor([-1.3516, -0.5156, -0.6055, 0.3535])
|
@awgu It seems from my testing that the changes are still necessary in @jspark1105 I have borrowed some changes from your PR #1136 to update the view when reallocating the zero buffers for main_grad. |
| self._streams["post_backward"].wait_stream(torch.cuda.current_stream()) | ||
| with torch.cuda.stream(self._streams["post_backward"]): | ||
| orig_grad_data = param.grad.data | ||
| if param.main_grad is not None and not param.main_grad.eq(0).all(): |
There was a problem hiding this comment.
Are we concerned that this param.main_grad.eq(0).all() might be a CPU sync? Perhaps, it is not so much a concern if we already have CPU syncs for rate limiting FSDP.
There was a problem hiding this comment.
Is there another way I can check if main_grad is non zero without doing a CPU sync?
There was a problem hiding this comment.
We are checking if this is all zeros to skip modules that didn't use main_grad?
There was a problem hiding this comment.
yes .. because all parameters have .main_grad, so not sure how to make sure we are not using the ones that do not have the grads stored in .main_grad
0be2e5e to
8cf28fa
Compare


Adds main_grad before FWD pass to
FlatParameterto be used with https://github.com/fairinternal/xlformers/pull/1418