From 06b76cd33f288ed598e3bcffa1d6288759e29ce1 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 14 May 2025 15:36:44 +0000 Subject: [PATCH 01/16] some sft fixes --- fast_llm/layers/language_model/embedding.py | 5 ++++- fast_llm/models/gpt/model.py | 6 +++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 1d9406ed1..f51f40df7 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -99,7 +99,10 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None) -> t input_ = split(input_, group=group, dim=0) if self._use_absolute_position_embeddings: position_ids = split(position_ids, group=group, dim=0) - embeddings = torch.embedding(self.word_embeddings_weight, input_) + # mask padded tokens + input_mask = input_ >= 0 + masked_input = input_ * input_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * input_mask.unsqueeze(2) if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) with set_generator( diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index b548ab525..ad8a61ba0 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -312,7 +312,7 @@ def preprocess( if batch.loss_masking_spans is not None: # avoid changing input tokens labels = labels.clone() - for i, spans in enumerate(batch.loss_masking_spans): + for idx, spans in enumerate(batch.loss_masking_spans): if not spans.numel(): continue valid_spans = spans[ @@ -326,9 +326,9 @@ def preprocess( loss_mask = torch.ones_like(labels, dtype=torch.bool) for start, end in valid_spans: if sequence_first: - loss_mask[start : end + 1, i] = False + loss_mask[start : end + 1, idx] = False else: - loss_mask[i, start : end + 1] = False + loss_mask[idx, start : end + 1] = False if self._config.distillation_model is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) From 0364fc6b540a904b6ea6e0429c6d323196c52f2d Mon Sep 17 00:00:00 2001 From: root Date: Wed, 4 Jun 2025 17:23:08 +0000 Subject: [PATCH 02/16] bug fixes and improvements --- fast_llm/data/dataset/gpt/sampled.py | 14 ++++++-------- fast_llm/functional/cross_entropy.py | 2 +- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 8bb5f7370..5100b6932 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -144,7 +144,7 @@ def _sample(self) -> None: " Please make sure Fast-LLM is installed correctly." ) long_docs_filter = document_sizes > self._parameters.sequence_length + 1 - ignored_documents = sum(long_docs_filter) + ignored_documents = long_docs_filter.sum().item() if ignored_documents: log_main_rank( f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", @@ -202,8 +202,6 @@ def _sample(self) -> None: if self._yaml_path is not None and self._yaml_path.is_file(): loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) self._load_yaml_data(yaml_data) - if not self._truncate_documents and not self._parameters.use_preference_loss_spans: - del loaded_yaml_data["unshuffled_tokens"] if loaded_yaml_data != yaml_data: raise RuntimeError( @@ -456,10 +454,10 @@ def __getitem__(self, index: int) -> typing.Any: document_sampling_index += 1 continue tokens_in_sample = token_count % (self._parameters.sequence_length + 1) - if document_size + tokens_in_sample > self._parameters.sequence_length + 1: + if document_size + tokens_in_sample >= self._parameters.sequence_length + 1: # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample - if token_count > token_start: + if token_count >= token_start: # Add padding tokens to current sample token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) Assert.eq(token_count + padding_size, token_end) @@ -487,7 +485,7 @@ def __getitem__(self, index: int) -> typing.Any: 0, self._parameters.sequence_length + self._parameters.extra_tokens, ) - if span[1] > span[0]: + if span[1] >= span[0]: loss_masking_spans.append(span) # Go to the next document. @@ -525,8 +523,8 @@ def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: elif "unshuffled_tokens" not in data: # Backward compatibility # TODO v0.x: Remove - assert self._truncate_documents - data["unshuffled_tokens"] = data["tokens_per_epoch"] * data["unshuffled_epochs"] + assert not self._truncate_documents + data["unshuffled_tokens"] = data["dataset"]["tokens_per_epoch"] * data["unshuffled_epochs"] self._unshuffled_tokens = data["unshuffled_tokens"] self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 513510ec7..53b5979ed 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -145,7 +145,7 @@ def _fused_cross_entropy_forward_backward( per_sample_loss = sum_exp_logits.log() - predicted_logits if loss_mask is not None: - per_sample_loss = per_sample_loss * loss_mask + per_sample_loss = per_sample_loss[loss_mask] loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: From 02fbb7e4a608a8d7613927ecc9812f4780892e15 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 4 Jun 2025 18:05:08 +0000 Subject: [PATCH 03/16] fix --- fast_llm/data/dataset/gpt/sampled.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 5100b6932..5d6e32a6d 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -454,10 +454,10 @@ def __getitem__(self, index: int) -> typing.Any: document_sampling_index += 1 continue tokens_in_sample = token_count % (self._parameters.sequence_length + 1) - if document_size + tokens_in_sample >= self._parameters.sequence_length + 1: + if document_size + tokens_in_sample > self._parameters.sequence_length + 1: # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample - if token_count >= token_start: + if token_count > token_start: # Add padding tokens to current sample token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) Assert.eq(token_count + padding_size, token_end) @@ -465,6 +465,12 @@ def __getitem__(self, index: int) -> typing.Any: else: # Move on to the next sample. token_count += padding_size + elif document_size + tokens_in_sample == self._parameters.sequence_length + 1: + if token_count + document_size == token_start: + # Document belongs to the current sample but the condition below will include it for the next sample + token_count += document_size + document_sampling_index += 1 + continue # Determine if the document belongs to the requested sample. if token_count + document_size >= token_start: From b31987a252f7dc223234fa8d5f55bd252f13ce1c Mon Sep 17 00:00:00 2001 From: root Date: Wed, 4 Jun 2025 19:45:50 +0000 Subject: [PATCH 04/16] unshuffled tokens for padding --- fast_llm/data/dataset/gpt/sampled.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 5d6e32a6d..43ec07a37 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -202,6 +202,9 @@ def _sample(self) -> None: if self._yaml_path is not None and self._yaml_path.is_file(): loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) self._load_yaml_data(yaml_data) + # Hack to make sure unshuffled tokens are loaded + if not self._truncate_documents: + yaml_data["unshuffled_tokens"] = loaded_yaml_data.get["unshuffled_tokens"] if loaded_yaml_data != yaml_data: raise RuntimeError( @@ -529,8 +532,8 @@ def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: elif "unshuffled_tokens" not in data: # Backward compatibility # TODO v0.x: Remove - assert not self._truncate_documents - data["unshuffled_tokens"] = data["dataset"]["tokens_per_epoch"] * data["unshuffled_epochs"] + assert self._truncate_documents + data["unshuffled_tokens"] = data["tokens_per_epoch"] * data["unshuffled_epochs"] self._unshuffled_tokens = data["unshuffled_tokens"] self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch From 9db9cdf985f3e56e1c1493d85d439dccff5b4e5d Mon Sep 17 00:00:00 2001 From: root Date: Wed, 4 Jun 2025 19:46:48 +0000 Subject: [PATCH 05/16] fix --- fast_llm/data/dataset/gpt/sampled.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 43ec07a37..72a4244fd 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -204,7 +204,7 @@ def _sample(self) -> None: self._load_yaml_data(yaml_data) # Hack to make sure unshuffled tokens are loaded if not self._truncate_documents: - yaml_data["unshuffled_tokens"] = loaded_yaml_data.get["unshuffled_tokens"] + yaml_data["unshuffled_tokens"] = loaded_yaml_data["unshuffled_tokens"] if loaded_yaml_data != yaml_data: raise RuntimeError( From a81ba7f8005df94c6c0ea43044118d452cb28d56 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 4 Jun 2025 20:04:43 +0000 Subject: [PATCH 06/16] mask optionally --- fast_llm/layers/language_model/embedding.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index f51f40df7..18e581e89 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -101,8 +101,9 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None) -> t position_ids = split(position_ids, group=group, dim=0) # mask padded tokens input_mask = input_ >= 0 - masked_input = input_ * input_mask - embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * input_mask.unsqueeze(2) + if (~input_mask).any(): + masked_input = input_ * input_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * input_mask.unsqueeze(2) if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) with set_generator( From eab15c0b0faff60f37f8a3a90ab7b14c31d585e4 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 4 Jun 2025 20:06:44 +0000 Subject: [PATCH 07/16] fix hack --- fast_llm/data/dataset/gpt/sampled.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 72a4244fd..85207a32d 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -201,10 +201,10 @@ def _sample(self) -> None: if self._yaml_path is not None and self._yaml_path.is_file(): loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) - self._load_yaml_data(yaml_data) # Hack to make sure unshuffled tokens are loaded if not self._truncate_documents: yaml_data["unshuffled_tokens"] = loaded_yaml_data["unshuffled_tokens"] + self._load_yaml_data(yaml_data) if loaded_yaml_data != yaml_data: raise RuntimeError( From e0e0f78af3df98de4991c962ccfeb5e692cb40a0 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 4 Jun 2025 20:26:24 +0000 Subject: [PATCH 08/16] use mask sum --- fast_llm/functional/cross_entropy.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 53b5979ed..67b7573c1 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -145,9 +145,13 @@ def _fused_cross_entropy_forward_backward( per_sample_loss = sum_exp_logits.log() - predicted_logits if loss_mask is not None: - per_sample_loss = per_sample_loss[loss_mask] + per_sample_loss = per_sample_loss * loss_mask - loss = per_sample_loss.mean() + unmasked_inputs = loss_mask.sum() + if unmasked_inputs: + loss = per_sample_loss.sum() / unmasked_inputs + else: + loss = torch.tensor(0.0, dtype=per_sample_loss.dtype, device=per_sample_loss.device) if target_format != TargetFormat.labels and group is not None: all_reduce(loss, op=ReduceOp.MEAN, group=group) From 6b2b59816b521bbe82ee44e031005a3950f33084 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 4 Jun 2025 22:38:33 +0000 Subject: [PATCH 09/16] handle None mask --- fast_llm/functional/cross_entropy.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 67b7573c1..b34bd8848 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -146,12 +146,13 @@ def _fused_cross_entropy_forward_backward( per_sample_loss = sum_exp_logits.log() - predicted_logits if loss_mask is not None: per_sample_loss = per_sample_loss * loss_mask - - unmasked_inputs = loss_mask.sum() - if unmasked_inputs: - loss = per_sample_loss.sum() / unmasked_inputs + unmasked_inputs = loss_mask.sum() + if unmasked_inputs: + loss = per_sample_loss.sum() / unmasked_inputs + else: + loss = torch.tensor(0.0, dtype=per_sample_loss.dtype, device=per_sample_loss.device) else: - loss = torch.tensor(0.0, dtype=per_sample_loss.dtype, device=per_sample_loss.device) + loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: all_reduce(loss, op=ReduceOp.MEAN, group=group) From 29cb0a84b78c54aa631ec0ca344abb8b1b893ed1 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 5 Jun 2025 15:31:48 +0000 Subject: [PATCH 10/16] flag to mask inputs, move truncate documents --- fast_llm/data/data/gpt/config.py | 9 --------- fast_llm/data/data/gpt/data.py | 1 - fast_llm/data/dataset/gpt/config.py | 2 +- fast_llm/data/dataset/gpt/sampled.py | 2 +- fast_llm/layers/language_model/config.py | 1 + fast_llm/layers/language_model/embedding.py | 21 ++++++++++++++------- fast_llm/models/gpt/config.py | 9 +++++++++ fast_llm/models/gpt/model.py | 1 + fast_llm/models/gpt/trainer.py | 1 + 9 files changed, 28 insertions(+), 19 deletions(-) diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 85bcc6561..709764d44 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -48,15 +48,6 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): desc="Multiprocessing context. Do not touch.", hint=FieldHint.expert, ) - truncate_documents: bool = Field( - default=True, - desc=( - "If enabled, documents may be truncated while being packed to fit the sequence length." - "Otherwise, sequences will be padded such that every document lies entirely within a sample" - " (and documents exceeding the sequence length will be skipped altogether)." - ), - hint=FieldHint.feature, - ) def _validate(self) -> None: if not self.datasets: diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index c6fece9d7..4e49bec7f 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -127,7 +127,6 @@ def setup( distributed=distributed, dataset_name=dataset_name, tokenizer=self._tokenizer, - truncate_documents=self._config.truncate_documents, ) dataset = self._config.datasets[dataset_name].build_and_sample(sampling) self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index ae87e0e78..ef2efedc9 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -75,6 +75,7 @@ class GPTSamplingParameters(SamplingParameters): use_loss_masking_spans: bool = False use_preference_loss_spans: bool = False cross_document_attention: bool = True + truncate_documents: bool = True # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 @@ -90,7 +91,6 @@ class GPTSamplingData(SamplingData): config: GPTSamplingConfig parameters: GPTSamplingParameters tokenizer: "Tokenizer" - truncate_documents: bool = True @config_class(registry=True) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 85207a32d..4b70ca807 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -89,7 +89,7 @@ def __init__( self._indexed_dataset = indexed_dataset self._config = sampling.config self._parameters = sampling.parameters - self._truncate_documents = sampling.truncate_documents + self._truncate_documents = sampling.parameters.truncate_documents self._device = torch.device("cuda" if self._config.gpu else "cpu") if sampling.cache_directory is None: diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 2d5fd8436..f7148fa0b 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -37,6 +37,7 @@ class LanguageModelKwargs: chosen_spans = "chosen_spans" rejected_spans = "rejected_spans" loss_mask = "loss_mask" + mask_inputs = "mask_inputs" @config_class() diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 18e581e89..ea30f5832 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -82,7 +82,7 @@ def __init__( ) @torch.compile - def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None) -> torch.Tensor: + def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask_inputs: bool) -> torch.Tensor: Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) group = self._tensor_space.distributed.tensor_group if self._parallel_embeddings: @@ -99,13 +99,18 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None) -> t input_ = split(input_, group=group, dim=0) if self._use_absolute_position_embeddings: position_ids = split(position_ids, group=group, dim=0) - # mask padded tokens - input_mask = input_ >= 0 - if (~input_mask).any(): - masked_input = input_ * input_mask - embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * input_mask.unsqueeze(2) + # handle masked tokens + if mask_inputs: + input_mask = input_ >= 0 + if (~input_mask).any(): + masked_input = input_ * input_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_input) + else: + embeddings = torch.embedding(self.word_embeddings_weight, input_) if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) + if mask_inputs: + embeddings = embeddings * input_mask.unsqueeze(2) with set_generator( self._tensor_space.distributed.tp_generator if self._sequence_parallel @@ -127,4 +132,6 @@ def forward( tensor_name="Embedding output", dtype=self._residual_dtype, ) - return self._forward(input_, kwargs.get(LanguageModelKwargs.position_ids)) + return self._forward( + input_, kwargs.get(LanguageModelKwargs.position_ids), kwargs.get(LanguageModelKwargs.mask_inputs) + ) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index d9085c670..5545c3cc4 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -82,6 +82,15 @@ class GPTBatchConfig(BatchConfig): desc="Read loss masking spans from the dataset.", hint=FieldHint.feature, ) + truncate_documents: bool = Field( + default=True, + desc=( + "If enabled, documents may be truncated while being packed to fit the sequence length." + "Otherwise, sequences will be padded such that every document lies entirely within a sample" + " (and documents exceeding the sequence length will be skipped altogether)." + ), + hint=FieldHint.feature, + ) def _validate(self) -> None: if self.micro_sequence_length is None: diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index ad8a61ba0..0a40faf38 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -179,6 +179,7 @@ def preprocess_meta( TransformerKwargs.hidden_dims: hidden_dims, TransformerKwargs.sequence_length: sequence_length, TransformerKwargs.sequence_q_dim: sequence_q_dim, + LanguageModelKwargs.mask_inputs: not batch_meta.truncate_documents, } sequence_k_pasts = range( diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index cc39d7f70..ff4d773f5 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -30,6 +30,7 @@ def _get_sampling_parameters( "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, "use_preference_loss_spans": self._config.model.base_model.enable_dpo, "cross_document_attention": self._config.batch.cross_document_attention, + "truncate_documents": self._config.batch.truncate_documents, "extra_tokens": self._config.model.base_model.prediction_heads, } ) From 8fe536fb432d7d46a36c9b322cb5744a5b56d500 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 5 Jun 2025 15:38:31 +0000 Subject: [PATCH 11/16] fix tests --- tests/data/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/data/common.py b/tests/data/common.py index cacb28e6b..0e9e72c64 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -51,12 +51,12 @@ def get_sampling_data( num_samples=num_samples, sequence_length=sequence_length, vocab_size=vocab_size, + truncate_documents=truncate_documents, ), cache_directory=cache_directory, distributed=distributed, dataset_name=phase.value, tokenizer=tokenizer, - truncate_documents=truncate_documents, ) From 1dc76a6342998a671f502908b17b95f319bca1e4 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 5 Jun 2025 15:52:58 +0000 Subject: [PATCH 12/16] fix legacy sampler --- fast_llm/data/dataset/gpt/sampled.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 4b70ca807..b9af48fae 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -554,7 +554,7 @@ def __init__( ): assert isinstance(sampling, GPTSamplingData) self._indexed_dataset = indexed_dataset - if not sampling.truncate_documents: + if not sampling.parameters.truncate_documents: raise NotImplementedError( "Legacy sampling only supports document truncation. Please use the latest dataset format." ) From 29cc7096432b048a8ec489fa123344d21a8f16b2 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 5 Jun 2025 15:54:47 +0000 Subject: [PATCH 13/16] fix --- fast_llm/layers/language_model/embedding.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index ea30f5832..81ddfee93 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -102,9 +102,8 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask # handle masked tokens if mask_inputs: input_mask = input_ >= 0 - if (~input_mask).any(): - masked_input = input_ * input_mask - embeddings = torch.embedding(self.word_embeddings_weight, masked_input) + masked_input = input_ * input_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_input) else: embeddings = torch.embedding(self.word_embeddings_weight, input_) if self._use_absolute_position_embeddings: From 43c868af36fc25a180f86f52ca20758752d05ab8 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 12 Jun 2025 19:21:18 +0000 Subject: [PATCH 14/16] review --- fast_llm/data/data/gpt/config.py | 10 ++++++++++ fast_llm/data/dataset/gpt/sampled.py | 8 +------- fast_llm/functional/cross_entropy.py | 7 +------ fast_llm/models/gpt/config.py | 11 ++++++++++- 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 709764d44..5f9e0b72e 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -48,6 +48,16 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): desc="Multiprocessing context. Do not touch.", hint=FieldHint.expert, ) + truncate_documents: bool | None = Field( + default=None, + desc=( + "Please use batch.truncate_documents instead " + "If enabled, documents may be truncated while being packed to fit the sequence length." + "Otherwise, sequences will be padded such that every document lies entirely within a sample" + " (and documents exceeding the sequence length will be skipped altogether)." + ), + hint=FieldHint.deprecated, + ) def _validate(self) -> None: if not self.datasets: diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index b9af48fae..6a06002cb 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -468,15 +468,9 @@ def __getitem__(self, index: int) -> typing.Any: else: # Move on to the next sample. token_count += padding_size - elif document_size + tokens_in_sample == self._parameters.sequence_length + 1: - if token_count + document_size == token_start: - # Document belongs to the current sample but the condition below will include it for the next sample - token_count += document_size - document_sampling_index += 1 - continue # Determine if the document belongs to the requested sample. - if token_count + document_size >= token_start: + if token_count + document_size > token_start: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) token_end_index_in_document = min(token_end - token_count, document_size) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index b34bd8848..8c5492596 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -145,12 +145,7 @@ def _fused_cross_entropy_forward_backward( per_sample_loss = sum_exp_logits.log() - predicted_logits if loss_mask is not None: - per_sample_loss = per_sample_loss * loss_mask - unmasked_inputs = loss_mask.sum() - if unmasked_inputs: - loss = per_sample_loss.sum() / unmasked_inputs - else: - loss = torch.tensor(0.0, dtype=per_sample_loss.dtype, device=per_sample_loss.device) + loss = (per_sample_loss * loss_mask).sum() / torch.maximum(loss_mask.sum(), 1) else: loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 5545c3cc4..d3723952b 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -1,4 +1,5 @@ import functools +import logging import typing from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class @@ -16,6 +17,8 @@ from fast_llm.models.gpt.model import GPTInferenceRunner, GPTModel from fast_llm.models.gpt.trainer import GPTTrainer +logger = logging.getLogger(__name__) + class GPTHuggingfaceCheckpointFormat(CheckpointFormat): support_optimizer: typing.ClassVar[bool] = False @@ -82,7 +85,7 @@ class GPTBatchConfig(BatchConfig): desc="Read loss masking spans from the dataset.", hint=FieldHint.feature, ) - truncate_documents: bool = Field( + truncate_documents: bool | None = Field( default=True, desc=( "If enabled, documents may be truncated while being packed to fit the sequence length." @@ -181,6 +184,12 @@ def _validate(self) -> None: self.batch.sequence_length = self.model.base_model.max_position_embeddings if self.model.base_model.use_megatron_initialization: set_megatron_distributed_seeds(self.model.distributed) + if self.data.truncate_documents is not None: + logger.warning( + "Using deprecated field `data.truncate_documents`, `batch.truncate_documents` will be overridden if specified. " + "Use `batch.truncate_documents` instead." + ) + self.batch.truncate_documents = self.data.truncate_documents super()._validate() if self.model.base_model.use_absolute_position_embeddings: From ba11c56c99b98af41d58dc57b20c5fa51a8d4dc4 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 27 Jun 2025 16:27:52 +0000 Subject: [PATCH 15/16] move to _from_dict --- fast_llm/data/data/gpt/config.py | 10 ---------- fast_llm/models/gpt/config.py | 20 ++++++++++++-------- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 5f9e0b72e..709764d44 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -48,16 +48,6 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): desc="Multiprocessing context. Do not touch.", hint=FieldHint.expert, ) - truncate_documents: bool | None = Field( - default=None, - desc=( - "Please use batch.truncate_documents instead " - "If enabled, documents may be truncated while being packed to fit the sequence length." - "Otherwise, sequences will be padded such that every document lies entirely within a sample" - " (and documents exceeding the sequence length will be skipped altogether)." - ), - hint=FieldHint.deprecated, - ) def _validate(self) -> None: if not self.datasets: diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 6ca2271b7..139360fba 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -59,11 +59,13 @@ class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): class MTPLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mtp_llama" trust_remote_code: typing.ClassVar[bool] = True - + + class DiffusionDreamGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "dream" trust_remote_code: typing.ClassVar[bool] = True - + + class DiffusionLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "diffusion_llama" trust_remote_code: typing.ClassVar[bool] = True @@ -195,12 +197,6 @@ def _validate(self) -> None: self.batch.sequence_length = self.model.base_model.max_position_embeddings if self.model.base_model.use_megatron_initialization: set_megatron_distributed_seeds(self.model.distributed) - if self.data.truncate_documents is not None: - logger.warning( - "Using deprecated field `data.truncate_documents`, `batch.truncate_documents` will be overridden if specified. " - "Use `batch.truncate_documents` instead." - ) - self.batch.truncate_documents = self.data.truncate_documents super()._validate() if self.model.base_model.use_absolute_position_embeddings: @@ -241,6 +237,14 @@ def _from_dict( cls._handle_renamed_field( default, ("data", "sampling", "use_loss_masking_spans"), ("batch", "use_loss_masking_spans") ) + if "truncate_documents" in default["data"]: + # Backward compatibility for the legacy truncate_documents field. + # TODO v0.x: Remove backward compatibility. + logger.warning( + "`data.truncate_documents` field is deprecated. " "Please use `batch.truncate_documents` instead." + ) + assert "truncate_documents" not in default["batch"] + default["batch"]["truncate_documents"] = default["data"].pop("truncate_documents") return super()._from_dict(default, strict, flat) @classmethod From 439241ee3e4b362f1ef9eefddbaeb80c8015b713 Mon Sep 17 00:00:00 2001 From: sohamparikh Date: Fri, 27 Jun 2025 09:48:15 -0700 Subject: [PATCH 16/16] fix --- fast_llm/models/gpt/config.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 139360fba..0da16428e 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -237,13 +237,15 @@ def _from_dict( cls._handle_renamed_field( default, ("data", "sampling", "use_loss_masking_spans"), ("batch", "use_loss_masking_spans") ) - if "truncate_documents" in default["data"]: + if "truncate_documents" in default.get("data", {}): # Backward compatibility for the legacy truncate_documents field. # TODO v0.x: Remove backward compatibility. logger.warning( "`data.truncate_documents` field is deprecated. " "Please use `batch.truncate_documents` instead." ) - assert "truncate_documents" not in default["batch"] + assert "truncate_documents" not in default.get("batch", {}) + if "batch" not in default: + default["batch"] = {} default["batch"]["truncate_documents"] = default["data"].pop("truncate_documents") return super()._from_dict(default, strict, flat)