Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 10 additions & 30 deletions tests/experimental/test_kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer

from trl.experimental.kto import KTOConfig, KTOTrainer
from trl.experimental.kto.kto_trainer import _get_kl_dataset, _process_tokens, _tokenize
Expand All @@ -31,26 +31,16 @@ def setup_method(self):
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.tokenizer.pad_token = self.tokenizer.eos_token

# get t5 as seq2seq example:
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration"
self.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
self.t5_ref_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
self.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)

@pytest.mark.parametrize(
"name, config_name, loss_type, pre_compute, eval_dataset",
"config_name, loss_type, pre_compute, eval_dataset",
[
("qwen", "standard_preference", "kto", True, True),
# ("t5", "standard_implicit_prompt_preference", "kto", True, False), # KTO broken for enc-dec
("qwen", "standard_unpaired_preference", "kto", False, True),
# ("t5", "conversational_preference", "kto", False, False),
("qwen", "conversational_implicit_prompt_preference", "apo_zero_unpaired", True, True),
# ("t5", "conversational_unpaired_preference", "apo_zero_unpaired", True, False),
("qwen", "standard_unpaired_preference", "apo_zero_unpaired", False, True),
# ("t5", "conversational_unpaired_preference", "apo_zero_unpaired", False, False),
("standard_preference", "kto", True, True),
("standard_unpaired_preference", "kto", False, True),
("conversational_implicit_prompt_preference", "apo_zero_unpaired", True, True),
("standard_unpaired_preference", "apo_zero_unpaired", False, True),
],
)
def test_kto_trainer(self, name, config_name, loss_type, pre_compute, eval_dataset):
def test_kto_trainer(self, config_name, loss_type, pre_compute, eval_dataset):
training_args = KTOConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=2,
Expand All @@ -67,20 +57,11 @@ def test_kto_trainer(self, name, config_name, loss_type, pre_compute, eval_datas

dummy_dataset = load_dataset("trl-internal-testing/zen", config_name)

if name == "qwen":
model = self.model
ref_model = self.ref_model
tokenizer = self.tokenizer
elif name == "t5":
model = self.t5_model
ref_model = self.t5_ref_model
tokenizer = self.t5_tokenizer

trainer = KTOTrainer(
model=model,
ref_model=ref_model,
model=self.model,
ref_model=self.ref_model,
args=training_args,
processing_class=tokenizer,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"] if eval_dataset else None,
)
Expand Down Expand Up @@ -172,7 +153,6 @@ def test_tokenize_and_process_tokens(self):

fn_kwargs = {
"prefix": "",
"is_encoder_decoder": trainer.is_encoder_decoder,
"tokenizer": trainer.processing_class,
"max_length": trainer.max_length,
"truncation_mode": trainer.truncation_mode,
Expand Down
22 changes: 1 addition & 21 deletions trl/experimental/kto/kto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ class KTOConfig(TrainingArguments):
to use the default data collator.
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt. This argument is required if you want to use the default data collator.
max_completion_length (`int`, *optional*):
Maximum length of the completion. This argument is required if you want to use the default data collator
and your model is an encoder-decoder.
beta (`float`, *optional*, defaults to `0.1`):
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
reference model.
Expand All @@ -65,9 +62,6 @@ class KTOConfig(TrainingArguments):
generate_during_eval (`bool`, *optional*, defaults to `False`):
If `True`, generates and logs completions from both the model and the reference model to W&B or Comet
during evaluation.
is_encoder_decoder (`bool`, *optional*):
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
you need to specify if the model returned by the callable is an encoder-decoder model.
precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
Whether to precompute reference model log probabilities for training and evaluation datasets. This is
useful when training without the reference model to reduce the total GPU memory needed.
Expand Down Expand Up @@ -144,14 +138,7 @@ class KTOConfig(TrainingArguments):
default=512,
metadata={
"help": "Maximum length of the prompt. This argument is required if you want to use the default data "
"collator and your model is an encoder-decoder."
},
)
max_completion_length: int | None = field(
default=None,
metadata={
"help": "Maximum length of the completion. This argument is required if you want to use the default data "
"collator and your model is an encoder-decoder."
"collator."
},
)
beta: float = field(
Expand Down Expand Up @@ -206,13 +193,6 @@ class KTOConfig(TrainingArguments):
"during evaluation."
},
)
is_encoder_decoder: bool | None = field(
default=None,
metadata={
"help": "When using the `model_init` argument (callable) to instantiate the model instead of the `model` "
"argument, you need to specify if the model returned by the callable is an encoder-decoder model."
},
)
disable_dropout: bool = field(
default=True,
metadata={"help": "Whether to disable dropout in the model."},
Expand Down
Loading
Loading