Skip to content

fix loss masking and padding #287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Jun 27, 2025
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions fast_llm/data/data/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,6 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
desc="Multiprocessing context. Do not touch.",
hint=FieldHint.expert,
)
truncate_documents: bool = Field(
default=True,
desc=(
"If enabled, documents may be truncated while being packed to fit the sequence length."
"Otherwise, sequences will be padded such that every document lies entirely within a sample"
" (and documents exceeding the sequence length will be skipped altogether)."
),
hint=FieldHint.feature,
)

def _validate(self) -> None:
if not self.datasets:
Expand Down
1 change: 0 additions & 1 deletion fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def setup(
distributed=distributed,
dataset_name=dataset_name,
tokenizer=self._tokenizer,
truncate_documents=self._config.truncate_documents,
)
dataset = self._config.datasets[dataset_name].build_and_sample(sampling)
self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class GPTSamplingParameters(SamplingParameters):
use_loss_masking_spans: bool = False
use_preference_loss_spans: bool = False
cross_document_attention: bool = True
truncate_documents: bool = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we need to move this, but if we do we need to add backward compatibility.

# How many extra tokens to add to the sequence length.
# This is used to provide labels even for the last tokens in the sequence.
extra_tokens: int = 1
Expand All @@ -90,7 +91,6 @@ class GPTSamplingData(SamplingData):
config: GPTSamplingConfig
parameters: GPTSamplingParameters
tokenizer: "Tokenizer"
truncate_documents: bool = True


@config_class(registry=True)
Expand Down
19 changes: 13 additions & 6 deletions fast_llm/data/dataset/gpt/sampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
self._indexed_dataset = indexed_dataset
self._config = sampling.config
self._parameters = sampling.parameters
self._truncate_documents = sampling.truncate_documents
self._truncate_documents = sampling.parameters.truncate_documents
self._device = torch.device("cuda" if self._config.gpu else "cpu")

if sampling.cache_directory is None:
Expand Down Expand Up @@ -144,7 +144,7 @@ def _sample(self) -> None:
" Please make sure Fast-LLM is installed correctly."
)
long_docs_filter = document_sizes > self._parameters.sequence_length + 1
ignored_documents = sum(long_docs_filter)
ignored_documents = long_docs_filter.sum().item()
if ignored_documents:
log_main_rank(
f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._parameters.sequence_length+1} tokens and will be ignored.",
Expand Down Expand Up @@ -201,9 +201,10 @@ def _sample(self) -> None:

if self._yaml_path is not None and self._yaml_path.is_file():
loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r"))
# Hack to make sure unshuffled tokens are loaded
if not self._truncate_documents:
yaml_data["unshuffled_tokens"] = loaded_yaml_data["unshuffled_tokens"]
self._load_yaml_data(yaml_data)
if not self._truncate_documents and not self._parameters.use_preference_loss_spans:
del loaded_yaml_data["unshuffled_tokens"]

if loaded_yaml_data != yaml_data:
raise RuntimeError(
Expand Down Expand Up @@ -467,6 +468,12 @@ def __getitem__(self, index: int) -> typing.Any:
else:
# Move on to the next sample.
token_count += padding_size
elif document_size + tokens_in_sample == self._parameters.sequence_length + 1:
if token_count + document_size == token_start:
# Document belongs to the current sample but the condition below will include it for the next sample
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not following, why are we ignoring the document if it belongs to the current sample? (Also it clearly belongs to the previous sample)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I understand seems like in this scenario well have token_start_index_in_document==token_end_index_in_document==document_size , so we'll load 0 tokens from the sample. That seems unnecessary but not wrong, also doesn't seem to relate to document_size + tokens_in_sample == self._parameters.sequence_length + 1?

Seems to me the actual fix would be to replace >= with > in the condition below.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes, i got confused because i faced this issue in the multimodal branch but it only occurs when there's images right after the text tokens. Will handle it there

token_count += document_size
document_sampling_index += 1
continue

# Determine if the document belongs to the requested sample.
if token_count + document_size >= token_start:
Expand All @@ -487,7 +494,7 @@ def __getitem__(self, index: int) -> typing.Any:
0,
self._parameters.sequence_length + self._parameters.extra_tokens,
)
if span[1] > span[0]:
if span[1] >= span[0]:
loss_masking_spans.append(span)

# Go to the next document.
Expand Down Expand Up @@ -547,7 +554,7 @@ def __init__(
):
assert isinstance(sampling, GPTSamplingData)
self._indexed_dataset = indexed_dataset
if not sampling.truncate_documents:
if not sampling.parameters.truncate_documents:
raise NotImplementedError(
"Legacy sampling only supports document truncation. Please use the latest dataset format."
)
Expand Down
9 changes: 7 additions & 2 deletions fast_llm/functional/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,13 @@ def _fused_cross_entropy_forward_backward(
per_sample_loss = sum_exp_logits.log() - predicted_logits
if loss_mask is not None:
per_sample_loss = per_sample_loss * loss_mask

loss = per_sample_loss.mean()
unmasked_inputs = loss_mask.sum()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still cause a cuda sync. You can just do loss = (per_sample_loss * loss_mask).sum() / torch.maximum(loss_mask.sum(), 1)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for my own understanding, how can i check whether a pytorch op causes cuda sync?

if unmasked_inputs:
loss = per_sample_loss.sum() / unmasked_inputs
else:
loss = torch.tensor(0.0, dtype=per_sample_loss.dtype, device=per_sample_loss.device)
else:
loss = per_sample_loss.mean()
if target_format != TargetFormat.labels and group is not None:
all_reduce(loss, op=ReduceOp.MEAN, group=group)

Expand Down
1 change: 1 addition & 0 deletions fast_llm/layers/language_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class LanguageModelKwargs:
chosen_spans = "chosen_spans"
rejected_spans = "rejected_spans"
loss_mask = "loss_mask"
mask_inputs = "mask_inputs"


@config_class()
Expand Down
16 changes: 13 additions & 3 deletions fast_llm/layers/language_model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
)

@torch.compile
def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None) -> torch.Tensor:
def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask_inputs: bool) -> torch.Tensor:
Assert.eq(position_ids is not None, self._use_absolute_position_embeddings)
group = self._tensor_space.distributed.tensor_group
if self._parallel_embeddings:
Expand All @@ -99,9 +99,17 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None) -> t
input_ = split(input_, group=group, dim=0)
if self._use_absolute_position_embeddings:
position_ids = split(position_ids, group=group, dim=0)
embeddings = torch.embedding(self.word_embeddings_weight, input_)
# handle masked tokens
if mask_inputs:
input_mask = input_ >= 0
masked_input = input_ * input_mask
embeddings = torch.embedding(self.word_embeddings_weight, masked_input)
else:
embeddings = torch.embedding(self.word_embeddings_weight, input_)
if self._use_absolute_position_embeddings:
embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight)
if mask_inputs:
embeddings = embeddings * input_mask.unsqueeze(2)
with set_generator(
self._tensor_space.distributed.tp_generator
if self._sequence_parallel
Expand All @@ -123,4 +131,6 @@ def forward(
tensor_name="Embedding output",
dtype=self._residual_dtype,
)
return self._forward(input_, kwargs.get(LanguageModelKwargs.position_ids))
return self._forward(
input_, kwargs.get(LanguageModelKwargs.position_ids), kwargs.get(LanguageModelKwargs.mask_inputs)
)
9 changes: 9 additions & 0 deletions fast_llm/models/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@ class GPTBatchConfig(BatchConfig):
desc="Read loss masking spans from the dataset.",
hint=FieldHint.feature,
)
truncate_documents: bool = Field(
default=True,
desc=(
"If enabled, documents may be truncated while being packed to fit the sequence length."
"Otherwise, sequences will be padded such that every document lies entirely within a sample"
" (and documents exceeding the sequence length will be skipped altogether)."
),
hint=FieldHint.feature,
)

def _validate(self) -> None:
if self.micro_sequence_length is None:
Expand Down
7 changes: 4 additions & 3 deletions fast_llm/models/gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def preprocess_meta(
TransformerKwargs.hidden_dims: hidden_dims,
TransformerKwargs.sequence_length: sequence_length,
TransformerKwargs.sequence_q_dim: sequence_q_dim,
LanguageModelKwargs.mask_inputs: not batch_meta.truncate_documents,
}

sequence_k_pasts = range(
Expand Down Expand Up @@ -312,7 +313,7 @@ def preprocess(
if batch.loss_masking_spans is not None:
# avoid changing input tokens
labels = labels.clone()
for i, spans in enumerate(batch.loss_masking_spans):
for idx, spans in enumerate(batch.loss_masking_spans):
if not spans.numel():
continue
valid_spans = spans[
Expand All @@ -326,9 +327,9 @@ def preprocess(
loss_mask = torch.ones_like(labels, dtype=torch.bool)
for start, end in valid_spans:
if sequence_first:
loss_mask[start : end + 1, i] = False
loss_mask[start : end + 1, idx] = False
else:
loss_mask[i, start : end + 1] = False
loss_mask[idx, start : end + 1] = False
if self._config.distillation_model is not None:
kwargs[LanguageModelKwargs.loss_mask] = loss_mask
labels = torch.where(loss_mask, labels, -100)
Expand Down
1 change: 1 addition & 0 deletions fast_llm/models/gpt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def _get_sampling_parameters(
"use_loss_masking_spans": self._config.batch.use_loss_masking_spans,
"use_preference_loss_spans": self._config.model.base_model.enable_dpo,
"cross_document_attention": self._config.batch.cross_document_attention,
"truncate_documents": self._config.batch.truncate_documents,
"extra_tokens": self._config.model.base_model.prediction_heads,
}
)
Expand Down
2 changes: 1 addition & 1 deletion tests/data/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def get_sampling_data(
num_samples=num_samples,
sequence_length=sequence_length,
vocab_size=vocab_size,
truncate_documents=truncate_documents,
),
cache_directory=cache_directory,
distributed=distributed,
dataset_name=phase.value,
tokenizer=tokenizer,
truncate_documents=truncate_documents,
)


Expand Down