Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ class DPOConfig(TrainingArguments):
tools (`Optional[list[Union[dict, Callable]]]`, *optional*):
List of tools (callable functions) that will be accessible to the model. If the template does not support
function calling, this argument will have no effect.
dataset_kwargs (`dict[str, Any]`, *optional*):
Dictionary of optional keyword arguments for the dataset preparation. The only supported key is
`skip_prepare_dataset`.

> Parameters that control the training

Expand Down Expand Up @@ -301,6 +304,13 @@ class DPOConfig(TrainingArguments):
default=None,
metadata={"help": "Number of processes to use for processing the dataset."},
)
dataset_kwargs: Optional[dict[str, Any]] = field(
default=None,
metadata={
"help": "Dictionary of optional keyword arguments for the dataset preparation. The only supported key is "
"`skip_prepare_dataset`."
},
)
pad_token: Optional[str] = field(
default=None,
metadata={
Expand Down
30 changes: 21 additions & 9 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,15 +470,19 @@ def __init__(
self.dataset_num_proc = args.dataset_num_proc

# Dataset preparation
train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train")
if eval_dataset is not None:
if isinstance(eval_dataset, dict):
eval_dataset = {
key: self._prepare_dataset(dataset, processing_class, args, key)
for key, dataset in eval_dataset.items()
}
else:
eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval")
skip_prepare_dataset = args.dataset_kwargs is not None and args.dataset_kwargs.get(
"skip_prepare_dataset", False
)
if not skip_prepare_dataset:
train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train")
if eval_dataset is not None:
if isinstance(eval_dataset, dict):
eval_dataset = {
key: self._prepare_dataset(dataset, processing_class, args, key)
for key, dataset in eval_dataset.items()
}
else:
eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval")

super().__init__(
model=model,
Expand Down Expand Up @@ -991,6 +995,8 @@ def concatenated_inputs(
)
if "image_sizes" in batch:
output["image_sizes"] = torch.cat([batch["image_sizes"], batch["image_sizes"]], dim=0)
if "image_grid_thw" in batch:
output["image_grid_thw"] = torch.cat([batch["image_grid_thw"], batch["image_grid_thw"]], dim=0)

# Concatenate the chosen and rejected completions
max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
Expand Down Expand Up @@ -1249,6 +1255,9 @@ def _compute_loss_liger(
model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"]
if "image_sizes" in concatenated_batch:
model_kwargs["image_sizes"] = concatenated_batch["image_sizes"]
# For Qwen-VL models
if "image_grid_thw" in concatenated_batch:
model_kwargs["image_grid_thw"] = concatenated_batch["image_grid_thw"]

prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
completion_attention_mask = concatenated_batch["completion_attention_mask"]
Expand Down Expand Up @@ -1496,6 +1505,9 @@ def concatenated_forward(
model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"]
if "image_sizes" in concatenated_batch:
model_kwargs["image_sizes"] = concatenated_batch["image_sizes"]
# For Qwen-VL models
if "image_grid_thw" in concatenated_batch:
model_kwargs["image_grid_thw"] = concatenated_batch["image_grid_thw"]

prompt_input_ids = concatenated_batch["prompt_input_ids"]
prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
Expand Down