Skip to content

Commit 11eca61

Browse files
committed
rename ppo -> mpoppo
1 parent 0371df4 commit 11eca61

File tree

12,567 files changed

+113
-106
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

12,567 files changed

+113
-106
lines changed

README.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,18 @@
33

44
## What’s Been Implemented?
55

6-
- **Main script** for launching MPO training on top of PPO: `examples/scripts/mpo.py`
7-
- **`MPOTrainer`**: Located in `trl/trainer/mpo_trainer.py`, this extends `PPOTrainer` to implement the full MPO procedure as described in the paper.
8-
- **`MPOConfig`**: Defined in `trl/trainer/mpo_config.py`, this contains all hyperparameters for MPO training.
9-
- **Processed corpora** for four tasks (essay writing, summarization, ethical reasoning, and mathematical reasoning) are provided in `trl/extras/mpo/corpora`.
10-
- **Initial prompts and meta-prompts** for each task are located in `trl/extras/mpo/prompts`.
11-
- **LLM-based reward models (RMs)** and **meta-reward models (MRMs)** are implemented in task-specific files under `trl/extras/mpo/rm_{task_name}.py`, and dataset loading/processing is handled in `trl/extras/mpo/mpo_datasets.py`.
12-
- **Utility functions** for MPO training are implemented in `trl/trainer/utils.py`.
6+
- **Main script** for launching MPOPPO training on top of PPO: `examples/scripts/mpoppo.py`
7+
- **`MPOPPOTrainer`**: Located in `trl/trainer/mpoppo_trainer.py`, this extends `PPOTrainer` to implement the full MPO procedure as described in the paper.
8+
- **`MPOPPOConfig`**: Defined in `trl/trainer/mpoppo_config.py`, this contains all hyperparameters for MPOPPO training.
9+
- **Processed corpora** for four tasks (essay writing, summarization, ethical reasoning, and mathematical reasoning) are provided in `trl/extras/mpoppo/corpora`.
10+
- **Initial prompts and meta-prompts** for each task are located in `trl/extras/mpoppo/prompts`.
11+
- **LLM-based reward models (RMs)** and **meta-reward models (MRMs)** are implemented in task-specific files under `trl/extras/mpoppo/rm_{task_name}.py`, and dataset loading/processing is handled in `trl/extras/mpoppo/mpoppo_datasets.py`.
12+
- **Utility functions** for MPOPPO training are implemented in `trl/trainer/utils.py`.
1313
- **Additional scripts** for launching remote LLM servers and evaluating trained models are provided in `scripts/mpo_experiments`.
1414

1515
## Installation & Execution Requirements
1616

17-
- Running MPO requires two components:
17+
- Running MPOPPO requires two components:
1818
1. **A primary node or subset of GPUs** dedicated to RL training.
1919
2. **A separate node or the remaining GPUs** dedicated to serving reward scores in an online fashion.
2020
- For the former, install this repository using `virtualenv` and `uv` (recommended for clean and reproducible environments):
@@ -36,8 +36,8 @@
3636
$ uv pip install vllm==0.8.4
3737
```
3838
- Refer to the [SGLang documentation](https://docs.sglang.ai/) for more details.
39-
- Training start and end notifications are currently sent via [Pushover](https://pushover.net/api). If you do not wish to use this feature, you can simply comment out the relevant lines in the launch script: `examples/scripts/mpo.py`.
40-
- The `launch_mpo.sh` script in `scripts/mpo_experiments` demonstrates how to train models using MPO with different parameter configurations.
39+
- Training start and end notifications are currently sent via [Pushover](https://pushover.net/api). If you do not wish to use this feature, you can simply comment out the relevant lines in the launch script: `examples/scripts/mpoppo.py`.
40+
- The `launch_mpoppo.sh` script in `scripts/mpo_experiments` demonstrates how to train models using MPOPPO with different parameter configurations.
4141
- The `launch_rm_mrm.sh` script in `scripts/mpo_experiments` shows how to instantiate and serve LLMs via SGLang over an SSH connection.
4242
4343
Below is the README from trl repository.
Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,21 @@
99
from peft import LoraConfig
1010
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
1111

12-
from trl import ModelConfig, MPOConfig, MPOTrainer, ScriptArguments, get_kbit_device_map, get_quantization_config
13-
from trl.extras.mpo import get_task_dataset
12+
from trl import (
13+
MPOPPOConfig,
14+
MPOPPOTrainer,
15+
ModelConfig,
16+
ScriptArguments,
17+
get_kbit_device_map,
18+
get_quantization_config,
19+
)
20+
from trl.extras.mpoppo import get_task_dataset
1421
from trl.models.modeling_value_head import AutoModelForCausalLMWithValueHead
1522
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
1623

1724

1825
"""
19-
See launch script in scripts/mpo_experiments/launch_mpo.sh
26+
See launch script in scripts/mpo_experiments/launch_mpoppo.sh
2027
"""
2128

2229

@@ -38,7 +45,7 @@ def seed_everything(seed: int = 42):
3845
pushover = Pushover(user=os.environ["PUSHOVER_USER"], token=os.environ["PUSHOVER_TOKEN"])
3946

4047
seed_everything(42)
41-
parser = HfArgumentParser((ScriptArguments, MPOConfig, ModelConfig))
48+
parser = HfArgumentParser((ScriptArguments, MPOPPOConfig, ModelConfig))
4249
script_args, training_args, model_args = parser.parse_args_into_dataclasses()
4350
if os.path.exists(training_args.output_dir):
4451
raise ValueError(
@@ -149,7 +156,7 @@ def seed_everything(seed: int = 42):
149156
sound="magic",
150157
)
151158
try:
152-
trainer = MPOTrainer(
159+
trainer = MPOPPOTrainer(
153160
args=training_args,
154161
processing_class=tokenizer,
155162
model=policy,

scripts/mpo_experiments/dgx/launch_sglang_dgx.sbatch

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#SBATCH --mem=100GB
88
#SBATCH --gpus-per-node=8
99
#SBATCH --cpus-per-task=32
10-
#SBATCH -t 72:00:00
10+
#SBATCH -t 96:00:00
1111
#SBATCH -o %x.%j.out
1212

1313
set -euo pipefail
@@ -17,20 +17,20 @@ set -euo pipefail
1717
##############################
1818

1919
# Your project root
20-
export MPO_ROOT="/lustre/fs0/scratch/zkim/Development/mpo-old"
20+
export MPOPPO_ROOT="/lustre/fs0/scratch/zkim/Development/mpo-old"
2121

2222
# Load .env (expects a line like: HF_TOKEN=xxxx)
23-
if [ -f "${MPO_ROOT}/.env" ]; then
23+
if [ -f "${MPOPPO_ROOT}/.env" ]; then
2424
# Export variables defined inside .env
2525
set -a
2626
# shellcheck source=/dev/null
27-
source "${MPO_ROOT}/.env"
27+
source "${MPOPPO_ROOT}/.env"
2828
set +a
2929
else
30-
echo "[WARN] ${MPO_ROOT}/.env not found; HF_TOKEN may be unset" >&2
30+
echo "[WARN] ${MPOPPO_ROOT}/.env not found; HF_TOKEN may be unset" >&2
3131
fi
3232

33-
: "${HF_TOKEN:?HF_TOKEN must be set in ${MPO_ROOT}/.env}"
33+
: "${HF_TOKEN:?HF_TOKEN must be set in ${MPOPPO_ROOT}/.env}"
3434

3535
# Container image (your SQSH)
3636
export SGLANG_IMAGE="/lustre/fs0/scratch/zkim/sqsh-files/lmsysorg+sglang+latest.sqsh"
@@ -140,4 +140,4 @@ echo " curl http://localhost:30000/v1/chat/completions ..."
140140
echo
141141

142142
# Keep job alive as long as servers run
143-
wait "${RM_PID}" "${MRM_PID}"
143+
wait "${RM_PID}" "${MRM_PID}"

scripts/mpo_experiments/dgx/launch_train.sbatch

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,23 @@ mrm_address=$2
2424
###############################################################################
2525

2626
# Point to this repo (overrides the default in mpoppo_train.sh)
27-
export MPO_ROOT="/lustre/fs0/scratch/zkim/Development/mpo-old"
27+
export MPOPPO_ROOT="/lustre/fs0/scratch/zkim/Development/mpo-old"
2828

2929
# Load secrets (e.g., HF_TOKEN) if present
30-
if [ -f "${MPO_ROOT}/.env" ]; then
30+
if [ -f "${MPOPPO_ROOT}/.env" ]; then
3131
set -a
3232
# shellcheck source=/dev/null
33-
source "${MPO_ROOT}/.env"
33+
source "${MPOPPO_ROOT}/.env"
3434
set +a
3535
fi
3636

37-
: "${HF_TOKEN:?HF_TOKEN must be set (place it in ${MPO_ROOT}/.env)}"
37+
: "${HF_TOKEN:?HF_TOKEN must be set (place it in ${MPOPPO_ROOT}/.env)}"
3838

39-
cd "${MPO_ROOT}"
39+
cd "${MPOPPO_ROOT}"
4040

4141
# Force single-node, 8 GPU layout
4242
export CUDA_DEVICES_OVERRIDE="${CUDA_DEVICES_OVERRIDE:-0,1,2,3,4,5,6,7}"
43-
export ACC_CONFIG_OVERRIDE="${ACC_CONFIG_OVERRIDE:-${MPO_ROOT}/examples/accelerate_configs/deepspeed_zero2.yaml}"
43+
export ACC_CONFIG_OVERRIDE="${ACC_CONFIG_OVERRIDE:-${MPOPPO_ROOT}/examples/accelerate_configs/deepspeed_zero2.yaml}"
4444

4545
echo "Job ID: ${SLURM_JOB_ID}"
4646
echo "Node list: ${SLURM_NODELIST}"

scripts/mpo_experiments/dgx/mpoppo_train.sh

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,21 @@ prompt="evaluation_rubric_real_iter_0.txt"
2525
# Paths & constants
2626
###############################################################################
2727

28-
# Use MPO_ROOT if set; fallback to your explicit path
29-
trl_dir="${MPO_ROOT:-/lustre/fs0/scratch/zkim/Development/mpo}"
30-
SCRIPT="$trl_dir/examples/scripts/mpo.py"
28+
# Use MPOPPO_ROOT if set; fallback to your explicit path
29+
trl_dir="${MPOPPO_ROOT:-/lustre/fs0/scratch/zkim/Development/mpo-old}"
30+
SCRIPT="$trl_dir/examples/scripts/mpoppo.py"
3131

3232
WANDB_ENTITY="iterater"
3333
WANDB_PROJECT="mpoppo-new"
3434
DATASET="essay_writing"
3535
TASK="essay_writing"
36-
PROMPT_DIR="$trl_dir/trl/extras/mpo/prompts/essay_writing"
36+
PROMPT_DIR="$trl_dir/trl/extras/mpoppo/prompts/essay_writing"
3737

3838
###############################################################################
3939
# Main runner
4040
###############################################################################
4141
run_experiment() {
42-
local exp_type=$1 # mpogrpo / ppo …
42+
local exp_type=$1 # mpoppo / mpogrpo / ppo …
4343
local rubric_type=$2 # e.g. iter0
4444
local rm_params=$3 # reward-model size
4545
local mrm_params=$4 # meta-reward-model size
@@ -70,7 +70,7 @@ run_experiment() {
7070
# gradient accumulation scaling
7171
local grad_acc_steps=16
7272

73-
# MPOGRPO interval
73+
# MPOPPO/MPOGRPO interval
7474
local num_mpo_interval=99999999
7575
[[ "$exp_type" == "mpogrpo" || "$exp_type" == "mpoppo" ]] && num_mpo_interval=2
7676

@@ -106,7 +106,7 @@ run_experiment() {
106106
--learning_rate 3e-6 \
107107
--num_ppo_epochs 4 \
108108
--num_mpo_interval "$num_mpo_interval" \
109-
--save_n_updates 20 \
109+
--save_n_updates 2 \
110110
--num_mpo_samples 10 \
111111
--num_mini_batches 1 \
112112
--per_device_train_batch_size 2 \

scripts/mpo_experiments/elo_simulation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
load_dotenv() # take environment variables from .env.
1717

18-
exp_name = sys.argv[1] # "mpo_variations" "rm_32" "rm_72" "32b_32bvs32b_72b" ""72b_32bvs72b_72b""
18+
exp_name = sys.argv[1] # "mpoppo_variations" "rm_32" "rm_72" "32b_32bvs32b_72b" ""72b_32bvs72b_72b""
1919
num_matches = 2000
2020
task_name = "summarization" # "essay_writing"
2121
print(f"exp_name is: {exp_name}")
@@ -36,7 +36,7 @@
3636
"base-1.5b": "ModelD",
3737
"iter0-72b": "ModelE",
3838
}
39-
elif exp_name == "mpo_variations":
39+
elif exp_name == "mpoppo_variations":
4040
model_names_to_annon = {
4141
"32b_32b": "ModelA",
4242
"32b_72b": "ModelB",
@@ -53,7 +53,7 @@
5353
"72b_32b": "ModelA",
5454
"72b_72b": "ModelB",
5555
}
56-
elif exp_name == "mpo_vs_oracle":
56+
elif exp_name == "mpoppo_vs_oracle":
5757
model_names_to_annon = {
5858
"32b_72b": "ModelB",
5959
"72b_72b": "ModelD",
Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,20 @@ remote_host=$4 # default unchanged
2121
###############################################################################
2222
# Paths & constants
2323
###############################################################################
24-
trl_dir="$HOME/Development/trl"
25-
SCRIPT="$trl_dir/examples/scripts/mpo.py"
24+
trl_dir="${MPOPPO_ROOT:-$HOME/Development/trl}"
25+
SCRIPT="$trl_dir/examples/scripts/mpoppo.py"
2626

2727
WANDB_ENTITY="iterater"
28-
WANDB_PROJECT="mpo-new"
28+
WANDB_PROJECT="mpoppo-new"
2929
DATASET="math_reasoning"
3030
TASK="math_reasoning"
31-
PROMPT_DIR="$trl_dir/trl/extras/mpo/prompts/math_reasoning"
31+
PROMPT_DIR="$trl_dir/trl/extras/mpoppo/prompts/math_reasoning"
3232

3333
###############################################################################
3434
# Main runner
3535
###############################################################################
3636
run_experiment() {
37-
local exp_type=$1 # mpo / ppo …
37+
local exp_type=$1 # mpoppo / ppo …
3838
local rubric_type=$2 # e.g. iter0
3939
local rm=$3 # reward-model size (e.g. 1.5b)
4040
local mrm=$4 # meta-reward-model size (e.g. 3b)
@@ -51,7 +51,7 @@ run_experiment() {
5151
# ------------------------------------------------------------------------
5252
local policy_model="policy-1.5b"
5353
local model_name
54-
if [[ "$exp_type" == "mpo" ]]; then
54+
if [[ "$exp_type" == "mpoppo" ]]; then
5555
model_name="${rubric_type}-${rm}_${mrm}"
5656
else
5757
model_name="${rubric_type}-${rm}"
@@ -67,9 +67,9 @@ run_experiment() {
6767
# gradient accumulation scaling
6868
local grad_acc_steps=8
6969

70-
# MPO interval
70+
# MPOPPO interval
7171
local num_mpo_interval=99999999
72-
[[ "$exp_type" == "mpo" ]] && num_mpo_interval=20
72+
[[ "$exp_type" == "mpoppo" ]] && num_mpo_interval=20
7373

7474
local _mrm_address=$mrm_address
7575
[[ $rm == $mrm ]] && _mrm_address=$rm_address
@@ -132,7 +132,7 @@ run_experiment() {
132132
###############################################################################
133133
# Sweep
134134
###############################################################################
135-
exp_type="mpo"
135+
exp_type="mpoppo"
136136
rubric_type="iter0"
137137
prompt="evaluation_rubric_real_iter_0.txt"
138138
# rubric_type="autoprompt"
@@ -146,4 +146,4 @@ for rm in "${rms[@]}"; do
146146
# run_experiment "$exp_type" "$rubric_type" "$rm" "$mrm" "$prompt"
147147
run_experiment "$exp_type" "$rubric_type" "$rm" "$rm" "$prompt"
148148
done
149-
done
149+
done

scripts/mpo_experiments/llm_generations.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from tqdm import tqdm
1414
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, HfArgumentParser, TrainingArguments
1515

16-
from trl.extras.mpo import get_task_dataset
17-
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE, MPODataCollatorWithPadding, generate
16+
from trl.extras.mpoppo import get_task_dataset
17+
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE, MPOPPODataCollatorWithPadding, generate
1818

1919

2020
"""Example
@@ -153,7 +153,7 @@ class InferenceConfig(TrainingArguments):
153153
dataloader = DataLoader(
154154
dataset,
155155
batch_size=args.batch_size,
156-
collate_fn=MPODataCollatorWithPadding(tokenizer),
156+
collate_fn=MPOPPODataCollatorWithPadding(tokenizer),
157157
drop_last=False,
158158
)
159159

trl/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@
7272
"LogCompletionsCallback",
7373
"MergeModelCallback",
7474
"ModelConfig",
75-
"MPOConfig",
76-
"MPOTrainer",
75+
"MPOPPOConfig",
76+
"MPOPPOTrainer",
7777
"NashMDConfig",
7878
"NashMDTrainer",
7979
"OnlineDPOConfig",
@@ -169,8 +169,8 @@
169169
LogCompletionsCallback,
170170
MergeModelCallback,
171171
ModelConfig,
172-
MPOConfig,
173-
MPOTrainer,
172+
MPOPPOConfig,
173+
MPOPPOTrainer,
174174
NashMDConfig,
175175
NashMDTrainer,
176176
OnlineDPOConfig,

trl/extras/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919

2020
_import_structure = {
2121
"best_of_n_sampler": ["BestOfNSampler"],
22+
"mpoppo": [],
2223
}
2324

2425
if TYPE_CHECKING:
2526
from .best_of_n_sampler import BestOfNSampler
27+
from .mpoppo import * # noqa: F401,F403
2628
else:
2729
import sys
2830

0 commit comments

Comments
 (0)