24
24
from transformers .utils import is_peft_available
25
25
26
26
from trl import GRPOConfig , GRPOTrainer
27
- from trl .trainer .grpo_trainer import RepeatRandomSampler
27
+ from trl .trainer .grpo_trainer import RepeatSampler
28
28
29
29
from .testing_utils import require_vllm
30
30
36
36
class RepeatRandomSamplerTester (unittest .TestCase ):
37
37
def test_sampler (self ):
38
38
dataset = ["a" , "b" , "c" , "d" , "e" , "f" , "g" ]
39
- sampler = RepeatRandomSampler (dataset , mini_repeat_count = 2 )
39
+ sampler = RepeatSampler (dataset , mini_repeat_count = 2 )
40
40
# Should output something like [4, 4, 3, 3, 0, 0, 1, 1, 2, 2, 6, 6, 5, 5]
41
41
sampled = list (sampler )
42
42
# Check that the length is doubled
@@ -46,9 +46,16 @@ def test_sampler(self):
46
46
# Check that each element is repeated twice
47
47
assert all (sampled [i ] == sampled [i + 1 ] for i in range (0 , len (sampled ), 2 ))
48
48
49
+ def test_sampler_no_shuffle (self ):
50
+ dataset = ["a" , "b" , "c" , "d" , "e" , "f" , "g" ]
51
+ sampler = RepeatSampler (dataset , mini_repeat_count = 2 , shuffle = False )
52
+ sampled = list (sampler )
53
+ expected = [0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 ]
54
+ self .assertEqual (sampled , expected )
55
+
49
56
def test_sampler_no_repeat (self ):
50
57
dataset = ["a" , "b" , "c" , "d" , "e" , "f" , "g" ]
51
- sampler = RepeatRandomSampler (dataset , mini_repeat_count = 1 )
58
+ sampler = RepeatSampler (dataset , mini_repeat_count = 1 )
52
59
# Should output something like [4, 3, 0, 1, 2, 6, 5]
53
60
sampled = list (sampler )
54
61
# Check that the length is the same
@@ -58,7 +65,7 @@ def test_sampler_no_repeat(self):
58
65
59
66
def test_sampler_with_batch_size (self ):
60
67
dataset = ["a" , "b" , "c" , "d" , "e" , "f" , "g" , "h" ]
61
- sampler = RepeatRandomSampler (dataset , mini_repeat_count = 1 , batch_size = 2 , repeat_count = 2 )
68
+ sampler = RepeatSampler (dataset , mini_repeat_count = 1 , batch_size = 2 , repeat_count = 2 )
62
69
# Should output something like [4, 3, 4, 3, 0, 1, 0, 1, 2, 6, 2, 6, 5, 7, 5, 7]
63
70
sampled = list (sampler )
64
71
# Check that the length is doubled
@@ -70,7 +77,7 @@ def test_sampler_with_batch_size(self):
70
77
71
78
def test_sampler_with_batch_size_and_drop (self ):
72
79
dataset = ["a" , "b" , "c" , "d" , "e" , "f" , "g" ]
73
- sampler = RepeatRandomSampler (dataset , mini_repeat_count = 1 , batch_size = 2 , repeat_count = 2 )
80
+ sampler = RepeatSampler (dataset , mini_repeat_count = 1 , batch_size = 2 , repeat_count = 2 )
74
81
# Should output something like [4, 3, 4, 3, 0, 1, 0, 1, 2, 6, 2, 6]
75
82
sampled = list (sampler )
76
83
# Check that the length is doubled
@@ -84,7 +91,7 @@ def test_sampler_with_batch_size_and_drop(self):
84
91
85
92
def test_sampler_with_mini_repeat_count_and_batch_size_1 (self ):
86
93
dataset = ["a" , "b" , "c" , "d" , "e" , "f" , "g" ]
87
- sampler = RepeatRandomSampler (dataset , mini_repeat_count = 2 , batch_size = 3 , repeat_count = 2 )
94
+ sampler = RepeatSampler (dataset , mini_repeat_count = 2 , batch_size = 3 , repeat_count = 2 )
88
95
# Should output something like [4, 4, 3, 3, 0, 0, 4, 4, 3, 3, 0, 0,
89
96
# 1, 1, 2, 2, 6, 6, 1, 1, 2, 2, 6, 6]
90
97
sampled = list (sampler )
@@ -100,7 +107,7 @@ def test_sampler_with_mini_repeat_count_and_batch_size_1(self):
100
107
101
108
def test_sampler_with_mini_repeat_count_and_batch_size_2 (self ):
102
109
dataset = ["a" , "b" , "c" , "d" , "e" , "f" , "g" ]
103
- sampler = RepeatRandomSampler (dataset , mini_repeat_count = 3 , batch_size = 2 , repeat_count = 2 )
110
+ sampler = RepeatSampler (dataset , mini_repeat_count = 3 , batch_size = 2 , repeat_count = 2 )
104
111
# Should output something like [4, 4, 4, 3, 3, 3, 4, 4, 4, 3, 3, 3,
105
112
# 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1,
106
113
# 2, 2, 2, 6, 6, 6, 2, 2, 2, 6, 6, 6]
@@ -118,7 +125,7 @@ def test_sampler_with_mini_repeat_count_and_batch_size_2(self):
118
125
119
126
def test_sampler_with_mini_repeat_count_and_batch_size_3 (self ):
120
127
dataset = ["a" , "b" , "c" , "d" , "e" , "f" , "g" ]
121
- sampler = RepeatRandomSampler (dataset , mini_repeat_count = 2 , batch_size = 2 , repeat_count = 3 )
128
+ sampler = RepeatSampler (dataset , mini_repeat_count = 2 , batch_size = 2 , repeat_count = 3 )
122
129
# Should output something like [4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3,
123
130
# 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1,
124
131
# 2, 2, 6, 6, 2, 2, 6, 6, 2, 2, 6, 6]
0 commit comments