Skip to content

Commit 5417491

Browse files
committed
add do_not_average_loss arg
Signed-off-by: ashors1 <[email protected]>
1 parent b745aef commit 5417491

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

nemo_rl/models/megatron/train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def megatron_forward_backward(
180180
defer_fp32_logits: Optional[bool] = None,
181181
global_valid_seqs: Optional[torch.Tensor] = None,
182182
global_valid_toks: Optional[torch.Tensor] = None,
183+
do_not_average_loss: bool = False,
183184
) -> Any:
184185
"""
185186
Execute forward and backward passes using Megatron's utilities.
@@ -222,6 +223,7 @@ def megatron_forward_backward(
222223
micro_batch_size=mbs,
223224
decoder_seq_length=seq_length,
224225
forward_only=forward_only,
226+
do_not_average_loss=do_not_average_loss,
225227
)
226228

227229
class LossPostProcessor:

nemo_rl/models/policy/workers/megatron_policy_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -963,10 +963,10 @@ def train(
963963
mbs=micro_batch_size,
964964
post_processing_fn=loss_fn_wrapped,
965965
forward_only=eval_mode,
966-
#do_not_average_loss=True, ## TODO!
967966
defer_fp32_logits=self.defer_fp32_logits,
968967
global_valid_seqs=global_valid_seqs,
969968
global_valid_toks=global_valid_toks,
969+
do_not_average_loss=True,
970970
)
971971

972972
# Empty unused memory.

0 commit comments

Comments
 (0)