File tree Expand file tree Collapse file tree 2 files changed +3
-1
lines changed
Expand file tree Collapse file tree 2 files changed +3
-1
lines changed Original file line number Diff line number Diff line change @@ -180,6 +180,7 @@ def megatron_forward_backward(
180180 defer_fp32_logits : Optional [bool ] = None ,
181181 global_valid_seqs : Optional [torch .Tensor ] = None ,
182182 global_valid_toks : Optional [torch .Tensor ] = None ,
183+ do_not_average_loss : bool = False ,
183184) -> Any :
184185 """
185186 Execute forward and backward passes using Megatron's utilities.
@@ -222,6 +223,7 @@ def megatron_forward_backward(
222223 micro_batch_size = mbs ,
223224 decoder_seq_length = seq_length ,
224225 forward_only = forward_only ,
226+ do_not_average_loss = do_not_average_loss ,
225227 )
226228
227229class LossPostProcessor :
Original file line number Diff line number Diff line change @@ -963,10 +963,10 @@ def train(
963963 mbs = micro_batch_size ,
964964 post_processing_fn = loss_fn_wrapped ,
965965 forward_only = eval_mode ,
966- #do_not_average_loss=True, ## TODO!
967966 defer_fp32_logits = self .defer_fp32_logits ,
968967 global_valid_seqs = global_valid_seqs ,
969968 global_valid_toks = global_valid_toks ,
969+ do_not_average_loss = True ,
970970 )
971971
972972 # Empty unused memory.
You can’t perform that action at this time.
0 commit comments