From 4a3af9eb41661e713c426063ec12940494cb8579 Mon Sep 17 00:00:00 2001 From: Fabian Schuetze Date: Sat, 11 Oct 2025 14:08:55 +0200 Subject: [PATCH 1/3] optinal collate-func and lazy encoding --- trl/trainer/dpo_config.py | 10 ++++++++++ trl/trainer/dpo_trainer.py | 25 ++++++++++++++++--------- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index b9e73a78ff4..f98dfe987c7 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -126,6 +126,9 @@ class DPOConfig(TrainingArguments): tools (`Optional[list[Union[dict, Callable]]]`, *optional*): List of tools (callable functions) that will be accessible to the model. If the template does not support function calling, this argument will have no effect. + dataset_kwargs (`dict[str, Any]`, *optional*): + Dictionary of optional keyword arguments for the dataset preparation. The only supported key is + `skip_prepare_dataset`. > Parameters that control the training @@ -301,6 +304,13 @@ class DPOConfig(TrainingArguments): default=None, metadata={"help": "Number of processes to use for processing the dataset."}, ) + dataset_kwargs: Optional[dict[str, Any]] = field( + default=None, + metadata={ + "help": "Dictionary of optional keyword arguments for the dataset preparation. The only supported key is " + "`skip_prepare_dataset`." + }, + ) pad_token: Optional[str] = field( default=None, metadata={ diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index bfcc4b4c53e..e929339effb 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -470,15 +470,17 @@ def __init__( self.dataset_num_proc = args.dataset_num_proc # Dataset preparation - train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") - if eval_dataset is not None: - if isinstance(eval_dataset, dict): - eval_dataset = { - key: self._prepare_dataset(dataset, processing_class, args, key) - for key, dataset in eval_dataset.items() - } - else: - eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") + skip_prepare_dataset = args.dataset_kwargs.get("skip_prepare_dataset", False) + if not skip_prepare_dataset: + train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") + if eval_dataset is not None: + if isinstance(eval_dataset, dict): + eval_dataset = { + key: self._prepare_dataset(dataset, processing_class, args, key) + for key, dataset in eval_dataset.items() + } + else: + eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") super().__init__( model=model, @@ -991,6 +993,8 @@ def concatenated_inputs( ) if "image_sizes" in batch: output["image_sizes"] = torch.cat([batch["image_sizes"], batch["image_sizes"]], dim=0) + if "image_grid_thw" in batch: + output["image_grid_thw"] = torch.cat([batch["image_grid_thw"], batch["image_grid_thw"]], dim=0) # Concatenate the chosen and rejected completions max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) @@ -1249,6 +1253,9 @@ def _compute_loss_liger( model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] if "image_sizes" in concatenated_batch: model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] + # For Qwen-VL models + if "image_grid_thw" in concatenated_batch: + model_kwargs['image_grid_thw'] = concatenated_batch['image_grid_tw'] prompt_attention_mask = concatenated_batch["prompt_attention_mask"] completion_attention_mask = concatenated_batch["completion_attention_mask"] From c8760fa73add9bc90d2e6a98f1fc163aac00b1a5 Mon Sep 17 00:00:00 2001 From: Fabian Schuetze Date: Sat, 11 Oct 2025 14:35:19 +0200 Subject: [PATCH 2/3] check if arg is set --- trl/trainer/dpo_trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index e929339effb..5012331d832 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -470,7 +470,9 @@ def __init__( self.dataset_num_proc = args.dataset_num_proc # Dataset preparation - skip_prepare_dataset = args.dataset_kwargs.get("skip_prepare_dataset", False) + skip_prepare_dataset = args.dataset_kwargs is not None and args.dataset_kwargs.get( + "skip_prepare_dataset", False + ) if not skip_prepare_dataset: train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") if eval_dataset is not None: @@ -1255,7 +1257,7 @@ def _compute_loss_liger( model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] # For Qwen-VL models if "image_grid_thw" in concatenated_batch: - model_kwargs['image_grid_thw'] = concatenated_batch['image_grid_tw'] + model_kwargs["image_grid_thw"] = concatenated_batch["image_grid_tw"] prompt_attention_mask = concatenated_batch["prompt_attention_mask"] completion_attention_mask = concatenated_batch["completion_attention_mask"] From adb34f339895e2f92a33d47195cc46c3c1f3bda9 Mon Sep 17 00:00:00 2001 From: Fabian Schuetze Date: Sat, 11 Oct 2025 17:31:32 +0200 Subject: [PATCH 3/3] add to concatenated forward --- trl/trainer/dpo_trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 5012331d832..8cdd0f4b1da 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1257,7 +1257,7 @@ def _compute_loss_liger( model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] # For Qwen-VL models if "image_grid_thw" in concatenated_batch: - model_kwargs["image_grid_thw"] = concatenated_batch["image_grid_tw"] + model_kwargs["image_grid_thw"] = concatenated_batch["image_grid_thw"] prompt_attention_mask = concatenated_batch["prompt_attention_mask"] completion_attention_mask = concatenated_batch["completion_attention_mask"] @@ -1505,6 +1505,9 @@ def concatenated_forward( model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] if "image_sizes" in concatenated_batch: model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] + # For Qwen-VL models + if "image_grid_thw" in concatenated_batch: + model_kwargs["image_grid_thw"] = concatenated_batch["image_grid_thw"] prompt_input_ids = concatenated_batch["prompt_input_ids"] prompt_attention_mask = concatenated_batch["prompt_attention_mask"]