Curated train,val,sft (using gpt 4o api with 383 samples) dataset from MATH
There isn't much to say about sft. Sft is not very sensitive to learning rate, so a relatively reasonable learning rate would lead to good results. Using the full 383-sample sft dataset will produce a checkpoint with validation accuracy of approximately 0.64.
See my wandb report for detailed reports of the ablation studies.
implementations of functions for unit tests in cs336_alignment/post_training_utils
run sft with uv run cs336_alignment/sft.py
run grpo with uv run cs336_alignment/grpo.py
brandon-snider's github repository
For a full description of the assignment, see the assignment handout at cs336_spring2025_assignment5_alignment.pdf
We include a supplemental (and completely optional) assignment on safety alignment, instruction tuning, and RLHF at cs336_spring2025_assignment5_supplement_safety_rlhf.pdf
If you see any issues with the assignment handout or code, please feel free to raise a GitHub issue or open a pull request with a fix.
As in previous assignments, we use uv to manage dependencies.
- Install all packages except
flash-attn, then all packages (flash-attnis weird)
uv sync --no-install-package flash-attn
uv sync
- Run unit tests:
uv run pytestInitially, all tests should fail with NotImplementedErrors.
To connect your implementation to the tests, complete the
functions in ./tests/adapters.py.

