Skip to content

Commit cd35db0

Browse files
committed
Merge remote-tracking branch 'origin/vllm_client_custom_url' into vllm_client_custom_url
2 parents a232a1a + c3408cd commit cd35db0

File tree

6 files changed

+70
-20
lines changed

6 files changed

+70
-20
lines changed

docs/source/sft_trainer.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -424,9 +424,9 @@ Below are some numbers you can get in terms of speedup and memory efficiency, us
424424

425425
| use_flash_attn_1 | model_name | max_seq_len | batch_size | time per training step |
426426
| ---------------- | ----------------- | ----------- | ---------- | ---------------------- |
427-
| x | facebook/opt-350m | 2048 | 8 | ~59.1s |
427+
| | facebook/opt-350m | 2048 | 8 | ~59.1s |
428428
| | facebook/opt-350m | 2048 | 8 | **OOM** |
429-
| x | facebook/opt-350m | 2048 | 4 | ~30.3s |
429+
| | facebook/opt-350m | 2048 | 4 | ~30.3s |
430430
| | facebook/opt-350m | 2048 | 4 | ~148.9s |
431431

432432
### Using Flash Attention-2

docs/source/text_environments.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ The `TextHistory` object stores the interactions between the model and the text
157157

158158
### Attributes
159159

160-
The following table summarises the available attributes of the `TextEnvironment` class:
160+
The following table summarises the available attributes of the `TextHistory` class:
161161

162162
| Attribute | Description |
163163
|:-------------------|:----------------|

tests/test_grpo_trainer.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from transformers.utils import is_peft_available
2525

2626
from trl import GRPOConfig, GRPOTrainer
27-
from trl.trainer.grpo_trainer import RepeatRandomSampler
27+
from trl.trainer.grpo_trainer import RepeatSampler
2828

2929
from .testing_utils import require_vllm
3030

@@ -36,7 +36,7 @@
3636
class RepeatRandomSamplerTester(unittest.TestCase):
3737
def test_sampler(self):
3838
dataset = ["a", "b", "c", "d", "e", "f", "g"]
39-
sampler = RepeatRandomSampler(dataset, mini_repeat_count=2)
39+
sampler = RepeatSampler(dataset, mini_repeat_count=2)
4040
# Should output something like [4, 4, 3, 3, 0, 0, 1, 1, 2, 2, 6, 6, 5, 5]
4141
sampled = list(sampler)
4242
# Check that the length is doubled
@@ -46,9 +46,16 @@ def test_sampler(self):
4646
# Check that each element is repeated twice
4747
assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2))
4848

49+
def test_sampler_no_shuffle(self):
50+
dataset = ["a", "b", "c", "d", "e", "f", "g"]
51+
sampler = RepeatSampler(dataset, mini_repeat_count=2, shuffle=False)
52+
sampled = list(sampler)
53+
expected = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6]
54+
self.assertEqual(sampled, expected)
55+
4956
def test_sampler_no_repeat(self):
5057
dataset = ["a", "b", "c", "d", "e", "f", "g"]
51-
sampler = RepeatRandomSampler(dataset, mini_repeat_count=1)
58+
sampler = RepeatSampler(dataset, mini_repeat_count=1)
5259
# Should output something like [4, 3, 0, 1, 2, 6, 5]
5360
sampled = list(sampler)
5461
# Check that the length is the same
@@ -58,7 +65,7 @@ def test_sampler_no_repeat(self):
5865

5966
def test_sampler_with_batch_size(self):
6067
dataset = ["a", "b", "c", "d", "e", "f", "g", "h"]
61-
sampler = RepeatRandomSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2)
68+
sampler = RepeatSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2)
6269
# Should output something like [4, 3, 4, 3, 0, 1, 0, 1, 2, 6, 2, 6, 5, 7, 5, 7]
6370
sampled = list(sampler)
6471
# Check that the length is doubled
@@ -70,7 +77,7 @@ def test_sampler_with_batch_size(self):
7077

7178
def test_sampler_with_batch_size_and_drop(self):
7279
dataset = ["a", "b", "c", "d", "e", "f", "g"]
73-
sampler = RepeatRandomSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2)
80+
sampler = RepeatSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2)
7481
# Should output something like [4, 3, 4, 3, 0, 1, 0, 1, 2, 6, 2, 6]
7582
sampled = list(sampler)
7683
# Check that the length is doubled
@@ -84,7 +91,7 @@ def test_sampler_with_batch_size_and_drop(self):
8491

8592
def test_sampler_with_mini_repeat_count_and_batch_size_1(self):
8693
dataset = ["a", "b", "c", "d", "e", "f", "g"]
87-
sampler = RepeatRandomSampler(dataset, mini_repeat_count=2, batch_size=3, repeat_count=2)
94+
sampler = RepeatSampler(dataset, mini_repeat_count=2, batch_size=3, repeat_count=2)
8895
# Should output something like [4, 4, 3, 3, 0, 0, 4, 4, 3, 3, 0, 0,
8996
# 1, 1, 2, 2, 6, 6, 1, 1, 2, 2, 6, 6]
9097
sampled = list(sampler)
@@ -100,7 +107,7 @@ def test_sampler_with_mini_repeat_count_and_batch_size_1(self):
100107

101108
def test_sampler_with_mini_repeat_count_and_batch_size_2(self):
102109
dataset = ["a", "b", "c", "d", "e", "f", "g"]
103-
sampler = RepeatRandomSampler(dataset, mini_repeat_count=3, batch_size=2, repeat_count=2)
110+
sampler = RepeatSampler(dataset, mini_repeat_count=3, batch_size=2, repeat_count=2)
104111
# Should output something like [4, 4, 4, 3, 3, 3, 4, 4, 4, 3, 3, 3,
105112
# 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1,
106113
# 2, 2, 2, 6, 6, 6, 2, 2, 2, 6, 6, 6]
@@ -118,7 +125,7 @@ def test_sampler_with_mini_repeat_count_and_batch_size_2(self):
118125

119126
def test_sampler_with_mini_repeat_count_and_batch_size_3(self):
120127
dataset = ["a", "b", "c", "d", "e", "f", "g"]
121-
sampler = RepeatRandomSampler(dataset, mini_repeat_count=2, batch_size=2, repeat_count=3)
128+
sampler = RepeatSampler(dataset, mini_repeat_count=2, batch_size=2, repeat_count=3)
122129
# Should output something like [4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3,
123130
# 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1,
124131
# 2, 2, 6, 6, 2, 2, 6, 6, 2, 2, 6, 6]

trl/scripts/vllm_serve.py

+12
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ class ScriptArguments:
174174
enable_prefix_caching (`bool` or `None`, *optional*, defaults to `None`):
175175
Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support
176176
this feature.
177+
enforce_eager (`bool` or `None`, *optional*, defaults to `None`):
178+
Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the
179+
model in eager mode. If `False` (default behavior), we will use CUDA graph and eager execution in hybrid.
177180
"""
178181

179182
model: str = field(metadata={"help": "Model name or path to load the model from."})
@@ -224,6 +227,14 @@ class ScriptArguments:
224227
"hardware support this feature."
225228
},
226229
)
230+
enforce_eager: Optional[bool] = field(
231+
default=None,
232+
metadata={
233+
"help": "Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always "
234+
"execute the model in eager mode. If `False` (default behavior), we will use CUDA graph and eager "
235+
"execution in hybrid."
236+
},
237+
)
227238

228239

229240
def main(script_args: ScriptArguments):
@@ -250,6 +261,7 @@ def main(script_args: ScriptArguments):
250261
revision=script_args.revision,
251262
tensor_parallel_size=script_args.tensor_parallel_size,
252263
gpu_memory_utilization=script_args.gpu_memory_utilization,
264+
enforce_eager=script_args.enforce_eager,
253265
dtype=script_args.dtype,
254266
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
255267
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.

trl/trainer/grpo_config.py

+6
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class GRPOConfig(TrainingArguments):
5959
improving generation speed. However, disabling this option allows training models that exceed the VRAM
6060
capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
6161
with vLLM generation.
62+
shuffle_dataset (`bool`, *optional*, defaults to `True`):
63+
Whether to shuffle the training dataset.
6264
6365
> Parameters that control generation
6466
@@ -225,6 +227,10 @@ class GRPOConfig(TrainingArguments):
225227
"is not compatible with vLLM generation."
226228
},
227229
)
230+
shuffle_dataset: Optional[bool] = field(
231+
default=True,
232+
metadata={"help": "Whether to shuffle the training dataset."},
233+
)
228234

229235
# Parameters that control generation
230236
temperature: float = field(

trl/trainer/grpo_trainer.py

+34-9
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
7979

8080

81-
class RepeatRandomSampler(Sampler):
81+
class RepeatSampler(Sampler):
8282
"""
8383
Sampler that repeats the indices of a dataset in a structured manner.
8484
@@ -91,6 +91,8 @@ class RepeatRandomSampler(Sampler):
9191
Number of unique indices per batch.
9292
repeat_count (`int`, *optional*, defaults to `1`):
9393
Number of times to repeat the full sampling process.
94+
shuffle (`bool`, *optional*, defaults to `True`):
95+
Whether to shuffle the dataset.
9496
seed (`int` or `None`, *optional*, defaults to `None`):
9597
Random seed for reproducibility (only affects this sampler).
9698
@@ -132,21 +134,28 @@ def __init__(
132134
mini_repeat_count: int,
133135
batch_size: int = 1,
134136
repeat_count: int = 1,
137+
shuffle: bool = True,
135138
seed: Optional[int] = None,
136139
):
137140
self.data_source = data_source
138141
self.mini_repeat_count = mini_repeat_count
139142
self.batch_size = batch_size
140143
self.repeat_count = repeat_count
141144
self.num_samples = len(data_source)
145+
self.shuffle = shuffle
142146
self.seed = seed
143-
self.generator = torch.Generator() # Create a local random generator
144-
if seed is not None:
145-
self.generator.manual_seed(seed)
147+
148+
if shuffle:
149+
self.generator = torch.Generator() # Create a local random generator
150+
if seed is not None:
151+
self.generator.manual_seed(seed)
146152

147153
def __iter__(self):
148-
# E.g., [2, 4, 3, 1, 0, 6, 5] (num_samples = 7)
149-
indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
154+
if self.shuffle:
155+
# E.g., [2, 4, 3, 1, 0, 6, 5] (num_samples = 7)
156+
indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
157+
else:
158+
indexes = list(range(self.num_samples))
150159

151160
# [2, 4, 3, 1, 0, 6, 5]
152161
# -> [[2, 4, 3], [1, 0, 6], [5]] (batch_size = 3)
@@ -166,6 +175,15 @@ def __len__(self) -> int:
166175
return self.num_samples * self.mini_repeat_count * self.repeat_count
167176

168177

178+
class RepeatRandomSampler(RepeatSampler):
179+
def __init__(self, *args, **kwargs):
180+
warnings.warn(
181+
"RepeatRandomSampler is deprecated and will be removed in version 0.18. Use RepeatSampler instead.",
182+
DeprecationWarning,
183+
)
184+
super().__init__(*args, **kwargs)
185+
186+
169187
# torch.nanstd doesn't exist, so we define it here
170188
def nanstd(tensor: torch.Tensor) -> torch.Tensor:
171189
"""
@@ -312,7 +330,9 @@ def reward_func(completions, **kwargs):
312330
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
313331
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
314332
Processing class used to process the data. The padding side must be set to "left". If `None`, the
315-
processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
333+
processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`]. A
334+
padding token, `processing_class.pad_token`, must be set. If the processing class has not set a padding
335+
token, `processing_class.eos_token` will be used as the default.
316336
reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
317337
Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
318338
@@ -418,6 +438,8 @@ def __init__(
418438
# Processing class
419439
if processing_class is None:
420440
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
441+
if processing_class.pad_token is None:
442+
processing_class.pad_token = processing_class.eos_token
421443

422444
# Reward functions
423445
if not isinstance(reward_funcs, list):
@@ -481,6 +503,8 @@ def data_collator(features): # No data collation is needed in GRPO
481503
self.mask_truncated_completions = args.mask_truncated_completions
482504

483505
# Datasets
506+
self.shuffle_dataset = args.shuffle_dataset
507+
484508
if (
485509
isinstance(train_dataset, IterableDataset)
486510
or isinstance(eval_dataset, IterableDataset)
@@ -734,17 +758,18 @@ def _get_train_sampler(self) -> Sampler:
734758
* self.accelerator.num_processes
735759
* self.args.gradient_accumulation_steps
736760
)
737-
return RepeatRandomSampler(
761+
return RepeatSampler(
738762
data_source=self.train_dataset,
739763
mini_repeat_count=self.num_generations,
740764
batch_size=effective_batch_size // self.num_generations,
741765
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
766+
shuffle=self.shuffle_dataset,
742767
seed=self.args.seed,
743768
)
744769

745770
def _get_eval_sampler(self, eval_dataset) -> Sampler:
746771
# See _get_train_sampler for an explanation of the sampler.
747-
return RepeatRandomSampler(
772+
return RepeatSampler(
748773
data_source=eval_dataset,
749774
mini_repeat_count=self.num_generations,
750775
seed=self.args.seed,

0 commit comments

Comments
 (0)