Skip to content

Commit c935f3a

Browse files
authored
fix loss masking and padding (#287)
1 parent 97e6cf4 commit c935f3a

File tree

11 files changed

+58
-30
lines changed

11 files changed

+58
-30
lines changed

fast_llm/data/data/gpt/config.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,6 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
4848
desc="Multiprocessing context. Do not touch.",
4949
hint=FieldHint.expert,
5050
)
51-
truncate_documents: bool = Field(
52-
default=True,
53-
desc=(
54-
"If enabled, documents may be truncated while being packed to fit the sequence length."
55-
"Otherwise, sequences will be padded such that every document lies entirely within a sample"
56-
" (and documents exceeding the sequence length will be skipped altogether)."
57-
),
58-
hint=FieldHint.feature,
59-
)
6051

6152
def _validate(self) -> None:
6253
if not self.datasets:

fast_llm/data/data/gpt/data.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ def setup(
129129
distributed=distributed,
130130
dataset_name=dataset_name,
131131
tokenizer=self._tokenizer,
132-
truncate_documents=self._config.truncate_documents,
133132
)
134133
dataset = self._config.datasets[dataset_name].build_and_sample(sampling)
135134
self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)

fast_llm/data/dataset/gpt/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class GPTSamplingParameters(SamplingParameters):
7575
use_loss_masking_spans: bool = False
7676
use_preference_loss_spans: bool = False
7777
cross_document_attention: bool = True
78+
truncate_documents: bool = True
7879
# How many extra tokens to add to the sequence length.
7980
# This is used to provide labels even for the last tokens in the sequence.
8081
extra_tokens: int = 1
@@ -90,7 +91,6 @@ class GPTSamplingData(SamplingData):
9091
config: GPTSamplingConfig
9192
parameters: GPTSamplingParameters
9293
tokenizer: "Tokenizer"
93-
truncate_documents: bool = True
9494

9595

9696
@config_class(registry=True)

fast_llm/data/dataset/gpt/sampled.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(
8989
self._indexed_dataset = indexed_dataset
9090
self._config = sampling.config
9191
self._parameters = sampling.parameters
92-
self._truncate_documents = sampling.truncate_documents
92+
self._truncate_documents = sampling.parameters.truncate_documents
9393
self._device = torch.device("cuda" if self._config.gpu else "cpu")
9494

9595
if sampling.cache_directory is None:
@@ -144,7 +144,7 @@ def _sample(self) -> None:
144144
" Please make sure Fast-LLM is installed correctly."
145145
)
146146
long_docs_filter = document_sizes > self._parameters.sequence_length + 1
147-
ignored_documents = sum(long_docs_filter)
147+
ignored_documents = long_docs_filter.sum().item()
148148
if ignored_documents:
149149
log_main_rank(
150150
f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._parameters.sequence_length+1} tokens and will be ignored.",
@@ -201,9 +201,10 @@ def _sample(self) -> None:
201201

202202
if self._yaml_path is not None and self._yaml_path.is_file():
203203
loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r"))
204+
# Hack to make sure unshuffled tokens are loaded
205+
if not self._truncate_documents:
206+
yaml_data["unshuffled_tokens"] = loaded_yaml_data["unshuffled_tokens"]
204207
self._load_yaml_data(yaml_data)
205-
if not self._truncate_documents and not self._parameters.use_preference_loss_spans:
206-
del loaded_yaml_data["unshuffled_tokens"]
207208

208209
if loaded_yaml_data != yaml_data:
209210
raise RuntimeError(
@@ -469,7 +470,7 @@ def __getitem__(self, index: int) -> typing.Any:
469470
token_count += padding_size
470471

471472
# Determine if the document belongs to the requested sample.
472-
if token_count + document_size >= token_start:
473+
if token_count + document_size > token_start:
473474
# Determine which part of the document belong to the sample, and add it to the list.
474475
token_start_index_in_document = max(token_start - token_count, 0)
475476
token_end_index_in_document = min(token_end - token_count, document_size)
@@ -487,7 +488,7 @@ def __getitem__(self, index: int) -> typing.Any:
487488
0,
488489
self._parameters.sequence_length + self._parameters.extra_tokens,
489490
)
490-
if span[1] > span[0]:
491+
if span[1] >= span[0]:
491492
loss_masking_spans.append(span)
492493

493494
# Go to the next document.
@@ -547,7 +548,7 @@ def __init__(
547548
):
548549
assert isinstance(sampling, GPTSamplingData)
549550
self._indexed_dataset = indexed_dataset
550-
if not sampling.truncate_documents:
551+
if not sampling.parameters.truncate_documents:
551552
raise NotImplementedError(
552553
"Legacy sampling only supports document truncation. Please use the latest dataset format."
553554
)

fast_llm/functional/cross_entropy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,9 @@ def _fused_cross_entropy_forward_backward(
145145

146146
per_sample_loss = sum_exp_logits.log() - predicted_logits
147147
if loss_mask is not None:
148-
per_sample_loss = per_sample_loss * loss_mask
149-
150-
loss = per_sample_loss.mean()
148+
loss = (per_sample_loss * loss_mask).sum() / torch.maximum(loss_mask.sum(), 1)
149+
else:
150+
loss = per_sample_loss.mean()
151151
if target_format != TargetFormat.labels and group is not None:
152152
all_reduce(loss, op=ReduceOp.MEAN, group=group)
153153

fast_llm/layers/language_model/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class LanguageModelKwargs:
3838
chosen_spans = "chosen_spans"
3939
rejected_spans = "rejected_spans"
4040
loss_mask = "loss_mask"
41+
mask_inputs = "mask_inputs"
4142

4243

4344
@config_class()

fast_llm/layers/language_model/embedding.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(
8484
)
8585

8686
@torch.compile
87-
def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None) -> torch.Tensor:
87+
def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask_inputs: bool) -> torch.Tensor:
8888
Assert.eq(position_ids is not None, self._use_absolute_position_embeddings)
8989
group = self._tensor_space.distributed.tensor_group
9090
if self._parallel_embeddings:
@@ -101,9 +101,17 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None) -> t
101101
input_ = split(input_, group=group, dim=0)
102102
if self._use_absolute_position_embeddings:
103103
position_ids = split(position_ids, group=group, dim=0)
104-
embeddings = torch.embedding(self.word_embeddings_weight, input_)
104+
# handle masked tokens
105+
if mask_inputs:
106+
input_mask = input_ >= 0
107+
masked_input = input_ * input_mask
108+
embeddings = torch.embedding(self.word_embeddings_weight, masked_input)
109+
else:
110+
embeddings = torch.embedding(self.word_embeddings_weight, input_)
105111
if self._use_absolute_position_embeddings:
106112
embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight)
113+
if mask_inputs:
114+
embeddings = embeddings * input_mask.unsqueeze(2)
107115
with set_generator(
108116
self._tensor_space.distributed.tp_generator
109117
if self._sequence_parallel
@@ -125,4 +133,6 @@ def forward(
125133
tensor_name="Embedding output",
126134
dtype=self._residual_dtype,
127135
)
128-
return self._forward(input_, kwargs.get(LanguageModelKwargs.position_ids))
136+
return self._forward(
137+
input_, kwargs.get(LanguageModelKwargs.position_ids), kwargs.get(LanguageModelKwargs.mask_inputs)
138+
)

fast_llm/models/gpt/config.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import logging
23
import typing
34

45
from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class
@@ -17,6 +18,8 @@
1718
from fast_llm.models.gpt.model import GPTInferenceRunner, GPTModel
1819
from fast_llm.models.gpt.trainer import GPTTrainer
1920

21+
logger = logging.getLogger(__name__)
22+
2023

2124
class GPTHuggingfaceCheckpointFormat(CheckpointFormat):
2225
support_optimizer: typing.ClassVar[bool] = False
@@ -56,11 +59,13 @@ class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
5659
class MTPLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
5760
name: typing.ClassVar[str] = "mtp_llama"
5861
trust_remote_code: typing.ClassVar[bool] = True
59-
62+
63+
6064
class DiffusionDreamGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
6165
name: typing.ClassVar[str] = "dream"
6266
trust_remote_code: typing.ClassVar[bool] = True
63-
67+
68+
6469
class DiffusionLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
6570
name: typing.ClassVar[str] = "diffusion_llama"
6671
trust_remote_code: typing.ClassVar[bool] = True
@@ -91,6 +96,15 @@ class GPTBatchConfig(BatchConfig):
9196
desc="Read loss masking spans from the dataset.",
9297
hint=FieldHint.feature,
9398
)
99+
truncate_documents: bool | None = Field(
100+
default=True,
101+
desc=(
102+
"If enabled, documents may be truncated while being packed to fit the sequence length."
103+
"Otherwise, sequences will be padded such that every document lies entirely within a sample"
104+
" (and documents exceeding the sequence length will be skipped altogether)."
105+
),
106+
hint=FieldHint.feature,
107+
)
94108

95109
def _validate(self) -> None:
96110
if self.micro_sequence_length is None:
@@ -223,6 +237,16 @@ def _from_dict(
223237
cls._handle_renamed_field(
224238
default, ("data", "sampling", "use_loss_masking_spans"), ("batch", "use_loss_masking_spans")
225239
)
240+
if "truncate_documents" in default.get("data", {}):
241+
# Backward compatibility for the legacy truncate_documents field.
242+
# TODO v0.x: Remove backward compatibility.
243+
logger.warning(
244+
"`data.truncate_documents` field is deprecated. " "Please use `batch.truncate_documents` instead."
245+
)
246+
assert "truncate_documents" not in default.get("batch", {})
247+
if "batch" not in default:
248+
default["batch"] = {}
249+
default["batch"]["truncate_documents"] = default["data"].pop("truncate_documents")
226250
return super()._from_dict(default, strict, flat)
227251

228252
@classmethod

fast_llm/models/gpt/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def preprocess_meta(
169169
TransformerKwargs.hidden_dims: hidden_dims,
170170
TransformerKwargs.sequence_length: sequence_length,
171171
TransformerKwargs.sequence_q_dim: sequence_q_dim,
172+
LanguageModelKwargs.mask_inputs: not batch_meta.truncate_documents,
172173
}
173174

174175
sequence_k_pasts = range(
@@ -302,7 +303,7 @@ def preprocess(
302303
if batch.loss_masking_spans is not None:
303304
# avoid changing input tokens
304305
labels = labels.clone()
305-
for i, spans in enumerate(batch.loss_masking_spans):
306+
for idx, spans in enumerate(batch.loss_masking_spans):
306307
if not spans.numel():
307308
continue
308309
valid_spans = spans[
@@ -316,9 +317,9 @@ def preprocess(
316317
loss_mask = torch.ones_like(labels, dtype=torch.bool)
317318
for start, end in valid_spans:
318319
if sequence_first:
319-
loss_mask[start : end + 1, i] = False
320+
loss_mask[start : end + 1, idx] = False
320321
else:
321-
loss_mask[i, start : end + 1] = False
322+
loss_mask[idx, start : end + 1] = False
322323
if self._config.distillation_model is not None:
323324
kwargs[LanguageModelKwargs.loss_mask] = loss_mask
324325
labels = torch.where(loss_mask, labels, -100)

fast_llm/models/gpt/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def _get_sampling_parameters(
2929
"use_loss_masking_spans": self._config.batch.use_loss_masking_spans,
3030
"use_preference_loss_spans": self._config.model.base_model.enable_dpo,
3131
"cross_document_attention": self._config.batch.cross_document_attention,
32+
"truncate_documents": self._config.batch.truncate_documents,
3233
"extra_tokens": self._config.model.base_model.prediction_heads,
3334
}
3435
)

0 commit comments

Comments
 (0)