@@ -34,13 +34,16 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None
34
34
self ._name = name
35
35
self ._prefix = pathlib .Path (prefix )
36
36
self ._has_spans = 0
37
+ self ._has_preference_spans = False
37
38
38
39
with self ._prefix .with_suffix (".idx" ).open ("rb" ) as stream :
39
40
Assert .eq (stream .read (9 ), MEMMAP_INDEX_HEADER , msg = f"File: { stream .name } " )
40
41
self ._version = struct .unpack ("<Q" , stream .read (8 ))[0 ]
41
- assert self ._version in [1 , 2 ], f"Unsupported version for gpt_memmap dataset: { self ._version } ."
42
- if self ._version = = 2 :
42
+ assert self ._version in [1 , 2 , 3 ], f"Unsupported version for gpt_memmap dataset: { self ._version } ."
43
+ if self ._version > = 2 :
43
44
self ._has_spans = struct .unpack ("<B" , stream .read (1 ))[0 ]
45
+ if self ._version >= 3 :
46
+ self ._has_preference_spans = struct .unpack ("<B" , stream .read (1 ))[0 ]
44
47
45
48
self ._dtype = MEMMAP_DTYPES [struct .unpack ("<B" , stream .read (1 ))[0 ]].numpy
46
49
self ._num_documents = struct .unpack ("<Q" , stream .read (8 ))[0 ]
@@ -52,18 +55,23 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None
52
55
53
56
self ._index_bin_buffer_mmap = np .memmap (self ._prefix .with_suffix (".idx" ), mode = "r" , order = "C" )
54
57
self ._index_bin_buffer = memoryview (self ._index_bin_buffer_mmap )
58
+
59
+ # read document sizes
55
60
self ._document_sizes = np .frombuffer (
56
61
self ._index_bin_buffer , dtype = np .int32 , count = self ._num_documents , offset = offset
57
62
)
63
+
64
+ # read pointers
58
65
self ._pointers = np .frombuffer (
59
66
self ._index_bin_buffer ,
60
67
dtype = np .int64 ,
61
68
count = self ._num_documents ,
62
69
offset = offset + self ._document_sizes .nbytes ,
63
70
)
64
71
72
+ # read spans
65
73
self ._spans = None
66
- if self ._has_spans and self ._version = = 2 :
74
+ if self ._has_spans and self ._version > = 2 :
67
75
self ._spans = []
68
76
self ._num_spans = np .frombuffer (
69
77
self ._index_bin_buffer ,
@@ -83,6 +91,36 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None
83
91
).reshape (- 1 , 2 )
84
92
)
85
93
94
+ # read preference spans
95
+ self ._chosen_spans = None
96
+ self ._rejected_spans = None
97
+ if self ._has_preference_spans and self ._version >= 3 :
98
+ self ._chosen_spans = []
99
+ self ._rejected_spans = []
100
+ chosen_span_offset = offset + self ._document_sizes .nbytes + self ._pointers .nbytes
101
+ for idx in range (self ._num_documents ):
102
+ self ._chosen_spans .append (
103
+ np .frombuffer (
104
+ self ._index_bin_buffer ,
105
+ dtype = np .int32 ,
106
+ count = 2 ,
107
+ offset = chosen_span_offset + idx * 2 * np .dtype (np .int32 ).itemsize ,
108
+ )
109
+ )
110
+
111
+ rejected_span_offset = (
112
+ offset + self ._document_sizes .nbytes + self ._pointers .nbytes + np .array (self ._chosen_spans ).nbytes
113
+ )
114
+ for idx in range (self ._num_documents ):
115
+ self ._rejected_spans .append (
116
+ np .frombuffer (
117
+ self ._index_bin_buffer ,
118
+ dtype = np .int32 ,
119
+ count = 2 ,
120
+ offset = rejected_span_offset + idx * 2 * np .dtype (np .int32 ).itemsize ,
121
+ )
122
+ )
123
+
86
124
self ._bin_buffer_mmap = np .memmap (self ._prefix .with_suffix (".bin" ), mode = "r" , order = "C" )
87
125
self ._bin_buffer = memoryview (self ._bin_buffer_mmap )
88
126
@@ -105,7 +143,12 @@ def __del__(self):
105
143
del self ._index_bin_buffer_mmap
106
144
107
145
def get (
108
- self , idx : int , offset : int = 0 , length : int | None = None , use_loss_masking_spans : bool = False
146
+ self ,
147
+ idx : int ,
148
+ offset : int = 0 ,
149
+ length : int | None = None ,
150
+ use_loss_masking_spans : bool = False ,
151
+ use_preference_loss_spans : bool = False ,
109
152
) -> GPTSample :
110
153
token_ids = np .frombuffer (
111
154
self ._bin_buffer ,
@@ -116,13 +159,53 @@ def get(
116
159
sample_spans = None
117
160
if use_loss_masking_spans and self ._spans is not None :
118
161
sample_spans = self ._spans [idx ]
119
- # adjust the spans for the offset and length
162
+
163
+ # filter spans that are outside the range of the selected tokens in the document
120
164
sample_spans = sample_spans [
121
165
(sample_spans [:, 0 ] < offset + len (token_ids )) & (sample_spans [:, 1 ] >= offset )
122
166
]
123
- sample_spans [:, 0 ] = np .maximum (sample_spans [:, 0 ], offset ) - offset
167
+
168
+ # subtract by offset to normalize span boundaries
169
+ sample_spans [:, 0 ] = np .maximum (sample_spans [:, 0 ], offset ) - offset # offset
124
170
sample_spans [:, 1 ] = np .minimum (sample_spans [:, 1 ], offset + len (token_ids ) - 1 ) - offset
125
- return GPTSample (token_ids = token_ids , loss_masking_spans = sample_spans )
171
+
172
+ chosen_span = None
173
+ rejected_span = None
174
+
175
+ if use_preference_loss_spans :
176
+ if not self ._has_preference_spans :
177
+ raise ValueError ("No preference spans found in memmap dataset." )
178
+ elif self ._has_preference_spans and self ._chosen_spans is None :
179
+ raise ValueError ("Failed to read chosen spans from memmap dataset." )
180
+ elif self ._has_preference_spans and self ._rejected_spans is None :
181
+ raise ValueError ("Failed to read rejected spans from memmap dataset." )
182
+ else :
183
+ chosen_span = self ._chosen_spans [idx ]
184
+
185
+ # filter spans that are outside the range of the selected tokens in the document
186
+ chosen_span = chosen_span [(chosen_span [0 ] < offset + len (token_ids )) & (chosen_span [1 ] >= offset )][0 ]
187
+
188
+ # subtract by offset to normalize span boundaries
189
+ chosen_span [0 ] = np .maximum (chosen_span [0 ], offset ) - offset # offset
190
+ chosen_span [1 ] = np .minimum (chosen_span [1 ], offset + len (token_ids ) - 1 ) - offset
191
+
192
+ rejected_span = self ._rejected_spans [idx ]
193
+
194
+ # filter spans that are outside the range of the selected tokens in the document
195
+ rejected_span = rejected_span [
196
+ (rejected_span [0 ] < offset + len (token_ids )) & (rejected_span [1 ] >= offset )
197
+ ][0 ]
198
+
199
+ # subtract by offset to normalize span boundaries
200
+ rejected_span [0 ] = np .maximum (rejected_span [0 ], offset ) - offset # offset
201
+ rejected_span [1 ] = np .minimum (rejected_span [1 ], offset + len (token_ids ) - 1 ) - offset
202
+
203
+ return GPTSample (
204
+ token_ids = token_ids ,
205
+ loss_masking_spans = sample_spans ,
206
+ chosen_span = chosen_span ,
207
+ rejected_span = rejected_span ,
208
+ )
126
209
127
210
@property
128
211
def name (self ) -> str :
@@ -157,6 +240,8 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
157
240
# number of spans for each document
158
241
num_spans = []
159
242
spans = []
243
+ chosen_spans = []
244
+ rejected_spans = []
160
245
161
246
prefix = pathlib .Path (prefix )
162
247
prefix .parent .mkdir (parents = True , exist_ok = True )
@@ -182,6 +267,10 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
182
267
if document .loss_masking_spans is not None :
183
268
num_spans .append (len (document .loss_masking_spans ))
184
269
spans .append (document .loss_masking_spans )
270
+ if document .chosen_span is not None :
271
+ chosen_spans .append (document .chosen_span )
272
+ if document .rejected_span is not None :
273
+ rejected_spans .append (document .rejected_span )
185
274
offset += doc_length * np .dtype (dtype ).itemsize
186
275
num_documents += 1
187
276
@@ -193,15 +282,20 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
193
282
spans = np .vstack (spans , dtype = np .int32 )
194
283
else :
195
284
spans = np .array (spans , dtype = np .int32 )
285
+ chosen_spans = np .array (chosen_spans , dtype = np .int32 ).reshape (- 1 , 2 )
286
+ rejected_spans = np .array (rejected_spans , dtype = np .int32 ).reshape (- 1 , 2 )
196
287
197
288
# Write the index file (.idx)
198
289
with prefix .with_suffix (".idx" ).open ("wb" ) as idx_stream :
199
290
idx_stream .write (MEMMAP_INDEX_HEADER )
200
291
# Indicates the version
201
292
# Version 2 optionally adds loss-masking spans
202
- idx_stream .write (struct .pack ("<Q" , 2 ))
293
+ # Version 3 optionally adds chosen/rejected spans
294
+ idx_stream .write (struct .pack ("<Q" , 3 ))
203
295
# Flag to indicate whether loss-masking spans are present
204
296
idx_stream .write (struct .pack ("<B" , 1 if spans .size > 0 else 0 ))
297
+ # Flag to indicate whether preference loss-masking spans are present
298
+ idx_stream .write (struct .pack ("<B" , 1 if chosen_spans .size > 0 and rejected_spans .size > 0 else 0 ))
205
299
# Data type
206
300
idx_stream .write (struct .pack ("<B" , MEMMAP_DTYPES_INV [DataType .from_numpy (dtype .type )]))
207
301
# "Number of sequences", same as documents in our case
@@ -216,5 +310,9 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
216
310
idx_stream .write (num_spans .tobytes (order = "C" ))
217
311
# Span indices for each document
218
312
idx_stream .write (spans .tobytes (order = "C" ))
313
+ # Chosen indices for each document
314
+ idx_stream .write (chosen_spans .tobytes (order = "C" ))
315
+ # Rejected indices for each document
316
+ idx_stream .write (rejected_spans .tobytes (order = "C" ))
219
317
# Document indices, unused but needed for compatibility with Megatron-LM
220
318
idx_stream .write (np .arange (num_documents + 1 , dtype = np .int64 ).tobytes (order = "C" ))
0 commit comments