Support option to skip the optimizer for training step#3490
Support option to skip the optimizer for training step#3490
Conversation
54b4b54 to
8682ecf
Compare
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
d2f5c75 to
a51468c
Compare
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR successfully implements an optimizer wrapper to skip training steps during severe loss or gradient anomalies, effectively porting the OLMo-core logic to MaxText using JAX. The core logic elegantly computes rolling statistics and appropriately bypasses the inner optimizer during a spike to prevent momentum poisoning.
🔍 General Feedback
- JAX Idioms: The usage of
jax.lax.condto defer and conditionalize the inner optimizer step is very cleanly implemented. - Resilience: Added a few critical suggestions to explicitly handle
NaNorInfloss cases. Preventing buffer poisoning and explicitly skipping on non-finite metrics will make this logic foolproof against catastrophic anomalies. - Kwargs Forwarding: Recommended using
.pop()on**extra_argsto ensure consumed arguments likelossaren't passed downstream, guaranteeing better compatibility with inner optimizers.
a51468c to
5b8835d
Compare
5b8835d to
960398a
Compare
gagika
left a comment
There was a problem hiding this comment.
thanks, one minor comment
| "losses": state["losses"], | ||
| "grad_norms": state["grad_norms"], | ||
| "count": state["count"], | ||
| "is_skipped": jnp.array(False, dtype=jnp.bool_), |
There was a problem hiding this comment.
Will user see is_skipped in logging? It would be good if the logging shows "skipped" when the step is skipped, but no need to show is_skipped for every step since the majority are non-skipped. Alternatively we can show a separate warning message when a step is skipped.
Description
This PR introduces a mechanism to skip training steps during severe loss or gradient anomalies (b/489540436). Reference implementation at OLMo-core.
base.yml&types.pyskip_step_on_spikesas anoptax.GradientTransformationExtraArgswrapperTests
tests/unit/optimizers_test.pyChecklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.