Skip to content

Commit 3ac976b

Browse files
tobyzl2Toby Liang
andauthored
Adding DPO Implementation (#223)
Co-authored-by: Toby Liang <toby.liang@serivcenow.com>
1 parent 98d3969 commit 3ac976b

File tree

18 files changed

+696
-39
lines changed

18 files changed

+696
-39
lines changed

fast_llm/data/data/gpt/data.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,29 @@ class GPTBatch:
3232
token_ids: torch.Tensor
3333
loss_masking_spans: list[torch.Tensor] | None = None
3434
sequence_lengths: list[torch.Tensor] | None = None
35+
chosen_spans: list[torch.Tensor] | None = None
36+
rejected_spans: list[torch.Tensor] | None = None
3537

3638

3739
def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch:
3840
stacked_ids = np.stack([sample.token_ids for sample in batch])
3941
stacked_spans = None
4042
sequence_lengths = None
43+
stacked_chosen_spans = None
44+
stacked_rejected_spans = None
4145
if sampling_parameters.use_loss_masking_spans:
4246
stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch]
47+
if sampling_parameters.use_preference_loss_spans:
48+
stacked_chosen_spans = [torch.from_numpy(sample.chosen_span) for sample in batch]
49+
stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch]
4350
if not sampling_parameters.cross_document_attention:
4451
sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch]
4552
return GPTBatch(
46-
token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths
53+
token_ids=torch.from_numpy(stacked_ids),
54+
loss_masking_spans=stacked_spans,
55+
sequence_lengths=sequence_lengths,
56+
chosen_spans=stacked_chosen_spans,
57+
rejected_spans=stacked_rejected_spans,
4758
)
4859

4960

@@ -149,6 +160,7 @@ def get_iterator(
149160
sampling_parameters = self._sampling_parameters[dataset_name]
150161
Assert.in_range_incl(batch_config.sequence_length, 1, sampling_parameters.sequence_length)
151162
log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...")
163+
152164
return iter(
153165
torch.utils.data.DataLoader(
154166
self._datasets[dataset_name], # noqa

fast_llm/data/dataset/gpt/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class GPTSamplingParameters(SamplingParameters):
7373
sequence_length: int
7474
vocab_size: int
7575
use_loss_masking_spans: bool = False
76+
use_preference_loss_spans: bool = False
7677
cross_document_attention: bool = True
7778
# How many extra tokens to add to the sequence length.
7879
# This is used to provide labels even for the last tokens in the sequence.

fast_llm/data/dataset/gpt/fim.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,11 @@ def __init__(
2020
):
2121
if sampling.parameters.use_loss_masking_spans:
2222
raise NotImplementedError("FIM is currently not compatible with loss masking.")
23+
if sampling.parameters.use_preference_loss_spans:
24+
raise NotImplementedError("FIM is currently not compatible with preference loss masking.")
2325
self._config = config
2426
self._dataset = dataset
27+
2528
self._seed = sampling.config.seed
2629
self._tokenizer = sampling.tokenizer
2730
if self._tokenizer is None:

fast_llm/data/dataset/gpt/memmap.py

Lines changed: 106 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,16 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None
3434
self._name = name
3535
self._prefix = pathlib.Path(prefix)
3636
self._has_spans = 0
37+
self._has_preference_spans = False
3738

3839
with self._prefix.with_suffix(".idx").open("rb") as stream:
3940
Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}")
4041
self._version = struct.unpack("<Q", stream.read(8))[0]
41-
assert self._version in [1, 2], f"Unsupported version for gpt_memmap dataset: {self._version}."
42-
if self._version == 2:
42+
assert self._version in [1, 2, 3], f"Unsupported version for gpt_memmap dataset: {self._version}."
43+
if self._version >= 2:
4344
self._has_spans = struct.unpack("<B", stream.read(1))[0]
45+
if self._version >= 3:
46+
self._has_preference_spans = struct.unpack("<B", stream.read(1))[0]
4447

4548
self._dtype = MEMMAP_DTYPES[struct.unpack("<B", stream.read(1))[0]].numpy
4649
self._num_documents = struct.unpack("<Q", stream.read(8))[0]
@@ -52,18 +55,23 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None
5255

5356
self._index_bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".idx"), mode="r", order="C")
5457
self._index_bin_buffer = memoryview(self._index_bin_buffer_mmap)
58+
59+
# read document sizes
5560
self._document_sizes = np.frombuffer(
5661
self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset
5762
)
63+
64+
# read pointers
5865
self._pointers = np.frombuffer(
5966
self._index_bin_buffer,
6067
dtype=np.int64,
6168
count=self._num_documents,
6269
offset=offset + self._document_sizes.nbytes,
6370
)
6471

72+
# read spans
6573
self._spans = None
66-
if self._has_spans and self._version == 2:
74+
if self._has_spans and self._version >= 2:
6775
self._spans = []
6876
self._num_spans = np.frombuffer(
6977
self._index_bin_buffer,
@@ -83,6 +91,36 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None
8391
).reshape(-1, 2)
8492
)
8593

94+
# read preference spans
95+
self._chosen_spans = None
96+
self._rejected_spans = None
97+
if self._has_preference_spans and self._version >= 3:
98+
self._chosen_spans = []
99+
self._rejected_spans = []
100+
chosen_span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes
101+
for idx in range(self._num_documents):
102+
self._chosen_spans.append(
103+
np.frombuffer(
104+
self._index_bin_buffer,
105+
dtype=np.int32,
106+
count=2,
107+
offset=chosen_span_offset + idx * 2 * np.dtype(np.int32).itemsize,
108+
)
109+
)
110+
111+
rejected_span_offset = (
112+
offset + self._document_sizes.nbytes + self._pointers.nbytes + np.array(self._chosen_spans).nbytes
113+
)
114+
for idx in range(self._num_documents):
115+
self._rejected_spans.append(
116+
np.frombuffer(
117+
self._index_bin_buffer,
118+
dtype=np.int32,
119+
count=2,
120+
offset=rejected_span_offset + idx * 2 * np.dtype(np.int32).itemsize,
121+
)
122+
)
123+
86124
self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C")
87125
self._bin_buffer = memoryview(self._bin_buffer_mmap)
88126

@@ -105,7 +143,12 @@ def __del__(self):
105143
del self._index_bin_buffer_mmap
106144

107145
def get(
108-
self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False
146+
self,
147+
idx: int,
148+
offset: int = 0,
149+
length: int | None = None,
150+
use_loss_masking_spans: bool = False,
151+
use_preference_loss_spans: bool = False,
109152
) -> GPTSample:
110153
token_ids = np.frombuffer(
111154
self._bin_buffer,
@@ -116,13 +159,53 @@ def get(
116159
sample_spans = None
117160
if use_loss_masking_spans and self._spans is not None:
118161
sample_spans = self._spans[idx]
119-
# adjust the spans for the offset and length
162+
163+
# filter spans that are outside the range of the selected tokens in the document
120164
sample_spans = sample_spans[
121165
(sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset)
122166
]
123-
sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset
167+
168+
# subtract by offset to normalize span boundaries
169+
sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset # offset
124170
sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset
125-
return GPTSample(token_ids=token_ids, loss_masking_spans=sample_spans)
171+
172+
chosen_span = None
173+
rejected_span = None
174+
175+
if use_preference_loss_spans:
176+
if not self._has_preference_spans:
177+
raise ValueError("No preference spans found in memmap dataset.")
178+
elif self._has_preference_spans and self._chosen_spans is None:
179+
raise ValueError("Failed to read chosen spans from memmap dataset.")
180+
elif self._has_preference_spans and self._rejected_spans is None:
181+
raise ValueError("Failed to read rejected spans from memmap dataset.")
182+
else:
183+
chosen_span = self._chosen_spans[idx]
184+
185+
# filter spans that are outside the range of the selected tokens in the document
186+
chosen_span = chosen_span[(chosen_span[0] < offset + len(token_ids)) & (chosen_span[1] >= offset)][0]
187+
188+
# subtract by offset to normalize span boundaries
189+
chosen_span[0] = np.maximum(chosen_span[0], offset) - offset # offset
190+
chosen_span[1] = np.minimum(chosen_span[1], offset + len(token_ids) - 1) - offset
191+
192+
rejected_span = self._rejected_spans[idx]
193+
194+
# filter spans that are outside the range of the selected tokens in the document
195+
rejected_span = rejected_span[
196+
(rejected_span[0] < offset + len(token_ids)) & (rejected_span[1] >= offset)
197+
][0]
198+
199+
# subtract by offset to normalize span boundaries
200+
rejected_span[0] = np.maximum(rejected_span[0], offset) - offset # offset
201+
rejected_span[1] = np.minimum(rejected_span[1], offset + len(token_ids) - 1) - offset
202+
203+
return GPTSample(
204+
token_ids=token_ids,
205+
loss_masking_spans=sample_spans,
206+
chosen_span=chosen_span,
207+
rejected_span=rejected_span,
208+
)
126209

127210
@property
128211
def name(self) -> str:
@@ -157,6 +240,8 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
157240
# number of spans for each document
158241
num_spans = []
159242
spans = []
243+
chosen_spans = []
244+
rejected_spans = []
160245

161246
prefix = pathlib.Path(prefix)
162247
prefix.parent.mkdir(parents=True, exist_ok=True)
@@ -182,6 +267,10 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
182267
if document.loss_masking_spans is not None:
183268
num_spans.append(len(document.loss_masking_spans))
184269
spans.append(document.loss_masking_spans)
270+
if document.chosen_span is not None:
271+
chosen_spans.append(document.chosen_span)
272+
if document.rejected_span is not None:
273+
rejected_spans.append(document.rejected_span)
185274
offset += doc_length * np.dtype(dtype).itemsize
186275
num_documents += 1
187276

@@ -193,15 +282,20 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
193282
spans = np.vstack(spans, dtype=np.int32)
194283
else:
195284
spans = np.array(spans, dtype=np.int32)
285+
chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2)
286+
rejected_spans = np.array(rejected_spans, dtype=np.int32).reshape(-1, 2)
196287

197288
# Write the index file (.idx)
198289
with prefix.with_suffix(".idx").open("wb") as idx_stream:
199290
idx_stream.write(MEMMAP_INDEX_HEADER)
200291
# Indicates the version
201292
# Version 2 optionally adds loss-masking spans
202-
idx_stream.write(struct.pack("<Q", 2))
293+
# Version 3 optionally adds chosen/rejected spans
294+
idx_stream.write(struct.pack("<Q", 3))
203295
# Flag to indicate whether loss-masking spans are present
204296
idx_stream.write(struct.pack("<B", 1 if spans.size > 0 else 0))
297+
# Flag to indicate whether preference loss-masking spans are present
298+
idx_stream.write(struct.pack("<B", 1 if chosen_spans.size > 0 and rejected_spans.size > 0 else 0))
205299
# Data type
206300
idx_stream.write(struct.pack("<B", MEMMAP_DTYPES_INV[DataType.from_numpy(dtype.type)]))
207301
# "Number of sequences", same as documents in our case
@@ -216,5 +310,9 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
216310
idx_stream.write(num_spans.tobytes(order="C"))
217311
# Span indices for each document
218312
idx_stream.write(spans.tobytes(order="C"))
313+
# Chosen indices for each document
314+
idx_stream.write(chosen_spans.tobytes(order="C"))
315+
# Rejected indices for each document
316+
idx_stream.write(rejected_spans.tobytes(order="C"))
219317
# Document indices, unused but needed for compatibility with Megatron-LM
220318
idx_stream.write(np.arange(num_documents + 1, dtype=np.int64).tobytes(order="C"))

fast_llm/data/dataset/gpt/sampled.py

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
class GPTSample:
3131
token_ids: np.ndarray
3232
loss_masking_spans: np.ndarray | None = None
33+
chosen_span: np.ndarray | None = None
34+
rejected_span: np.ndarray | None = None
3335
sequence_lengths: np.ndarray | None = None
3436

3537

@@ -112,6 +114,14 @@ def __init__(
112114
self._token_cumsum_shuffled = MemmapArray(base_path.with_name(base_path.name + "_shuffled_cumsum.npy"))
113115
self._token_cumsum_unshuffled = MemmapArray(base_path.with_name(base_path.name + "_unshuffled_cumsum.npy"))
114116
self._yaml_path = base_path.with_suffix(".yaml")
117+
118+
# keep document sizes and len filtered docs for preference loss masking
119+
if self._parameters.use_preference_loss_spans:
120+
self._document_sizes = MemmapArray(base_path.with_name(base_path.name + "_doc_sizes.npy"))
121+
self._doc_length_filtered_indicies = MemmapArray(
122+
base_path.with_name(base_path.name + "_doc_length_filtered_indices.npy")
123+
)
124+
115125
# Sample or validate the dataset of a given rank.
116126
if sampling.distributed.config.rank == sampling.get_next_rank():
117127
self._sample()
@@ -145,10 +155,14 @@ def _sample(self) -> None:
145155
raise RuntimeError(
146156
f" > No documents shorter than {self._parameters.sequence_length+1} tokens found in dataset {self._indexed_dataset.name}."
147157
)
158+
148159
# We produce sequences of length `self._sequence_length + extra_tokens` so the last token has a label for all prediction heads,
149160
# but in case of truncations we also include those last labels in the following sample,
150161
# so we need `sequence_length * num_samples + extra_tokens` tokens in total.
151-
if self._truncate_documents:
162+
if self._parameters.use_preference_loss_spans:
163+
documents_per_epoch = (~long_docs_filter).sum().item()
164+
num_epochs = math.ceil(self._parameters.num_samples / documents_per_epoch)
165+
elif self._truncate_documents:
152166
num_epochs = math.ceil(
153167
(self._parameters.sequence_length * self._parameters.num_samples + self._parameters.extra_tokens)
154168
/ tokens_per_epoch
@@ -187,8 +201,8 @@ def _sample(self) -> None:
187201

188202
if self._yaml_path is not None and self._yaml_path.is_file():
189203
loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r"))
190-
self._load_yaml_data(loaded_yaml_data)
191-
if not self._truncate_documents:
204+
self._load_yaml_data(yaml_data)
205+
if not self._truncate_documents and not self._parameters.use_preference_loss_spans:
192206
del loaded_yaml_data["unshuffled_tokens"]
193207

194208
if loaded_yaml_data != yaml_data:
@@ -251,6 +265,24 @@ def _sample(self) -> None:
251265
else:
252266
raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}")
253267

268+
if self._parameters.use_preference_loss_spans:
269+
yaml_data["unshuffled_tokens"] = 0 # not used, ignore
270+
271+
# index of all documents less than seq length long
272+
doc_length_filtered_indicies = torch.nonzero(~long_docs_filter, as_tuple=True)[0]
273+
self._doc_length_filtered_indicies.save(doc_length_filtered_indicies.numpy(force=self._config.gpu))
274+
275+
# apply shuffling on doc_length_filtered_indicies
276+
if shuffled_epochs > 0:
277+
self._document_shuffling.save(
278+
document_shuffling[: self._parameters.num_samples].numpy(force=self._config.gpu)
279+
)
280+
self._document_sizes.save(document_sizes.numpy(force=self._config.gpu))
281+
if self._yaml_path is not None:
282+
self._yaml_path.parent.mkdir(parents=True, exist_ok=True)
283+
yaml.safe_dump(yaml_data, self._yaml_path.open("w"))
284+
return
285+
254286
# To get a sample on the fly we need to know where it begins,
255287
# and this is a non-trivial information because the documents have variable length.
256288
# The starting point `(document[idx], token[idx])` corresponds to the `(idx * sequence_length)` th token, i.e.
@@ -349,6 +381,40 @@ def __getitem__(self, index: int) -> typing.Any:
349381
The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`).
350382
"""
351383
self._lazy_load()
384+
385+
if self._parameters.use_preference_loss_spans:
386+
if index < self._unshuffled_documents:
387+
document_index = self._doc_length_filtered_indicies[index % self._documents_per_epoch]
388+
else:
389+
document_index = self._doc_length_filtered_indicies[
390+
self._document_shuffling[index - self._unshuffled_documents].item()
391+
]
392+
393+
sample = self._indexed_dataset.get(
394+
document_index,
395+
offset=0,
396+
length=self._document_sizes[document_index],
397+
use_loss_masking_spans=self._parameters.use_loss_masking_spans,
398+
use_preference_loss_spans=self._parameters.use_preference_loss_spans,
399+
)
400+
401+
chosen_span_end = sample.chosen_span[1] + 1
402+
sequence_lengths = [
403+
chosen_span_end,
404+
len(sample.token_ids) - chosen_span_end,
405+
]
406+
407+
# compute padding size
408+
padding = np.full((self._parameters.sequence_length + 1,), 0)
409+
padding[: len(sample.token_ids)] = sample.token_ids
410+
sequence_lengths.append(self._parameters.sequence_length - len(sample.token_ids))
411+
sample.token_ids = padding
412+
413+
if not self._parameters.cross_document_attention:
414+
sample.sequence_lengths = np.array(sequence_lengths)
415+
416+
return sample
417+
352418
# tokens at the boundary are included in only one sample when we pack without truncations
353419
# in case of packing with truncations, the last token from the previous sample is also the first token of the next sample
354420
sample_length = (
@@ -454,7 +520,9 @@ def _lazy_load(self):
454520
def _load_yaml_data(self, data: dict[str, typing.Any]) -> None:
455521
self._documents_per_epoch = data["dataset"]["documents_per_epoch"]
456522

457-
if "unshuffled_tokens" not in data:
523+
if self._parameters.use_preference_loss_spans:
524+
data["unshuffled_tokens"] = 0 # not used, ignore
525+
elif "unshuffled_tokens" not in data:
458526
# Backward compatibility
459527
# TODO v0.x: Remove
460528
assert self._truncate_documents
@@ -485,6 +553,8 @@ def __init__(
485553
)
486554
self._config = sampling.config
487555
self._parameters = sampling.parameters
556+
if self._parameters.use_preference_loss_spans:
557+
raise NotImplementedError("Legacy sampling does not support preference loss masking.")
488558

489559
if sampling.cache_directory is None:
490560
log_main_rank(

0 commit comments

Comments
 (0)