diff --git a/tests/experimental/test_kto_trainer.py b/tests/experimental/test_kto_trainer.py index 9b13325a34..7108de87c2 100644 --- a/tests/experimental/test_kto_trainer.py +++ b/tests/experimental/test_kto_trainer.py @@ -15,7 +15,7 @@ import pytest import torch from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer from trl.experimental.kto import KTOConfig, KTOTrainer from trl.experimental.kto.kto_trainer import _get_kl_dataset, _process_tokens, _tokenize @@ -31,26 +31,16 @@ def setup_method(self): self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.tokenizer.pad_token = self.tokenizer.eos_token - # get t5 as seq2seq example: - model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration" - self.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) - self.t5_ref_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) - self.t5_tokenizer = AutoTokenizer.from_pretrained(model_id) - @pytest.mark.parametrize( - "name, config_name, loss_type, pre_compute, eval_dataset", + "config_name, loss_type, pre_compute, eval_dataset", [ - ("qwen", "standard_preference", "kto", True, True), - # ("t5", "standard_implicit_prompt_preference", "kto", True, False), # KTO broken for enc-dec - ("qwen", "standard_unpaired_preference", "kto", False, True), - # ("t5", "conversational_preference", "kto", False, False), - ("qwen", "conversational_implicit_prompt_preference", "apo_zero_unpaired", True, True), - # ("t5", "conversational_unpaired_preference", "apo_zero_unpaired", True, False), - ("qwen", "standard_unpaired_preference", "apo_zero_unpaired", False, True), - # ("t5", "conversational_unpaired_preference", "apo_zero_unpaired", False, False), + ("standard_preference", "kto", True, True), + ("standard_unpaired_preference", "kto", False, True), + ("conversational_implicit_prompt_preference", "apo_zero_unpaired", True, True), + ("standard_unpaired_preference", "apo_zero_unpaired", False, True), ], ) - def test_kto_trainer(self, name, config_name, loss_type, pre_compute, eval_dataset): + def test_kto_trainer(self, config_name, loss_type, pre_compute, eval_dataset): training_args = KTOConfig( output_dir=self.tmp_dir, per_device_train_batch_size=2, @@ -67,20 +57,11 @@ def test_kto_trainer(self, name, config_name, loss_type, pre_compute, eval_datas dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) - if name == "qwen": - model = self.model - ref_model = self.ref_model - tokenizer = self.tokenizer - elif name == "t5": - model = self.t5_model - ref_model = self.t5_ref_model - tokenizer = self.t5_tokenizer - trainer = KTOTrainer( - model=model, - ref_model=ref_model, + model=self.model, + ref_model=self.ref_model, args=training_args, - processing_class=tokenizer, + processing_class=self.tokenizer, train_dataset=dummy_dataset["train"], eval_dataset=dummy_dataset["test"] if eval_dataset else None, ) @@ -172,7 +153,6 @@ def test_tokenize_and_process_tokens(self): fn_kwargs = { "prefix": "", - "is_encoder_decoder": trainer.is_encoder_decoder, "tokenizer": trainer.processing_class, "max_length": trainer.max_length, "truncation_mode": trainer.truncation_mode, diff --git a/trl/experimental/kto/kto_config.py b/trl/experimental/kto/kto_config.py index 6801d47c7b..b67b6416a1 100644 --- a/trl/experimental/kto/kto_config.py +++ b/trl/experimental/kto/kto_config.py @@ -38,9 +38,6 @@ class KTOConfig(TrainingArguments): to use the default data collator. max_prompt_length (`int` or `None`, *optional*, defaults to `512`): Maximum length of the prompt. This argument is required if you want to use the default data collator. - max_completion_length (`int`, *optional*): - Maximum length of the completion. This argument is required if you want to use the default data collator - and your model is an encoder-decoder. beta (`float`, *optional*, defaults to `0.1`): Parameter controlling the deviation from the reference model. Higher β means less deviation from the reference model. @@ -65,9 +62,6 @@ class KTOConfig(TrainingArguments): generate_during_eval (`bool`, *optional*, defaults to `False`): If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during evaluation. - is_encoder_decoder (`bool`, *optional*): - When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument, - you need to specify if the model returned by the callable is an encoder-decoder model. precompute_ref_log_probs (`bool`, *optional*, defaults to `False`): Whether to precompute reference model log probabilities for training and evaluation datasets. This is useful when training without the reference model to reduce the total GPU memory needed. @@ -144,14 +138,7 @@ class KTOConfig(TrainingArguments): default=512, metadata={ "help": "Maximum length of the prompt. This argument is required if you want to use the default data " - "collator and your model is an encoder-decoder." - }, - ) - max_completion_length: int | None = field( - default=None, - metadata={ - "help": "Maximum length of the completion. This argument is required if you want to use the default data " - "collator and your model is an encoder-decoder." + "collator." }, ) beta: float = field( @@ -206,13 +193,6 @@ class KTOConfig(TrainingArguments): "during evaluation." }, ) - is_encoder_decoder: bool | None = field( - default=None, - metadata={ - "help": "When using the `model_init` argument (callable) to instantiate the model instead of the `model` " - "argument, you need to specify if the model returned by the callable is an encoder-decoder model." - }, - ) disable_dropout: bool = field( default=True, metadata={"help": "Whether to disable dropout in the model."}, diff --git a/trl/experimental/kto/kto_trainer.py b/trl/experimental/kto/kto_trainer.py index 1376f09e8d..6a9b7f36aa 100644 --- a/trl/experimental/kto/kto_trainer.py +++ b/trl/experimental/kto/kto_trainer.py @@ -171,105 +171,83 @@ def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, ** f"{kwargs['prefix']}label": example["label"], } - if not kwargs["is_encoder_decoder"]: - # Check issues below for more details - # 1. https://github.com/huggingface/trl/issues/907 - # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 - # 3. https://github.com/LianjiaTech/BELLE/issues/337 - - if not isinstance(prompt, str): - raise ValueError(f"prompt should be an str but got {type(prompt)}") - - if not isinstance(completion, str): - raise ValueError(f"completion should be an str but got {type(completion)}") - - # keys of format prompt_* refers to just the prompt and answer_* refers to just the answer - all_tokens = { - "prompt_input_ids": example["prompt_input_ids"], - "prompt_attention_mask": example["prompt_attention_mask"], - "answer_input_ids": example["answer_input_ids"], - "answer_attention_mask": example["answer_attention_mask"], - } + # Check issues below for more details + # 1. https://github.com/huggingface/trl/issues/907 + # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + # 3. https://github.com/LianjiaTech/BELLE/issues/337 + + if not isinstance(prompt, str): + raise ValueError(f"prompt should be an str but got {type(prompt)}") + + if not isinstance(completion, str): + raise ValueError(f"completion should be an str but got {type(completion)}") + + # keys of format prompt_* refers to just the prompt and answer_* refers to just the answer + all_tokens = { + "prompt_input_ids": example["prompt_input_ids"], + "prompt_attention_mask": example["prompt_attention_mask"], + "answer_input_ids": example["answer_input_ids"], + "answer_attention_mask": example["answer_attention_mask"], + } - # calculate max length by checking if BOS/EOS is already there - max_length = kwargs["max_length"] - bos_token_id = kwargs["tokenizer"].bos_token_id - eos_token_id = kwargs["tokenizer"].eos_token_id - if len(all_tokens["prompt_input_ids"]) > 0 and bos_token_id != all_tokens["prompt_input_ids"][0]: - max_length -= 1 - if len(all_tokens["answer_input_ids"]) > 0 and eos_token_id != all_tokens["answer_input_ids"][-1]: - max_length -= 1 - - # if combined sequence is too long (> max_length - 1 for BOS token - 1 for EOS), truncate the prompt - if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length: - for k in ["prompt_input_ids", "prompt_attention_mask"]: - if kwargs["truncation_mode"] == "keep_start": - all_tokens[k] = all_tokens[k][: kwargs["max_prompt_length"]] - elif kwargs["truncation_mode"] == "keep_end": - all_tokens[k] = all_tokens[k][-kwargs["max_prompt_length"] :] - else: - raise ValueError(f"Unknown truncation mode: {kwargs['truncation_mode']}") - - # if that's still too long, truncate the response - if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length: - for k in ["answer_input_ids", "answer_attention_mask"]: - all_tokens[k] = all_tokens[k][: max_length - kwargs["max_prompt_length"]] - - # all input_ids and attention mask as is. We then check if we need to add BOS/EOS tokens - batch[f"{kwargs['prefix']}prompt_input_ids"] = all_tokens["prompt_input_ids"] - batch[f"{kwargs['prefix']}prompt_attention_mask"] = all_tokens["prompt_attention_mask"] - batch[f"{kwargs['prefix']}completion_input_ids"] = ( - all_tokens["prompt_input_ids"] + all_tokens["answer_input_ids"] - ) - batch[f"{kwargs['prefix']}completion_attention_mask"] = ( - all_tokens["prompt_attention_mask"] + all_tokens["answer_attention_mask"] - ) + # calculate max length by checking if BOS/EOS is already there + max_length = kwargs["max_length"] + bos_token_id = kwargs["tokenizer"].bos_token_id + eos_token_id = kwargs["tokenizer"].eos_token_id + if len(all_tokens["prompt_input_ids"]) > 0 and bos_token_id != all_tokens["prompt_input_ids"][0]: + max_length -= 1 + if len(all_tokens["answer_input_ids"]) > 0 and eos_token_id != all_tokens["answer_input_ids"][-1]: + max_length -= 1 + + # if combined sequence is too long (> max_length - 1 for BOS token - 1 for EOS), truncate the prompt + if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length: + for k in ["prompt_input_ids", "prompt_attention_mask"]: + if kwargs["truncation_mode"] == "keep_start": + all_tokens[k] = all_tokens[k][: kwargs["max_prompt_length"]] + elif kwargs["truncation_mode"] == "keep_end": + all_tokens[k] = all_tokens[k][-kwargs["max_prompt_length"] :] + else: + raise ValueError(f"Unknown truncation mode: {kwargs['truncation_mode']}") + + # if that's still too long, truncate the response + if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length: + for k in ["answer_input_ids", "answer_attention_mask"]: + all_tokens[k] = all_tokens[k][: max_length - kwargs["max_prompt_length"]] + + # all input_ids and attention mask as is. We then check if we need to add BOS/EOS tokens + batch[f"{kwargs['prefix']}prompt_input_ids"] = all_tokens["prompt_input_ids"] + batch[f"{kwargs['prefix']}prompt_attention_mask"] = all_tokens["prompt_attention_mask"] + batch[f"{kwargs['prefix']}completion_input_ids"] = all_tokens["prompt_input_ids"] + all_tokens["answer_input_ids"] + batch[f"{kwargs['prefix']}completion_attention_mask"] = ( + all_tokens["prompt_attention_mask"] + all_tokens["answer_attention_mask"] + ) - # add BOS, which affects both prompt and the full completion - if bos_token_id is not None: - if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]: - batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[ - f"{kwargs['prefix']}prompt_input_ids" - ] - batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[ - f"{kwargs['prefix']}prompt_attention_mask" - ] - batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[ - f"{kwargs['prefix']}completion_input_ids" - ] - batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[ - f"{kwargs['prefix']}completion_attention_mask" - ] - # add EOS, which affects only the full completion - if len(all_tokens["answer_input_ids"]) == 0 or eos_token_id != all_tokens["answer_input_ids"][-1]: - batch[f"{kwargs['prefix']}completion_input_ids"] = batch[f"{kwargs['prefix']}completion_input_ids"] + [ - eos_token_id + # add BOS, which affects both prompt and the full completion + if bos_token_id is not None: + if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]: + batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[ + f"{kwargs['prefix']}prompt_input_ids" + ] + batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[f"{kwargs['prefix']}prompt_attention_mask"] + batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[ + f"{kwargs['prefix']}completion_input_ids" ] - batch[f"{kwargs['prefix']}completion_attention_mask"] = batch[ + batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[ f"{kwargs['prefix']}completion_attention_mask" - ] + [1] - - batch[f"{kwargs['prefix']}completion_labels"] = batch[f"{kwargs['prefix']}completion_input_ids"][:] - batch[f"{kwargs['prefix']}completion_labels"][: len(batch[f"{kwargs['prefix']}prompt_input_ids"])] = [ - kwargs["label_pad_token_id"] - ] * len(batch[f"{kwargs['prefix']}prompt_input_ids"]) - else: - completion_tokens = kwargs["tokenizer"]( - completion, truncation=True, max_length=kwargs["max_completion_length"], add_special_tokens=True - ) - prompt_tokens = kwargs["tokenizer"]( - prompt, truncation=True, max_length=kwargs["max_prompt_length"], add_special_tokens=True - ) - - batch[f"{kwargs['prefix']}prompt_input_ids"] = prompt_tokens["input_ids"] - batch[f"{kwargs['prefix']}prompt_attention_mask"] = prompt_tokens["attention_mask"] - - batch[f"{kwargs['prefix']}completion_labels"] = completion_tokens["input_ids"] - batch[f"{kwargs['prefix']}completion_attention_mask"] = completion_tokens["attention_mask"] - if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"): - batch[f"{kwargs['prefix']}completion_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( - labels=torch.tensor(batch["completion_labels"]) - ) + ] + # add EOS, which affects only the full completion + if len(all_tokens["answer_input_ids"]) == 0 or eos_token_id != all_tokens["answer_input_ids"][-1]: + batch[f"{kwargs['prefix']}completion_input_ids"] = batch[f"{kwargs['prefix']}completion_input_ids"] + [ + eos_token_id + ] + batch[f"{kwargs['prefix']}completion_attention_mask"] = batch[ + f"{kwargs['prefix']}completion_attention_mask" + ] + [1] + + batch[f"{kwargs['prefix']}completion_labels"] = batch[f"{kwargs['prefix']}completion_input_ids"][:] + batch[f"{kwargs['prefix']}completion_labels"][: len(batch[f"{kwargs['prefix']}prompt_input_ids"])] = [ + kwargs["label_pad_token_id"] + ] * len(batch[f"{kwargs['prefix']}prompt_input_ids"]) return batch @@ -461,12 +439,12 @@ def make_inputs_require_grad(module, input, output): " Please install `wandb` or `comet-ml` to resolve." ) - if model is not None: - self.is_encoder_decoder = model.config.is_encoder_decoder - elif args.is_encoder_decoder is None: - raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") - else: - self.is_encoder_decoder = args.is_encoder_decoder + # KTO only supports causal language models, not encoder-decoder models + if model is not None and hasattr(model.config, "is_encoder_decoder") and model.config.is_encoder_decoder: + raise ValueError( + "KTO only supports causal language models. Encoder-decoder models are not supported. " + "Please use a causal LM (e.g., GPT, Llama, Mistral) instead of an encoder-decoder model (e.g., T5, BART)." + ) self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) self.model_adapter_name = model_adapter_name @@ -502,21 +480,10 @@ def make_inputs_require_grad(module, input, output): if args.max_prompt_length is not None: max_prompt_length = args.max_prompt_length - max_completion_length = None - if args.max_completion_length is None and self.is_encoder_decoder: - logger.warning( - "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init" - " it will be set to `128` by default, but you should do it yourself in the future.", - ) - max_completion_length = 128 - if args.max_completion_length is not None and self.is_encoder_decoder: - max_completion_length = args.max_completion_length - if data_collator is None: data_collator = DPODataCollatorWithPadding( pad_token_id=processing_class.pad_token_id, label_pad_token_id=args.label_pad_token_id, - is_encoder_decoder=self.is_encoder_decoder, ) if args.remove_unused_columns: @@ -544,7 +511,6 @@ def make_inputs_require_grad(module, input, output): self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id self.max_prompt_length = max_prompt_length self.truncation_mode = args.truncation_mode - self.max_completion_length = max_completion_length self.processing_class = processing_class self.precompute_ref_log_probs = args.precompute_ref_log_probs @@ -627,13 +593,11 @@ def make_inputs_require_grad(module, input, output): fn_kwargs = { "prefix": "", - "is_encoder_decoder": self.is_encoder_decoder, "tokenizer": self.processing_class, "max_length": self.max_length, "truncation_mode": self.truncation_mode, "label_pad_token_id": self.label_pad_token_id, "max_prompt_length": self.max_prompt_length, - "max_completion_length": self.max_completion_length, } train_dataset = train_dataset.map( @@ -923,64 +887,31 @@ def compute_reference_log_probs(self, padded_batch: dict) -> dict: with torch.no_grad(): if self.ref_model is None: with self.null_ref_context(): - if self.is_encoder_decoder: - completion_logits = self.model( - padded_batch["prompt_input_ids"], - attention_mask=padded_batch["prompt_attention_mask"], - decoder_input_ids=padded_batch.get("completion_decoder_input_ids"), - labels=padded_batch["completion_labels"], - ).logits - - if self.calculate_KL: - KL_logits = self.model( - padded_batch["KL_prompt_input_ids"], - attention_mask=padded_batch["KL_prompt_attention_mask"], - decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"), - labels=padded_batch["KL_completion_labels"], - ).logits - else: - completion_logits = self.model( - padded_batch["completion_input_ids"], - attention_mask=padded_batch["completion_attention_mask"], - ).logits - - if self.calculate_KL: - KL_logits = self.model( - padded_batch["KL_completion_input_ids"], - attention_mask=padded_batch["KL_completion_attention_mask"], - ).logits - else: - if self.is_encoder_decoder: - completion_logits = self.ref_model( - padded_batch["prompt_input_ids"], - attention_mask=padded_batch["prompt_attention_mask"], - decoder_input_ids=padded_batch.get("completion_decoder_input_ids"), - labels=padded_batch["completion_labels"], - ).logits - - if self.calculate_KL: - KL_logits = self.ref_model( - padded_batch["KL_prompt_input_ids"], - attention_mask=padded_batch["KL_prompt_attention_mask"], - decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"), - labels=padded_batch["KL_completion_labels"], - ).logits - else: - completion_logits = self.ref_model( - padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"] + completion_logits = self.model( + padded_batch["completion_input_ids"], + attention_mask=padded_batch["completion_attention_mask"], ).logits if self.calculate_KL: - KL_logits = self.ref_model( + KL_logits = self.model( padded_batch["KL_completion_input_ids"], attention_mask=padded_batch["KL_completion_attention_mask"], ).logits + else: + completion_logits = self.ref_model( + padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"] + ).logits + + if self.calculate_KL: + KL_logits = self.ref_model( + padded_batch["KL_completion_input_ids"], + attention_mask=padded_batch["KL_completion_attention_mask"], + ).logits completion_logps = self.get_batch_logps( completion_logits, padded_batch["completion_labels"], average_log_prob=False, - is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, ) @@ -989,7 +920,6 @@ def compute_reference_log_probs(self, padded_batch: dict) -> dict: KL_logits, padded_batch["KL_completion_labels"], average_log_prob=False, - is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, ) else: @@ -1003,7 +933,6 @@ def get_batch_logps( labels: torch.LongTensor, average_log_prob: bool = False, label_pad_token_id: int = -100, - is_encoder_decoder: bool = False, ) -> torch.FloatTensor: """Compute the log probabilities of the given labels under the given logits. @@ -1018,10 +947,6 @@ def get_batch_logps( log probabilities of the (non-masked) tokens. label_pad_token_id: The label value to ignore when computing log probabilities. - is_encoder_decoder: - Whether the model is an encoder-decoder model. If True, the labels are not shifted and the logits are - assumed to already be aligned with the labels. If False, the labels are shifted to the right by one - position, and the logits are assumed to be aligned with the shifted labels. Returns: A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the @@ -1030,12 +955,9 @@ def get_batch_logps( if logits.shape[:-1] != labels.shape: raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") - if not is_encoder_decoder: - labels = labels[:, 1:].clone() - logits = logits[:, :-1, :] - else: - # Fixes end-dec RuntimeError - labels = labels.clone() + # For causal LM, shift labels and logits by one position + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] loss_mask = labels != label_pad_token_id @@ -1054,14 +976,7 @@ def forward( ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: KL_logps = self._compute_kl_logps(model, batch) - model_kwargs = ( - { - "labels": batch["completion_labels"], - "decoder_input_ids": batch.get("completion_decoder_input_ids"), - } - if self.is_encoder_decoder - else {} - ) + model_kwargs = {} if self.aux_loss_enabled: model_kwargs["output_router_logits"] = True @@ -1076,7 +991,6 @@ def forward( completion_logits, batch["completion_labels"], average_log_prob=False, - is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, ) @@ -1185,18 +1099,10 @@ def _compute_kl_logps(self, model, batch): """Compute KL log probabilities for a given batch.""" KL_logps = None if self.calculate_KL: - if self.is_encoder_decoder: - KL_model_kwargs = { - "input_ids": batch["KL_prompt_input_ids"], - "attention_mask": batch["KL_prompt_attention_mask"], - "labels": batch["KL_completion_labels"], - "decoder_input_ids": batch.get("KL_completion_decoder_input_ids"), - } - else: - KL_model_kwargs = { - "input_ids": batch["KL_completion_input_ids"], - "attention_mask": batch["KL_completion_attention_mask"], - } + KL_model_kwargs = { + "input_ids": batch["KL_completion_input_ids"], + "attention_mask": batch["KL_completion_attention_mask"], + } with torch.no_grad(): KL_logits = model(**KL_model_kwargs).logits @@ -1205,7 +1111,6 @@ def _compute_kl_logps(self, model, batch): KL_logits, batch["KL_completion_labels"], average_log_prob=False, - is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, ) return KL_logps @@ -1242,72 +1147,35 @@ def _compute_loss_liger(self, model, batch): else: kl = torch.zeros(1).to(self.accelerator.device) - model_kwargs = ( - { - "labels": batch["completion_labels"], - "decoder_input_ids": batch.get("completion_decoder_input_ids"), - } - if self.is_encoder_decoder - else {} - ) + model_kwargs = {} if self.aux_loss_enabled: model_kwargs["output_router_logits"] = True - if self.is_encoder_decoder: - # 1. Get encoder outputs - encoder_outputs = model.get_encoder()( - batch["completion_input_ids"], - attention_mask=batch["completion_attention_mask"], - return_dict=True, - **model_kwargs, - ) - # 2. Get decoder outputs - outputs = model.get_decoder()( - input_ids=model_kwargs["decoder_input_ids"], - encoder_hidden_states=encoder_outputs.last_hidden_state, - use_cache=False, - **model_kwargs, - ) - # 1. Get reference encoder outputs - ref_encoder_outputs = self.ref_model.get_encoder()( - batch["completion_input_ids"], - attention_mask=batch["completion_attention_mask"], - return_dict=True, - **model_kwargs, - ) - # 2. Get reference decoder outputs - ref_outputs = self.ref_model.get_decoder()( - input_ids=model_kwargs["decoder_input_ids"], - encoder_hidden_states=ref_encoder_outputs.last_hidden_state, - use_cache=False, - **model_kwargs, - ) + # skip the lm head and get the last hidden state + if hasattr(model, "get_decoder") and model.get_decoder() is not None: + base_model = model.get_decoder() else: - # skip the lm head and get the last hidden state - if hasattr(model, "get_decoder") and model.get_decoder() is not None: - base_model = model.get_decoder() - else: - base_attr = getattr(model, "base_model_prefix", self.args.base_model_attribute_name) - base_model = getattr(model, base_attr, model) - outputs = base_model( - batch["completion_input_ids"], - attention_mask=batch["completion_attention_mask"], - use_cache=False, - **model_kwargs, - ) + base_attr = getattr(model, "base_model_prefix", self.args.base_model_attribute_name) + base_model = getattr(model, base_attr, model) + outputs = base_model( + batch["completion_input_ids"], + attention_mask=batch["completion_attention_mask"], + use_cache=False, + **model_kwargs, + ) - # reference model - if hasattr(self.ref_model, "get_decoder") and self.ref_model.get_decoder() is not None: - ref_base_model = self.ref_model.get_decoder() - else: - ref_attr = getattr(self.ref_model, "base_model_prefix", self.args.base_model_attribute_name) - ref_base_model = getattr(self.ref_model, ref_attr, self.ref_model) - ref_outputs = ref_base_model( - batch["completion_input_ids"], - attention_mask=batch["completion_attention_mask"], - use_cache=False, - **model_kwargs, - ) + # reference model + if hasattr(self.ref_model, "get_decoder") and self.ref_model.get_decoder() is not None: + ref_base_model = self.ref_model.get_decoder() + else: + ref_attr = getattr(self.ref_model, "base_model_prefix", self.args.base_model_attribute_name) + ref_base_model = getattr(self.ref_model, ref_attr, self.ref_model) + ref_outputs = ref_base_model( + batch["completion_input_ids"], + attention_mask=batch["completion_attention_mask"], + use_cache=False, + **model_kwargs, + ) lm_head = model.get_output_embeddings() ref_lm_head = self.ref_model.get_output_embeddings() @@ -1322,14 +1190,12 @@ def _compute_loss_liger(self, model, batch): rejected_rewards_sum, ), ) = self.kto_loss_fn( - _input=outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state, + _input=outputs.last_hidden_state[:, :-1], lin_weight=lm_head.weight, target=batch["completion_labels"][:, 1:], bias=lm_head.bias if hasattr(lm_head, "bias") else None, preference_labels=torch.tensor(batch["label"], dtype=torch.bool).to(self.accelerator.device), - ref_input=ref_outputs.last_hidden_state[:, :-1] - if not self.is_encoder_decoder - else outputs.last_hidden_state, + ref_input=ref_outputs.last_hidden_state[:, :-1], ref_weight=ref_lm_head.weight, ref_bias=ref_lm_head.bias if hasattr(lm_head, "bias") else None, kl=kl,