diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 5de5e2a2b..405d1c672 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -48,15 +48,6 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): desc="Multiprocessing context. Do not touch.", hint=FieldHint.expert, ) - truncate_documents: bool = Field( - default=True, - desc=( - "If enabled, documents may be truncated while being packed to fit the sequence length." - "Otherwise, sequences will be padded such that every document lies entirely within a sample" - " (and documents exceeding the sequence length will be skipped altogether)." - ), - hint=FieldHint.feature, - ) def _validate(self) -> None: if not self.datasets: diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 176c077a2..6724afb59 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -129,7 +129,6 @@ def setup( distributed=distributed, dataset_name=dataset_name, tokenizer=self._tokenizer, - truncate_documents=self._config.truncate_documents, ) dataset = self._config.datasets[dataset_name].build_and_sample(sampling) self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index ae87e0e78..ef2efedc9 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -75,6 +75,7 @@ class GPTSamplingParameters(SamplingParameters): use_loss_masking_spans: bool = False use_preference_loss_spans: bool = False cross_document_attention: bool = True + truncate_documents: bool = True # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 @@ -90,7 +91,6 @@ class GPTSamplingData(SamplingData): config: GPTSamplingConfig parameters: GPTSamplingParameters tokenizer: "Tokenizer" - truncate_documents: bool = True @config_class(registry=True) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 8bb5f7370..6a06002cb 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -89,7 +89,7 @@ def __init__( self._indexed_dataset = indexed_dataset self._config = sampling.config self._parameters = sampling.parameters - self._truncate_documents = sampling.truncate_documents + self._truncate_documents = sampling.parameters.truncate_documents self._device = torch.device("cuda" if self._config.gpu else "cpu") if sampling.cache_directory is None: @@ -144,7 +144,7 @@ def _sample(self) -> None: " Please make sure Fast-LLM is installed correctly." ) long_docs_filter = document_sizes > self._parameters.sequence_length + 1 - ignored_documents = sum(long_docs_filter) + ignored_documents = long_docs_filter.sum().item() if ignored_documents: log_main_rank( f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", @@ -201,9 +201,10 @@ def _sample(self) -> None: if self._yaml_path is not None and self._yaml_path.is_file(): loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) + # Hack to make sure unshuffled tokens are loaded + if not self._truncate_documents: + yaml_data["unshuffled_tokens"] = loaded_yaml_data["unshuffled_tokens"] self._load_yaml_data(yaml_data) - if not self._truncate_documents and not self._parameters.use_preference_loss_spans: - del loaded_yaml_data["unshuffled_tokens"] if loaded_yaml_data != yaml_data: raise RuntimeError( @@ -469,7 +470,7 @@ def __getitem__(self, index: int) -> typing.Any: token_count += padding_size # Determine if the document belongs to the requested sample. - if token_count + document_size >= token_start: + if token_count + document_size > token_start: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) token_end_index_in_document = min(token_end - token_count, document_size) @@ -487,7 +488,7 @@ def __getitem__(self, index: int) -> typing.Any: 0, self._parameters.sequence_length + self._parameters.extra_tokens, ) - if span[1] > span[0]: + if span[1] >= span[0]: loss_masking_spans.append(span) # Go to the next document. @@ -547,7 +548,7 @@ def __init__( ): assert isinstance(sampling, GPTSamplingData) self._indexed_dataset = indexed_dataset - if not sampling.truncate_documents: + if not sampling.parameters.truncate_documents: raise NotImplementedError( "Legacy sampling only supports document truncation. Please use the latest dataset format." ) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 513510ec7..8c5492596 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -145,9 +145,9 @@ def _fused_cross_entropy_forward_backward( per_sample_loss = sum_exp_logits.log() - predicted_logits if loss_mask is not None: - per_sample_loss = per_sample_loss * loss_mask - - loss = per_sample_loss.mean() + loss = (per_sample_loss * loss_mask).sum() / torch.maximum(loss_mask.sum(), 1) + else: + loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: all_reduce(loss, op=ReduceOp.MEAN, group=group) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index cbe17101e..c4776abe9 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -38,6 +38,7 @@ class LanguageModelKwargs: chosen_spans = "chosen_spans" rejected_spans = "rejected_spans" loss_mask = "loss_mask" + mask_inputs = "mask_inputs" @config_class() diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index e0386d8df..7036a1e97 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -84,7 +84,7 @@ def __init__( ) @torch.compile - def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None) -> torch.Tensor: + def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask_inputs: bool) -> torch.Tensor: Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) group = self._tensor_space.distributed.tensor_group if self._parallel_embeddings: @@ -101,9 +101,17 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None) -> t input_ = split(input_, group=group, dim=0) if self._use_absolute_position_embeddings: position_ids = split(position_ids, group=group, dim=0) - embeddings = torch.embedding(self.word_embeddings_weight, input_) + # handle masked tokens + if mask_inputs: + input_mask = input_ >= 0 + masked_input = input_ * input_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_input) + else: + embeddings = torch.embedding(self.word_embeddings_weight, input_) if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) + if mask_inputs: + embeddings = embeddings * input_mask.unsqueeze(2) with set_generator( self._tensor_space.distributed.tp_generator if self._sequence_parallel @@ -125,4 +133,6 @@ def forward( tensor_name="Embedding output", dtype=self._residual_dtype, ) - return self._forward(input_, kwargs.get(LanguageModelKwargs.position_ids)) + return self._forward( + input_, kwargs.get(LanguageModelKwargs.position_ids), kwargs.get(LanguageModelKwargs.mask_inputs) + ) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 6bf6e06cf..0da16428e 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -1,4 +1,5 @@ import functools +import logging import typing from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class @@ -17,6 +18,8 @@ from fast_llm.models.gpt.model import GPTInferenceRunner, GPTModel from fast_llm.models.gpt.trainer import GPTTrainer +logger = logging.getLogger(__name__) + class GPTHuggingfaceCheckpointFormat(CheckpointFormat): support_optimizer: typing.ClassVar[bool] = False @@ -56,11 +59,13 @@ class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): class MTPLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mtp_llama" trust_remote_code: typing.ClassVar[bool] = True - + + class DiffusionDreamGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "dream" trust_remote_code: typing.ClassVar[bool] = True - + + class DiffusionLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "diffusion_llama" trust_remote_code: typing.ClassVar[bool] = True @@ -91,6 +96,15 @@ class GPTBatchConfig(BatchConfig): desc="Read loss masking spans from the dataset.", hint=FieldHint.feature, ) + truncate_documents: bool | None = Field( + default=True, + desc=( + "If enabled, documents may be truncated while being packed to fit the sequence length." + "Otherwise, sequences will be padded such that every document lies entirely within a sample" + " (and documents exceeding the sequence length will be skipped altogether)." + ), + hint=FieldHint.feature, + ) def _validate(self) -> None: if self.micro_sequence_length is None: @@ -223,6 +237,16 @@ def _from_dict( cls._handle_renamed_field( default, ("data", "sampling", "use_loss_masking_spans"), ("batch", "use_loss_masking_spans") ) + if "truncate_documents" in default.get("data", {}): + # Backward compatibility for the legacy truncate_documents field. + # TODO v0.x: Remove backward compatibility. + logger.warning( + "`data.truncate_documents` field is deprecated. " "Please use `batch.truncate_documents` instead." + ) + assert "truncate_documents" not in default.get("batch", {}) + if "batch" not in default: + default["batch"] = {} + default["batch"]["truncate_documents"] = default["data"].pop("truncate_documents") return super()._from_dict(default, strict, flat) @classmethod diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index f19ef151b..3e7e3f30d 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -169,6 +169,7 @@ def preprocess_meta( TransformerKwargs.hidden_dims: hidden_dims, TransformerKwargs.sequence_length: sequence_length, TransformerKwargs.sequence_q_dim: sequence_q_dim, + LanguageModelKwargs.mask_inputs: not batch_meta.truncate_documents, } sequence_k_pasts = range( @@ -302,7 +303,7 @@ def preprocess( if batch.loss_masking_spans is not None: # avoid changing input tokens labels = labels.clone() - for i, spans in enumerate(batch.loss_masking_spans): + for idx, spans in enumerate(batch.loss_masking_spans): if not spans.numel(): continue valid_spans = spans[ @@ -316,9 +317,9 @@ def preprocess( loss_mask = torch.ones_like(labels, dtype=torch.bool) for start, end in valid_spans: if sequence_first: - loss_mask[start : end + 1, i] = False + loss_mask[start : end + 1, idx] = False else: - loss_mask[i, start : end + 1] = False + loss_mask[idx, start : end + 1] = False if self._config.distillation_model is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 0b2bb3433..54508e8e1 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -29,6 +29,7 @@ def _get_sampling_parameters( "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, "use_preference_loss_spans": self._config.model.base_model.enable_dpo, "cross_document_attention": self._config.batch.cross_document_attention, + "truncate_documents": self._config.batch.truncate_documents, "extra_tokens": self._config.model.base_model.prediction_heads, } ) diff --git a/tests/data/common.py b/tests/data/common.py index 2d3cb905f..2bb90a6b4 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -51,12 +51,12 @@ def get_sampling_data( num_samples=num_samples, sequence_length=sequence_length, vocab_size=vocab_size, + truncate_documents=truncate_documents, ), cache_directory=cache_directory, distributed=distributed, dataset_name=phase.value, tokenizer=tokenizer, - truncate_documents=truncate_documents, )